types.py 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. # ©️ Dan Gazizullin, 2021-2022
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import ast
  7. import asyncio
  8. import contextlib
  9. import copy
  10. import importlib
  11. import importlib.machinery
  12. import importlib.util
  13. import inspect
  14. import logging
  15. import os
  16. import re
  17. import sys
  18. import time
  19. import typing
  20. from dataclasses import dataclass, field
  21. from importlib.abc import SourceLoader
  22. import requests
  23. from telethon.hints import EntityLike
  24. from telethon.tl.functions.account import UpdateNotifySettingsRequest
  25. from telethon.tl.types import (
  26. Channel,
  27. ChannelFull,
  28. InputPeerNotifySettings,
  29. Message,
  30. UserFull,
  31. )
  32. from . import validators, version # skipcq: PY-W2000
  33. from ._reference_finder import replace_all_refs
  34. from .inline.types import BotInlineMessage # skipcq: PY-W2000
  35. from .inline.types import (
  36. BotInlineCall,
  37. BotMessage,
  38. InlineCall,
  39. InlineMessage,
  40. InlineQuery,
  41. InlineUnit,
  42. )
  43. from .pointers import PointerDict, PointerList # skipcq: PY-W2000
  44. logger = logging.getLogger(__name__)
  45. JSONSerializable = typing.Union[str, int, float, bool, list, dict, None]
  46. HikkaReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict]
  47. ListLike = typing.Union[list, set, tuple]
  48. Command = typing.Callable[..., typing.Awaitable[typing.Any]]
  49. class StringLoader(SourceLoader):
  50. """Load a python module/file from a string"""
  51. def __init__(self, data: str, origin: str):
  52. self.data = data.encode("utf-8") if isinstance(data, str) else data
  53. self.origin = origin
  54. def get_source(self, _=None) -> str:
  55. return self.data.decode("utf-8")
  56. def get_code(self, fullname: str) -> bytes:
  57. return (
  58. compile(source, self.origin, "exec", dont_inherit=True)
  59. if (source := self.get_data(fullname))
  60. else None
  61. )
  62. def get_filename(self, *args, **kwargs) -> str:
  63. return self.origin
  64. def get_data(self, *args, **kwargs) -> bytes:
  65. return self.data
  66. class Module:
  67. strings = {"name": "Unknown"}
  68. """There is no help for this module"""
  69. def config_complete(self):
  70. """Called when module.config is populated"""
  71. async def client_ready(self):
  72. """Called after client is ready (after config_loaded)"""
  73. def internal_init(self):
  74. """Called after the class is initialized in order to pass the client and db. Do not call it yourself
  75. """
  76. self.db = self.allmodules.db
  77. self._db = self.allmodules.db
  78. self.client = self.allmodules.client
  79. self._client = self.allmodules.client
  80. self.lookup = self.allmodules.lookup
  81. self.get_prefix = self.allmodules.get_prefix
  82. self.inline = self.allmodules.inline
  83. self.allclients = self.allmodules.allclients
  84. self.tg_id = self._client.tg_id
  85. self._tg_id = self._client.tg_id
  86. async def on_unload(self):
  87. """Called after unloading / reloading module"""
  88. async def on_dlmod(self):
  89. """
  90. Called after the module is first time loaded with .dlmod or .loadmod
  91. Possible use-cases:
  92. - Send reaction to author's channel message
  93. - Create asset folder
  94. - ...
  95. ⚠️ Note, that any error there will not interrupt module load, and will just
  96. send a message to logs with verbosity INFO and exception traceback
  97. """
  98. async def invoke(
  99. self,
  100. command: str,
  101. args: typing.Optional[str] = None,
  102. peer: typing.Optional[EntityLike] = None,
  103. message: typing.Optional[Message] = None,
  104. edit: bool = False,
  105. ) -> Message:
  106. """
  107. Invoke another command
  108. :param command: Command to invoke
  109. :param args: Arguments to pass to command
  110. :param peer: Peer to send the command to. If not specified, will send to the current chat
  111. :param edit: Whether to edit the message
  112. :returns Message:
  113. """
  114. if command not in self.allmodules.commands:
  115. raise ValueError(f"Command {command} not found")
  116. if not message and not peer:
  117. raise ValueError("Either peer or message must be specified")
  118. cmd = f"{self.get_prefix()}{command} {args or ''}".strip()
  119. message = (
  120. (await self._client.send_message(peer, cmd))
  121. if peer
  122. else (await (message.edit if edit else message.respond)(cmd))
  123. )
  124. await self.allmodules.commands[command](message)
  125. return message
  126. @property
  127. def commands(self) -> typing.Dict[str, Command]:
  128. """List of commands that module supports"""
  129. return get_commands(self)
  130. @property
  131. def hikka_commands(self) -> typing.Dict[str, Command]:
  132. """List of commands that module supports"""
  133. return get_commands(self)
  134. @property
  135. def inline_handlers(self) -> typing.Dict[str, Command]:
  136. """List of inline handlers that module supports"""
  137. return get_inline_handlers(self)
  138. @property
  139. def hikka_inline_handlers(self) -> typing.Dict[str, Command]:
  140. """List of inline handlers that module supports"""
  141. return get_inline_handlers(self)
  142. @property
  143. def callback_handlers(self) -> typing.Dict[str, Command]:
  144. """List of callback handlers that module supports"""
  145. return get_callback_handlers(self)
  146. @property
  147. def hikka_callback_handlers(self) -> typing.Dict[str, Command]:
  148. """List of callback handlers that module supports"""
  149. return get_callback_handlers(self)
  150. @property
  151. def watchers(self) -> typing.Dict[str, Command]:
  152. """List of watchers that module supports"""
  153. return get_watchers(self)
  154. @property
  155. def hikka_watchers(self) -> typing.Dict[str, Command]:
  156. """List of watchers that module supports"""
  157. return get_watchers(self)
  158. @commands.setter
  159. def commands(self, _):
  160. pass
  161. @hikka_commands.setter
  162. def hikka_commands(self, _):
  163. pass
  164. @inline_handlers.setter
  165. def inline_handlers(self, _):
  166. pass
  167. @hikka_inline_handlers.setter
  168. def hikka_inline_handlers(self, _):
  169. pass
  170. @callback_handlers.setter
  171. def callback_handlers(self, _):
  172. pass
  173. @hikka_callback_handlers.setter
  174. def hikka_callback_handlers(self, _):
  175. pass
  176. @watchers.setter
  177. def watchers(self, _):
  178. pass
  179. @hikka_watchers.setter
  180. def hikka_watchers(self, _):
  181. pass
  182. async def animate(
  183. self,
  184. message: typing.Union[Message, InlineMessage],
  185. frames: typing.List[str],
  186. interval: typing.Union[float, int],
  187. *,
  188. inline: bool = False,
  189. ) -> None:
  190. """
  191. Animate message
  192. :param message: Message to animate
  193. :param frames: A List of strings which are the frames of animation
  194. :param interval: Animation delay
  195. :param inline: Whether to use inline bot for animation
  196. :returns message:
  197. Please, note that if you set `inline=True`, first frame will be shown with an empty
  198. button due to the limitations of Telegram API
  199. """
  200. from . import utils
  201. with contextlib.suppress(AttributeError):
  202. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  203. if interval < 0.1:
  204. logger.warning(
  205. "Resetting animation interval to 0.1s, because it may get you in"
  206. " floodwaits"
  207. )
  208. interval = 0.1
  209. for frame in frames:
  210. if isinstance(message, Message):
  211. if inline:
  212. message = await self.inline.form(
  213. message=message,
  214. text=frame,
  215. reply_markup={"text": "\u0020\u2800", "data": "empty"},
  216. )
  217. else:
  218. message = await utils.answer(message, frame)
  219. elif isinstance(message, InlineMessage) and inline:
  220. await message.edit(frame)
  221. await asyncio.sleep(interval)
  222. return message
  223. def get(
  224. self,
  225. key: str,
  226. default: typing.Optional[JSONSerializable] = None,
  227. ) -> JSONSerializable:
  228. return self._db.get(self.__class__.__name__, key, default)
  229. def set(self, key: str, value: JSONSerializable) -> bool:
  230. self._db.set(self.__class__.__name__, key, value)
  231. def pointer(
  232. self,
  233. key: str,
  234. default: typing.Optional[JSONSerializable] = None,
  235. ) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
  236. return self._db.pointer(self.__class__.__name__, key, default)
  237. async def _approve(
  238. self,
  239. call: InlineCall,
  240. channel: EntityLike,
  241. event: asyncio.Event,
  242. ):
  243. from . import utils
  244. local_event = asyncio.Event()
  245. self.__approve += [(channel, local_event)] # skipcq: PTC-W0037
  246. await local_event.wait()
  247. event.status = local_event.status
  248. event.set()
  249. await call.edit(
  250. "💫 <b>Joined <a"
  251. f' href="https://t.me/{channel.username}">{utils.escape_html(channel.title)}</a></b>',
  252. gif="https://static.hikari.gay/0d32cbaa959e755ac8eef610f01ba0bd.gif",
  253. )
  254. async def _decline(
  255. self,
  256. call: InlineCall,
  257. channel: EntityLike,
  258. event: asyncio.Event,
  259. ):
  260. from . import utils
  261. self._db.set(
  262. "hikka.main",
  263. "declined_joins",
  264. list(set(self._db.get("hikka.main", "declined_joins", []) + [channel.id])),
  265. )
  266. event.status = False
  267. event.set()
  268. await call.edit(
  269. "✖️ <b>Declined joining <a"
  270. f' href="https://t.me/{channel.username}">{utils.escape_html(channel.title)}</a></b>',
  271. gif="https://static.hikari.gay/0d32cbaa959e755ac8eef610f01ba0bd.gif",
  272. )
  273. async def request_join(
  274. self,
  275. peer: EntityLike,
  276. reason: str,
  277. assure_joined: typing.Optional[bool] = False,
  278. ) -> bool:
  279. """
  280. Request to join a channel.
  281. :param peer: The channel to join.
  282. :param reason: The reason for joining.
  283. :param assure_joined: If set, module will not be loaded unless the required channel is joined.
  284. ⚠️ Works only in `client_ready`!
  285. ⚠️ If user declines to join channel, he will not be asked to
  286. join again, so unless he joins it manually, module will not be loaded
  287. ever.
  288. :return: Status of the request.
  289. :rtype: bool
  290. :notice: This method will block module loading until the request is approved or declined.
  291. """
  292. from . import utils
  293. event = asyncio.Event()
  294. await self.client(
  295. UpdateNotifySettingsRequest(
  296. peer=self.inline.bot_username,
  297. settings=InputPeerNotifySettings(show_previews=False, silent=False),
  298. )
  299. )
  300. channel = await self.client.get_entity(peer)
  301. if channel.id in self._db.get("hikka.main", "declined_joins", []):
  302. if assure_joined:
  303. raise LoadError(
  304. f"You need to join @{channel.username} in order to use this module"
  305. )
  306. return False
  307. if not isinstance(channel, Channel):
  308. raise TypeError("`peer` field must be a channel")
  309. if getattr(channel, "left", True):
  310. channel = await self.client.force_get_entity(peer)
  311. if not getattr(channel, "left", True):
  312. return True
  313. await self.inline.bot.send_animation(
  314. self.tg_id,
  315. "https://static.hikari.gay/ab3adf144c94a0883bfe489f4eebc520.gif",
  316. caption=(
  317. self._client.loader.lookup("translations")
  318. .strings("requested_join")
  319. .format(
  320. self.__class__.__name__,
  321. channel.username,
  322. utils.escape_html(channel.title),
  323. utils.escape_html(reason),
  324. )
  325. ),
  326. reply_markup=self.inline.generate_markup(
  327. [
  328. {
  329. "text": "💫 Approve",
  330. "callback": self._approve,
  331. "args": (channel, event),
  332. },
  333. {
  334. "text": "✖️ Decline",
  335. "callback": self._decline,
  336. "args": (channel, event),
  337. },
  338. ]
  339. ),
  340. )
  341. self.hikka_wait_channel_approve = (
  342. self.__class__.__name__,
  343. channel,
  344. reason,
  345. )
  346. await event.wait()
  347. with contextlib.suppress(AttributeError):
  348. delattr(self, "hikka_wait_channel_approve")
  349. if assure_joined and not event.status:
  350. raise LoadError(
  351. f"You need to join @{channel.username} in order to use this module"
  352. )
  353. return event.status
  354. async def import_lib(
  355. self,
  356. url: str,
  357. *,
  358. suspend_on_error: typing.Optional[bool] = False,
  359. _did_requirements: bool = False,
  360. ) -> "Library":
  361. """
  362. Import library from url and register it in :obj:`Modules`
  363. :param url: Url to import
  364. :param suspend_on_error: Will raise :obj:`loader.SelfSuspend` if library can't be loaded
  365. :return: :obj:`Library`
  366. :raise: SelfUnload if :attr:`suspend_on_error` is True and error occurred
  367. :raise: HTTPError if library is not found
  368. :raise: ImportError if library doesn't have any class which is a subclass of :obj:`loader.Library`
  369. :raise: ImportError if library name doesn't end with `Lib`
  370. :raise: RuntimeError if library throws in :method:`init`
  371. :raise: RuntimeError if library classname exists in :obj:`Modules`.libraries
  372. """
  373. from . import utils # Avoiding circular import
  374. from .loader import USER_INSTALL, VALID_PIP_PACKAGES
  375. from .translations import Strings
  376. def _raise(e: Exception):
  377. if suspend_on_error:
  378. raise SelfSuspend("Required library is not available or is corrupted.")
  379. raise e
  380. if not utils.check_url(url):
  381. _raise(ValueError("Invalid url for library"))
  382. code = await utils.run_sync(requests.get, url)
  383. code.raise_for_status()
  384. code = code.text
  385. if re.search(r"# ?scope: ?hikka_min", code):
  386. ver = tuple(
  387. map(
  388. int,
  389. re.search(r"# ?scope: ?hikka_min ((\d+\.){2}\d+)", code)[1].split(
  390. "."
  391. ),
  392. )
  393. )
  394. if version.__version__ < ver:
  395. _raise(
  396. RuntimeError(
  397. f"Library requires Hikka version {'{}.{}.{}'.format(*ver)}+"
  398. )
  399. )
  400. module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}"
  401. origin = f"<library {url}>"
  402. spec = importlib.machinery.ModuleSpec(
  403. module,
  404. StringLoader(code, origin),
  405. origin=origin,
  406. )
  407. try:
  408. instance = importlib.util.module_from_spec(spec)
  409. sys.modules[module] = instance
  410. spec.loader.exec_module(instance)
  411. except ImportError as e:
  412. logger.info(
  413. "Library loading failed, attemping dependency installation (%s)",
  414. e.name,
  415. )
  416. # Let's try to reinstall dependencies
  417. try:
  418. requirements = list(
  419. filter(
  420. lambda x: not x.startswith(("-", "_", ".")),
  421. map(
  422. str.strip,
  423. VALID_PIP_PACKAGES.search(code)[1].split(),
  424. ),
  425. )
  426. )
  427. except TypeError:
  428. logger.warning(
  429. "No valid pip packages specified in code, attemping"
  430. " installation from error"
  431. )
  432. requirements = [e.name]
  433. logger.debug("Installing requirements: %s", requirements)
  434. if not requirements or _did_requirements:
  435. _raise(e)
  436. pip = await asyncio.create_subprocess_exec(
  437. sys.executable,
  438. "-m",
  439. "pip",
  440. "install",
  441. "--upgrade",
  442. "-q",
  443. "--disable-pip-version-check",
  444. "--no-warn-script-location",
  445. *["--user"] if USER_INSTALL else [],
  446. *requirements,
  447. )
  448. rc = await pip.wait()
  449. if rc != 0:
  450. _raise(e)
  451. importlib.invalidate_caches()
  452. kwargs = utils.get_kwargs()
  453. kwargs["_did_requirements"] = True
  454. return await self._mod_import_lib(**kwargs) # Try again
  455. lib_obj = next(
  456. (
  457. value()
  458. for value in vars(instance).values()
  459. if inspect.isclass(value) and issubclass(value, Library)
  460. ),
  461. None,
  462. )
  463. if not lib_obj:
  464. _raise(ImportError("Invalid library. No class found"))
  465. if not lib_obj.__class__.__name__.endswith("Lib"):
  466. _raise(
  467. ImportError(
  468. "Invalid library. Classname {} does not end with 'Lib'".format(
  469. lib_obj.__class__.__name__
  470. )
  471. )
  472. )
  473. if (
  474. all(
  475. line.replace(" ", "") != "#scope:no_stats" for line in code.splitlines()
  476. )
  477. and self._db.get("hikka.main", "stats", True)
  478. and url is not None
  479. and utils.check_url(url)
  480. ):
  481. with contextlib.suppress(Exception):
  482. await self.lookup("loader")._send_stats(url)
  483. lib_obj.source_url = url.strip("/")
  484. lib_obj.allmodules = self.allmodules
  485. lib_obj.internal_init()
  486. for old_lib in self.allmodules.libraries:
  487. if old_lib.name == lib_obj.name and (
  488. not isinstance(getattr(old_lib, "version", None), tuple)
  489. and not isinstance(getattr(lib_obj, "version", None), tuple)
  490. or old_lib.version >= lib_obj.version
  491. ):
  492. logger.debug("Using existing instance of library %s", old_lib.name)
  493. return old_lib
  494. if hasattr(lib_obj, "init"):
  495. if not callable(lib_obj.init):
  496. _raise(ValueError("Library init() must be callable"))
  497. try:
  498. await lib_obj.init()
  499. except Exception:
  500. _raise(RuntimeError("Library init() failed"))
  501. if hasattr(lib_obj, "config"):
  502. if not isinstance(lib_obj.config, LibraryConfig):
  503. _raise(
  504. RuntimeError("Library config must be a `LibraryConfig` instance")
  505. )
  506. libcfg = lib_obj.db.get(
  507. lib_obj.__class__.__name__,
  508. "__config__",
  509. {},
  510. )
  511. for conf in lib_obj.config:
  512. with contextlib.suppress(Exception):
  513. lib_obj.config.set_no_raise(
  514. conf,
  515. (
  516. libcfg[conf]
  517. if conf in libcfg
  518. else os.environ.get(f"{lib_obj.__class__.__name__}.{conf}")
  519. or lib_obj.config.getdef(conf)
  520. ),
  521. )
  522. if hasattr(lib_obj, "strings"):
  523. lib_obj.strings = Strings(lib_obj, self.translator)
  524. lib_obj.translator = self.translator
  525. for old_lib in self.allmodules.libraries:
  526. if old_lib.name == lib_obj.name:
  527. if hasattr(old_lib, "on_lib_update") and callable(
  528. old_lib.on_lib_update
  529. ):
  530. await old_lib.on_lib_update(lib_obj)
  531. replace_all_refs(old_lib, lib_obj)
  532. logger.debug(
  533. "Replacing existing instance of library %s with updated object",
  534. lib_obj.name,
  535. )
  536. return lib_obj
  537. self.allmodules.libraries += [lib_obj]
  538. return lib_obj
  539. class DragonModule:
  540. """Module is running in compatibility mode with Dragon, so it might be unstable"""
  541. # fmt: off
  542. strings_ru = {"_cls_doc": "Модуль запущен в режиме совместимости с Dragon, поэтому он может быть нестабильным"}
  543. strings_de = {"_cls_doc": "Das Modul wird im Dragon-Kompatibilitäts modus ausgeführt, daher kann es instabil sein"}
  544. strings_tr = {"_cls_doc": "Modül Dragon uyumluluğu modunda çalıştığı için istikrarsız olabilir"}
  545. strings_uz = {"_cls_doc": "Modul Dragon muvofiqligi rejimida ishlamoqda, shuning uchun u beqaror bo'lishi mumkin"}
  546. strings_es = {"_cls_doc": "El módulo se ejecuta en modo de compatibilidad con Dragon, por lo que puede ser inestable"}
  547. strings_kk = {"_cls_doc": "Модуль Dragon қамтамасыз ету режимінде іске қосылған, сондықтан белсенді емес болуы мүмкін"}
  548. strings_tt = {"_clc_doc": "Модуль Dragon белән ярашучанлык режимда эшли башлады, шуңа күрә ул тотрыксыз була ала"}
  549. # fmt: on
  550. def __init__(self):
  551. self.name = "Unknown"
  552. self.url = None
  553. self.commands = {}
  554. self.watchers = {}
  555. self.hikka_watchers = {}
  556. self.inline_handlers = {}
  557. self.hikka_inline_handlers = {}
  558. self.callback_handlers = {}
  559. self.hikka_callback_handlers = {}
  560. @property
  561. def hikka_commands(
  562. self,
  563. ) -> typing.Dict[str, Command]:
  564. return self.commands
  565. @property
  566. def __origin__(self) -> str:
  567. return f"<dragon {self.name}>"
  568. def config_complete(self):
  569. pass
  570. async def client_ready(self):
  571. pass
  572. async def on_unload(self):
  573. pass
  574. async def on_dlmod(self):
  575. pass
  576. class Library:
  577. """All external libraries must have a class-inheritant from this class"""
  578. def internal_init(self):
  579. self.name = self.__class__.__name__
  580. self.db = self.allmodules.db
  581. self._db = self.allmodules.db
  582. self.client = self.allmodules.client
  583. self._client = self.allmodules.client
  584. self.tg_id = self._client.tg_id
  585. self._tg_id = self._client.tg_id
  586. self.lookup = self.allmodules.lookup
  587. self.get_prefix = self.allmodules.get_prefix
  588. self.inline = self.allmodules.inline
  589. self.allclients = self.allmodules.allclients
  590. def _lib_get(
  591. self,
  592. key: str,
  593. default: typing.Optional[JSONSerializable] = None,
  594. ) -> JSONSerializable:
  595. return self._db.get(self.__class__.__name__, key, default)
  596. def _lib_set(self, key: str, value: JSONSerializable) -> bool:
  597. self._db.set(self.__class__.__name__, key, value)
  598. def _lib_pointer(
  599. self,
  600. key: str,
  601. default: typing.Optional[JSONSerializable] = None,
  602. ) -> typing.Union[JSONSerializable, PointerDict, PointerList]:
  603. return self._db.pointer(self.__class__.__name__, key, default)
  604. class LoadError(Exception):
  605. """Tells user, why your module can't be loaded, if raised in `client_ready`"""
  606. def __init__(self, error_message: str): # skipcq: PYL-W0231
  607. self._error = error_message
  608. def __str__(self) -> str:
  609. return self._error
  610. class CoreOverwriteError(LoadError):
  611. """Is being raised when core module or command is overwritten"""
  612. def __init__(
  613. self,
  614. module: typing.Optional[str] = None,
  615. command: typing.Optional[str] = None,
  616. ):
  617. self.type = "module" if module else "command"
  618. self.target = module or command
  619. super().__init__(str(self))
  620. def __str__(self) -> str:
  621. return (
  622. f"{'Module' if self.type == 'module' else 'command'} {self.target} will not"
  623. " be overwritten, because it's core"
  624. )
  625. class CoreUnloadError(Exception):
  626. """Is being raised when user tries to unload core module"""
  627. def __init__(self, module: str):
  628. self.module = module
  629. super().__init__()
  630. def __str__(self) -> str:
  631. return f"Module {self.module} will not be unloaded, because it's core"
  632. class SelfUnload(Exception):
  633. """Silently unloads module, if raised in `client_ready`"""
  634. def __init__(self, error_message: str = ""):
  635. super().__init__()
  636. self._error = error_message
  637. def __str__(self) -> str:
  638. return self._error
  639. class SelfSuspend(Exception):
  640. """
  641. Silently suspends module, if raised in `client_ready`
  642. Commands and watcher will not be registered if raised
  643. Module won't be unloaded from db and will be unfreezed after restart, unless
  644. the exception is raised again
  645. """
  646. def __init__(self, error_message: str = ""):
  647. super().__init__()
  648. self._error = error_message
  649. def __str__(self) -> str:
  650. return self._error
  651. class StopLoop(Exception):
  652. """Stops the loop, in which is raised"""
  653. class ModuleConfig(dict):
  654. """Stores config for modules and apparently libraries"""
  655. def __init__(self, *entries: typing.Union[str, "ConfigValue"]):
  656. if all(isinstance(entry, ConfigValue) for entry in entries):
  657. # New config format processing
  658. self._config = {config.option: config for config in entries}
  659. else:
  660. # Legacy config processing
  661. keys = []
  662. values = []
  663. defaults = []
  664. docstrings = []
  665. for i, entry in enumerate(entries):
  666. if i % 3 == 0:
  667. keys += [entry]
  668. elif i % 3 == 1:
  669. values += [entry]
  670. defaults += [entry]
  671. else:
  672. docstrings += [entry]
  673. self._config = {
  674. key: ConfigValue(option=key, default=default, doc=doc)
  675. for key, default, doc in zip(keys, defaults, docstrings)
  676. }
  677. super().__init__(
  678. {option: config.value for option, config in self._config.items()}
  679. )
  680. def getdoc(self, key: str, message: typing.Optional[Message] = None) -> str:
  681. """Get the documentation by key"""
  682. ret = self._config[key].doc
  683. if callable(ret):
  684. try:
  685. # Compatibility tweak
  686. # does nothing in Hikka
  687. ret = ret(message)
  688. except Exception:
  689. ret = ret()
  690. return ret
  691. def getdef(self, key: str) -> str:
  692. """Get the default value by key"""
  693. return self._config[key].default
  694. def __setitem__(self, key: str, value: typing.Any):
  695. self._config[key].value = value
  696. super().__setitem__(key, value)
  697. def set_no_raise(self, key: str, value: typing.Any):
  698. self._config[key].set_no_raise(value)
  699. super().__setitem__(key, value)
  700. def __getitem__(self, key: str) -> typing.Any:
  701. try:
  702. return self._config[key].value
  703. except KeyError:
  704. return None
  705. def reload(self):
  706. for key in self._config:
  707. super().__setitem__(key, self._config[key].value)
  708. LibraryConfig = ModuleConfig
  709. class _Placeholder:
  710. """Placeholder to determine if the default value is going to be set"""
  711. async def wrap(func: typing.Callable[[], typing.Awaitable]) -> typing.Any:
  712. with contextlib.suppress(Exception):
  713. return await func()
  714. def syncwrap(func: typing.Callable[[], typing.Any]) -> typing.Any:
  715. with contextlib.suppress(Exception):
  716. return func()
  717. @dataclass(repr=True)
  718. class ConfigValue:
  719. option: str
  720. default: typing.Any = None
  721. doc: typing.Union[typing.Callable[[], str], str] = "No description"
  722. value: typing.Any = field(default_factory=_Placeholder)
  723. validator: typing.Optional[
  724. typing.Callable[[JSONSerializable], JSONSerializable]
  725. ] = None
  726. on_change: typing.Optional[
  727. typing.Union[typing.Callable[[], typing.Awaitable], typing.Callable]
  728. ] = None
  729. def __post_init__(self):
  730. if isinstance(self.value, _Placeholder):
  731. self.value = self.default
  732. def set_no_raise(self, value: typing.Any) -> bool:
  733. """
  734. Sets the config value w/o ValidationError being raised
  735. Should not be used uninternally
  736. """
  737. return self.__setattr__("value", value, ignore_validation=True)
  738. def __setattr__(
  739. self,
  740. key: str,
  741. value: typing.Any,
  742. *,
  743. ignore_validation: bool = False,
  744. ):
  745. if key == "value":
  746. try:
  747. value = ast.literal_eval(value)
  748. except Exception:
  749. pass
  750. # Convert value to list if it's tuple just not to mess up
  751. # with json convertations
  752. if isinstance(value, (set, tuple)):
  753. value = list(value)
  754. if isinstance(value, list):
  755. value = [
  756. item.strip() if isinstance(item, str) else item for item in value
  757. ]
  758. if self.validator is not None:
  759. if value is not None:
  760. try:
  761. value = self.validator.validate(value)
  762. except validators.ValidationError as e:
  763. if not ignore_validation:
  764. raise e
  765. logger.debug(
  766. "Config value was broken (%s), so it was reset to %s",
  767. value,
  768. self.default,
  769. )
  770. value = self.default
  771. else:
  772. defaults = {
  773. "String": "",
  774. "Integer": 0,
  775. "Boolean": False,
  776. "Series": [],
  777. "Float": 0.0,
  778. }
  779. if self.validator.internal_id in defaults:
  780. logger.debug(
  781. "Config value was None, so it was reset to %s",
  782. defaults[self.validator.internal_id],
  783. )
  784. value = defaults[self.validator.internal_id]
  785. # This attribute will tell the `Loader` to save this value in db
  786. self._save_marker = True
  787. object.__setattr__(self, key, value)
  788. if key == "value" and not ignore_validation and callable(self.on_change):
  789. if inspect.iscoroutinefunction(self.on_change):
  790. asyncio.ensure_future(wrap(self.on_change))
  791. else:
  792. syncwrap(self.on_change)
  793. def _get_members(
  794. mod: Module,
  795. ending: str,
  796. attribute: typing.Optional[str] = None,
  797. strict: bool = False,
  798. ) -> dict:
  799. """Get method of module, which end with ending"""
  800. return {
  801. (
  802. method_name.rsplit(ending, maxsplit=1)[0]
  803. if (method_name == ending if strict else method_name.endswith(ending))
  804. else method_name
  805. ).lower(): getattr(mod, method_name)
  806. for method_name in dir(mod)
  807. if not isinstance(getattr(type(mod), method_name, None), property)
  808. and callable(getattr(mod, method_name))
  809. and (
  810. (method_name == ending if strict else method_name.endswith(ending))
  811. or attribute
  812. and getattr(getattr(mod, method_name), attribute, False)
  813. )
  814. }
  815. class CacheRecordEntity:
  816. def __init__(
  817. self,
  818. hashable_entity: "Hashable", # type: ignore
  819. resolved_entity: EntityLike,
  820. exp: int,
  821. ):
  822. self.entity = copy.deepcopy(resolved_entity)
  823. self._hashable_entity = copy.deepcopy(hashable_entity)
  824. self._exp = round(time.time() + exp)
  825. self.ts = time.time()
  826. @property
  827. def expired(self) -> bool:
  828. return self._exp < time.time()
  829. def __eq__(self, record: "CacheRecordEntity") -> bool:
  830. return hash(record) == hash(self)
  831. def __hash__(self) -> int:
  832. return hash(self._hashable_entity)
  833. def __str__(self) -> str:
  834. return f"CacheRecordEntity of {self.entity}"
  835. def __repr__(self) -> str:
  836. return (
  837. f"CacheRecordEntity(entity={type(self.entity).__name__}(...),"
  838. f" exp={self._exp})"
  839. )
  840. class CacheRecordPerms:
  841. def __init__(
  842. self,
  843. hashable_entity: "Hashable", # type: ignore
  844. hashable_user: "Hashable", # type: ignore
  845. resolved_perms: EntityLike,
  846. exp: int,
  847. ):
  848. self.perms = copy.deepcopy(resolved_perms)
  849. self._hashable_entity = copy.deepcopy(hashable_entity)
  850. self._hashable_user = copy.deepcopy(hashable_user)
  851. self._exp = round(time.time() + exp)
  852. self.ts = time.time()
  853. @property
  854. def expired(self) -> bool:
  855. return self._exp < time.time()
  856. def __eq__(self, record: "CacheRecordPerms") -> bool:
  857. return hash(record) == hash(self)
  858. def __hash__(self) -> int:
  859. return hash((self._hashable_entity, self._hashable_user))
  860. def __str__(self) -> str:
  861. return f"CacheRecordPerms of {self.perms}"
  862. def __repr__(self) -> str:
  863. return (
  864. f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
  865. )
  866. class CacheRecordFullChannel:
  867. def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
  868. self.channel_id = channel_id
  869. self.full_channel = full_channel
  870. self._exp = round(time.time() + exp)
  871. self.ts = time.time()
  872. @property
  873. def expired(self) -> bool:
  874. return self._exp < time.time()
  875. def __eq__(self, record: "CacheRecordFullChannel") -> bool:
  876. return hash(record) == hash(self)
  877. def __hash__(self) -> int:
  878. return hash((self._hashable_entity, self._hashable_user))
  879. def __str__(self) -> str:
  880. return f"CacheRecordFullChannel of {self.channel_id}"
  881. def __repr__(self) -> str:
  882. return (
  883. f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
  884. f" exp={self._exp})"
  885. )
  886. class CacheRecordFullUser:
  887. def __init__(self, user_id: int, full_user: UserFull, exp: int):
  888. self.user_id = user_id
  889. self.full_user = full_user
  890. self._exp = round(time.time() + exp)
  891. self.ts = time.time()
  892. @property
  893. def expired(self) -> bool:
  894. return self._exp < time.time()
  895. def __eq__(self, record: "CacheRecordFullUser") -> bool:
  896. return hash(record) == hash(self)
  897. def __hash__(self) -> int:
  898. return hash((self._hashable_entity, self._hashable_user))
  899. def __str__(self) -> str:
  900. return f"CacheRecordFullUser of {self.user_id}"
  901. def __repr__(self) -> str:
  902. return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
  903. def get_commands(mod: Module) -> dict:
  904. """Introspect the module to get its commands"""
  905. return _get_members(mod, "cmd", "is_command")
  906. def get_inline_handlers(mod: Module) -> dict:
  907. """Introspect the module to get its inline handlers"""
  908. return _get_members(mod, "_inline_handler", "is_inline_handler")
  909. def get_callback_handlers(mod: Module) -> dict:
  910. """Introspect the module to get its callback handlers"""
  911. return _get_members(mod, "_callback_handler", "is_callback_handler")
  912. def get_watchers(mod: Module) -> dict:
  913. """Introspect the module to get its watchers"""
  914. return _get_members(
  915. mod,
  916. "watcher",
  917. "is_watcher",
  918. strict=True,
  919. )