types.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import ast
  9. import asyncio
  10. import contextlib
  11. import copy
  12. import inspect
  13. import logging
  14. from dataclasses import dataclass, field
  15. import time
  16. import typing
  17. from importlib.abc import SourceLoader
  18. from telethon.tl.types import Message, ChannelFull, UserFull
  19. from telethon.hints import EntityLike
  20. from .inline.types import ( # skipcq: PY-W2000
  21. InlineMessage,
  22. BotInlineMessage,
  23. InlineCall,
  24. BotInlineCall,
  25. InlineUnit,
  26. BotMessage,
  27. InlineQuery,
  28. )
  29. from . import validators # skipcq: PY-W2000
  30. from .pointers import ( # skipcq: PY-W2000
  31. PointerList,
  32. PointerDict,
  33. )
  34. logger = logging.getLogger(__name__)
  35. JSONSerializable = typing.Union[str, int, float, bool, list, dict, None]
  36. HikkaReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict]
  37. ListLike = typing.Union[list, set, tuple]
  38. class StringLoader(SourceLoader):
  39. """Load a python module/file from a string"""
  40. def __init__(self, data: str, origin: str):
  41. self.data = data.encode("utf-8") if isinstance(data, str) else data
  42. self.origin = origin
  43. def get_code(self, fullname: str) -> str:
  44. return (
  45. compile(source, self.origin, "exec", dont_inherit=True)
  46. if (source := self.get_source(fullname))
  47. else None
  48. )
  49. def get_filename(self, *args, **kwargs) -> str:
  50. return self.origin
  51. def get_data(self, *args, **kwargs) -> bytes:
  52. return self.data
  53. class Module:
  54. strings = {"name": "Unknown"}
  55. """There is no help for this module"""
  56. def config_complete(self):
  57. """Called when module.config is populated"""
  58. async def client_ready(self, client, db):
  59. """Called after client is ready (after config_loaded)"""
  60. async def on_unload(self):
  61. """Called after unloading / reloading module"""
  62. async def on_dlmod(self, client, db):
  63. """
  64. Called after the module is first time loaded with .dlmod or .loadmod
  65. Possible use-cases:
  66. - Send reaction to author's channel message
  67. - Join author's channel
  68. - Create asset folder
  69. - ...
  70. ⚠️ Note, that any error there will not interrupt module load, and will just
  71. send a message to logs with verbosity INFO and exception traceback
  72. """
  73. def __getattr__(self, name: str):
  74. if name in {"hikka_commands", "commands"}:
  75. return get_commands(self)
  76. if name in {"hikka_inline_handlers", "inline_handlers"}:
  77. return get_inline_handlers(self)
  78. if name in {"hikka_callback_handlers", "callback_handlers"}:
  79. return get_callback_handlers(self)
  80. if name in {"hikka_watchers", "watchers"}:
  81. return get_watchers(self)
  82. raise AttributeError(
  83. f"Module {self.__class__.__name__} has no attribute {name}"
  84. )
  85. class Library:
  86. """All external libraries must have a class-inheritant from this class"""
  87. class LoadError(Exception):
  88. """Tells user, why your module can't be loaded, if raised in `client_ready`"""
  89. def __init__(self, error_message: str): # skipcq: PYL-W0231
  90. self._error = error_message
  91. def __str__(self) -> str:
  92. return self._error
  93. class CoreOverwriteError(LoadError):
  94. """Is being raised when core module or command is overwritten"""
  95. def __init__(
  96. self,
  97. module: typing.Optional[str] = None,
  98. command: typing.Optional[str] = None,
  99. ):
  100. self.type = "module" if module else "command"
  101. self.target = module or command
  102. super().__init__(str(self))
  103. def __str__(self) -> str:
  104. return (
  105. f"Module {self.target} will not be overwritten, because it's core"
  106. if self.type == "module"
  107. else f"Command {self.target} will not be overwritten, because it's core"
  108. )
  109. class CoreUnloadError(Exception):
  110. """Is being raised when user tries to unload core module"""
  111. def __init__(self, module: str):
  112. self.module = module
  113. super().__init__()
  114. def __str__(self) -> str:
  115. return f"Module {self.module} will not be unloaded, because it's core"
  116. class SelfUnload(Exception):
  117. """Silently unloads module, if raised in `client_ready`"""
  118. def __init__(self, error_message: str = ""):
  119. super().__init__()
  120. self._error = error_message
  121. def __str__(self) -> str:
  122. return self._error
  123. class SelfSuspend(Exception):
  124. """
  125. Silently suspends module, if raised in `client_ready`
  126. Commands and watcher will not be registered if raised
  127. Module won't be unloaded from db and will be unfreezed after restart, unless
  128. the exception is raised again
  129. """
  130. def __init__(self, error_message: str = ""):
  131. super().__init__()
  132. self._error = error_message
  133. def __str__(self) -> str:
  134. return self._error
  135. class StopLoop(Exception):
  136. """Stops the loop, in which is raised"""
  137. class ModuleConfig(dict):
  138. """Stores config for modules and apparently libraries"""
  139. def __init__(self, *entries):
  140. if all(isinstance(entry, ConfigValue) for entry in entries):
  141. # New config format processing
  142. self._config = {config.option: config for config in entries}
  143. else:
  144. # Legacy config processing
  145. keys = []
  146. values = []
  147. defaults = []
  148. docstrings = []
  149. for i, entry in enumerate(entries):
  150. if i % 3 == 0:
  151. keys += [entry]
  152. elif i % 3 == 1:
  153. values += [entry]
  154. defaults += [entry]
  155. else:
  156. docstrings += [entry]
  157. self._config = {
  158. key: ConfigValue(option=key, default=default, doc=doc)
  159. for key, default, doc in zip(keys, defaults, docstrings)
  160. }
  161. super().__init__(
  162. {option: config.value for option, config in self._config.items()}
  163. )
  164. def getdoc(self, key: str, message: Message = None) -> str:
  165. """Get the documentation by key"""
  166. ret = self._config[key].doc
  167. if callable(ret):
  168. try:
  169. # Compatibility tweak
  170. # does nothing in Hikka
  171. ret = ret(message)
  172. except Exception:
  173. ret = ret()
  174. return ret
  175. def getdef(self, key: str) -> str:
  176. """Get the default value by key"""
  177. return self._config[key].default
  178. def __setitem__(self, key: str, value: typing.Any):
  179. self._config[key].value = value
  180. super().__setitem__(key, value)
  181. def set_no_raise(self, key: str, value: typing.Any):
  182. self._config[key].set_no_raise(value)
  183. super().__setitem__(key, value)
  184. def __getitem__(self, key: str) -> typing.Any:
  185. try:
  186. return self._config[key].value
  187. except KeyError:
  188. return None
  189. def reload(self):
  190. for key in self._config:
  191. super().__setitem__(key, self._config[key].value)
  192. LibraryConfig = ModuleConfig
  193. class _Placeholder:
  194. """Placeholder to determine if the default value is going to be set"""
  195. async def wrap(func: typing.Awaitable):
  196. with contextlib.suppress(Exception):
  197. return await func()
  198. def syncwrap(func: typing.Callable):
  199. with contextlib.suppress(Exception):
  200. return func()
  201. @dataclass(repr=True)
  202. class ConfigValue:
  203. option: str
  204. default: typing.Any = None
  205. doc: typing.Union[callable, str] = "No description"
  206. value: typing.Any = field(default_factory=_Placeholder)
  207. validator: typing.Optional[callable] = None
  208. on_change: typing.Optional[typing.Union[typing.Awaitable, typing.Callable]] = None
  209. def __post_init__(self):
  210. if isinstance(self.value, _Placeholder):
  211. self.value = self.default
  212. def set_no_raise(self, value: typing.Any) -> bool:
  213. """
  214. Sets the config value w/o ValidationError being raised
  215. Should not be used uninternally
  216. """
  217. return self.__setattr__("value", value, ignore_validation=True)
  218. def __setattr__(
  219. self,
  220. key: str,
  221. value: typing.Any,
  222. *,
  223. ignore_validation: bool = False,
  224. ) -> bool:
  225. if key == "value":
  226. try:
  227. value = ast.literal_eval(value)
  228. except Exception:
  229. pass
  230. # Convert value to list if it's tuple just not to mess up
  231. # with json convertations
  232. if isinstance(value, (set, tuple)):
  233. value = list(value)
  234. if isinstance(value, list):
  235. value = [
  236. item.strip() if isinstance(item, str) else item for item in value
  237. ]
  238. if self.validator is not None:
  239. if value is not None:
  240. try:
  241. value = self.validator.validate(value)
  242. except validators.ValidationError as e:
  243. if not ignore_validation:
  244. raise e
  245. logger.debug(
  246. "Config value was broken (%s), so it was reset to %s",
  247. value,
  248. self.default,
  249. )
  250. value = self.default
  251. else:
  252. defaults = {
  253. "String": "",
  254. "Integer": 0,
  255. "Boolean": False,
  256. "Series": [],
  257. "Float": 0.0,
  258. }
  259. if self.validator.internal_id in defaults:
  260. logger.debug(
  261. "Config value was None, so it was reset to %s",
  262. defaults[self.validator.internal_id],
  263. )
  264. value = defaults[self.validator.internal_id]
  265. # This attribute will tell the `Loader` to save this value in db
  266. self._save_marker = True
  267. object.__setattr__(self, key, value)
  268. if key == "value" and not ignore_validation and callable(self.on_change):
  269. if inspect.iscoroutinefunction(self.on_change):
  270. asyncio.ensure_future(wrap(self.on_change))
  271. else:
  272. syncwrap(self.on_change)
  273. def _get_members(
  274. mod: Module,
  275. ending: str,
  276. attribute: typing.Optional[str] = None,
  277. strict: bool = False,
  278. ) -> dict:
  279. """Get method of module, which end with ending"""
  280. return {
  281. (
  282. method_name.rsplit(ending, maxsplit=1)[0]
  283. if (method_name == ending if strict else method_name.endswith(ending))
  284. else method_name
  285. ).lower(): getattr(mod, method_name)
  286. for method_name in dir(mod)
  287. if callable(getattr(mod, method_name))
  288. and (
  289. (method_name == ending if strict else method_name.endswith(ending))
  290. or attribute
  291. and getattr(getattr(mod, method_name), attribute, False)
  292. )
  293. }
  294. class CacheRecord:
  295. def __init__(
  296. self,
  297. hashable_entity: "Hashable", # type: ignore
  298. resolved_entity: EntityLike,
  299. exp: int,
  300. ):
  301. self.entity = copy.deepcopy(resolved_entity)
  302. self._hashable_entity = copy.deepcopy(hashable_entity)
  303. self._exp = round(time.time() + exp)
  304. self.ts = time.time()
  305. def expired(self):
  306. return self._exp < time.time()
  307. def __eq__(self, record: "CacheRecord"):
  308. return hash(record) == hash(self)
  309. def __hash__(self):
  310. return hash(self._hashable_entity)
  311. def __str__(self):
  312. return f"CacheRecord of {self.entity}"
  313. def __repr__(self):
  314. return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
  315. class CacheRecordPerms:
  316. def __init__(
  317. self,
  318. hashable_entity: "Hashable", # type: ignore
  319. hashable_user: "Hashable", # type: ignore
  320. resolved_perms: EntityLike,
  321. exp: int,
  322. ):
  323. self.perms = copy.deepcopy(resolved_perms)
  324. self._hashable_entity = copy.deepcopy(hashable_entity)
  325. self._hashable_user = copy.deepcopy(hashable_user)
  326. self._exp = round(time.time() + exp)
  327. self.ts = time.time()
  328. def expired(self):
  329. return self._exp < time.time()
  330. def __eq__(self, record: "CacheRecordPerms"):
  331. return hash(record) == hash(self)
  332. def __hash__(self):
  333. return hash((self._hashable_entity, self._hashable_user))
  334. def __str__(self):
  335. return f"CacheRecordPerms of {self.perms}"
  336. def __repr__(self):
  337. return (
  338. f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
  339. )
  340. class CacheRecordFullChannel:
  341. def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
  342. self.channel_id = channel_id
  343. self.full_channel = full_channel
  344. self._exp = round(time.time() + exp)
  345. self.ts = time.time()
  346. def expired(self):
  347. return self._exp < time.time()
  348. def __eq__(self, record: "CacheRecordFullChannel"):
  349. return hash(record) == hash(self)
  350. def __hash__(self):
  351. return hash((self._hashable_entity, self._hashable_user))
  352. def __str__(self):
  353. return f"CacheRecordFullChannel of {self.channel_id}"
  354. def __repr__(self):
  355. return (
  356. f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
  357. f" exp={self._exp})"
  358. )
  359. class CacheRecordFullUser:
  360. def __init__(self, user_id: int, full_user: UserFull, exp: int):
  361. self.user_id = user_id
  362. self.full_user = full_user
  363. self._exp = round(time.time() + exp)
  364. self.ts = time.time()
  365. def expired(self):
  366. return self._exp < time.time()
  367. def __eq__(self, record: "CacheRecordFullUser"):
  368. return hash(record) == hash(self)
  369. def __hash__(self):
  370. return hash((self._hashable_entity, self._hashable_user))
  371. def __str__(self):
  372. return f"CacheRecordFullUser of {self.user_id}"
  373. def __repr__(self):
  374. return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
  375. def get_commands(mod: Module) -> dict:
  376. """Introspect the module to get its commands"""
  377. return _get_members(mod, "cmd", "is_command")
  378. def get_inline_handlers(mod: Module) -> dict:
  379. """Introspect the module to get its inline handlers"""
  380. return _get_members(mod, "_inline_handler", "is_inline_handler")
  381. def get_callback_handlers(mod: Module) -> dict:
  382. """Introspect the module to get its callback handlers"""
  383. return _get_members(mod, "_callback_handler", "is_callback_handler")
  384. def get_watchers(mod: Module) -> dict:
  385. """Introspect the module to get its watchers"""
  386. return _get_members(
  387. mod,
  388. "watcher",
  389. "is_watcher",
  390. strict=True,
  391. )