loader.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. """Registers modules"""
  2. # Friendly Telegram (telegram userbot)
  3. # Copyright (C) 2018-2021 The Authors
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. # This program is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU Affero General Public License for more details.
  12. # You should have received a copy of the GNU Affero General Public License
  13. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  14. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  15. # █▀█ █ █ █ █▀█ █▀▄ █
  16. # © Copyright 2022
  17. # https://t.me/hikariatama
  18. #
  19. # 🔒 Licensed under the GNU AGPLv3
  20. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  21. import asyncio
  22. import contextlib
  23. import copy
  24. from functools import partial, wraps
  25. import importlib
  26. import importlib.util
  27. import inspect
  28. import logging
  29. import os
  30. import sys
  31. from importlib.machinery import ModuleSpec
  32. from types import FunctionType
  33. from typing import Any, Hashable, Optional, Union, List
  34. import requests
  35. from telethon.tl.types import Message
  36. from . import security, utils, validators
  37. from ._types import (
  38. ConfigValue, # skipcq
  39. LoadError, # skipcq
  40. Module,
  41. Library, # skipcq
  42. ModuleConfig, # skipcq
  43. LibraryConfig, # skipcq
  44. SelfUnload,
  45. SelfSuspend,
  46. StopLoop,
  47. InlineMessage,
  48. CoreOverwriteError,
  49. StringLoader,
  50. )
  51. from .inline.core import InlineManager
  52. from .translations import Strings
  53. logger = logging.getLogger(__name__)
  54. owner = security.owner
  55. sudo = security.sudo
  56. support = security.support
  57. group_owner = security.group_owner
  58. group_admin_add_admins = security.group_admin_add_admins
  59. group_admin_change_info = security.group_admin_change_info
  60. group_admin_ban_users = security.group_admin_ban_users
  61. group_admin_delete_messages = security.group_admin_delete_messages
  62. group_admin_pin_messages = security.group_admin_pin_messages
  63. group_admin_invite_users = security.group_admin_invite_users
  64. group_admin = security.group_admin
  65. group_member = security.group_member
  66. pm = security.pm
  67. unrestricted = security.unrestricted
  68. inline_everyone = security.inline_everyone
  69. async def stop_placeholder() -> bool:
  70. return True
  71. class Placeholder:
  72. """Placeholder"""
  73. class InfiniteLoop:
  74. _task = None
  75. status = False
  76. module_instance = None # Will be passed later
  77. def __init__(
  78. self,
  79. func: FunctionType,
  80. interval: int,
  81. autostart: bool,
  82. wait_before: bool,
  83. stop_clause: Union[str, None],
  84. ):
  85. self.func = func
  86. self.interval = interval
  87. self._wait_before = wait_before
  88. self._stop_clause = stop_clause
  89. self.autostart = autostart
  90. def _stop(self, *args, **kwargs):
  91. self._wait_for_stop.set()
  92. def stop(self, *args, **kwargs):
  93. with contextlib.suppress(AttributeError):
  94. _hikka_client_id_logging_tag = copy.copy(
  95. self.module_instance.allmodules.client._tg_id
  96. )
  97. if self._task:
  98. logger.debug(f"Stopped loop for {self.func}")
  99. self._wait_for_stop = asyncio.Event()
  100. self.status = False
  101. self._task.add_done_callback(self._stop)
  102. self._task.cancel()
  103. return asyncio.ensure_future(self._wait_for_stop.wait())
  104. logger.debug("Loop is not running")
  105. return asyncio.ensure_future(stop_placeholder())
  106. def start(self, *args, **kwargs):
  107. with contextlib.suppress(AttributeError):
  108. _hikka_client_id_logging_tag = copy.copy(
  109. self.module_instance.allmodules.client._tg_id
  110. )
  111. if not self._task:
  112. logger.debug(f"Started loop for {self.func}")
  113. self._task = asyncio.ensure_future(self.actual_loop(*args, **kwargs))
  114. else:
  115. logger.debug("Attempted to start already running loop")
  116. async def actual_loop(self, *args, **kwargs):
  117. # Wait for loader to set attribute
  118. while not self.module_instance:
  119. await asyncio.sleep(0.01)
  120. if isinstance(self._stop_clause, str) and self._stop_clause:
  121. self.module_instance.set(self._stop_clause, True)
  122. self.status = True
  123. while self.status:
  124. if self._wait_before:
  125. await asyncio.sleep(self.interval)
  126. if (
  127. isinstance(self._stop_clause, str)
  128. and self._stop_clause
  129. and not self.module_instance.get(self._stop_clause, False)
  130. ):
  131. break
  132. try:
  133. await self.func(self.module_instance, *args, **kwargs)
  134. except StopLoop:
  135. break
  136. except Exception:
  137. logger.exception("Error running loop!")
  138. if not self._wait_before:
  139. await asyncio.sleep(self.interval)
  140. self._wait_for_stop.set()
  141. self.status = False
  142. def __del__(self):
  143. self.stop()
  144. def loop(
  145. interval: int = 5,
  146. autostart: Optional[bool] = False,
  147. wait_before: Optional[bool] = False,
  148. stop_clause: Optional[str] = None,
  149. ) -> FunctionType:
  150. """
  151. Create new infinite loop from class method
  152. :param interval: Loop iterations delay
  153. :param autostart: Start loop once module is loaded
  154. :param wait_before: Insert delay before actual iteration, rather than after
  155. :param stop_clause: Database key, based on which the loop will run.
  156. This key will be set to `True` once loop is started,
  157. and will stop after key resets to `False`
  158. :attr status: Boolean, describing whether the loop is running
  159. """
  160. def wrapped(func):
  161. return InfiniteLoop(func, interval, autostart, wait_before, stop_clause)
  162. return wrapped
  163. MODULES_NAME = "modules"
  164. ru_keys = 'ёйцукенгшщзхъфывапролджэячсмитьбю.Ё"№;%:?ЙЦУКЕНГШЩЗХЪФЫВАПРОЛДЖЭ/ЯЧСМИТЬБЮ,'
  165. en_keys = "`qwertyuiop[]asdfghjkl;'zxcvbnm,./~@#$%^&QWERTYUIOP{}ASDFGHJKL:\"|ZXCVBNM<>?"
  166. BASE_DIR = (
  167. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  168. if "OKTETO" not in os.environ and "DOCKER" not in os.environ
  169. else "/data"
  170. )
  171. LOADED_MODULES_DIR = os.path.join(BASE_DIR, "loaded_modules")
  172. if not os.path.isdir(LOADED_MODULES_DIR) and "DYNO" not in os.environ:
  173. os.mkdir(LOADED_MODULES_DIR, mode=0o755)
  174. def translatable_docstring(cls):
  175. """Decorator that makes triple-quote docstrings translatable"""
  176. @wraps(cls.config_complete)
  177. def config_complete(self, *args, **kwargs):
  178. for command_, func_ in get_commands(cls).items():
  179. try:
  180. func_.__doc__ = self.strings[f"_cmd_doc_{command_}"]
  181. except AttributeError:
  182. func_.__func__.__doc__ = self.strings[f"_cmd_doc_{command_}"]
  183. for inline_handler_, func_ in get_inline_handlers(cls).items():
  184. try:
  185. func_.__doc__ = self.strings[f"_ihandle_doc_{inline_handler_}"]
  186. except AttributeError:
  187. func_.__func__.__doc__ = self.strings[f"_ihandle_doc_{inline_handler_}"]
  188. self.__doc__ = self.strings["_cls_doc"]
  189. return self.config_complete._old_(self, *args, **kwargs)
  190. config_complete._old_ = cls.config_complete
  191. cls.config_complete = config_complete
  192. for command, func in get_commands(cls).items():
  193. cls.strings[f"_cmd_doc_{command}"] = inspect.getdoc(func)
  194. for inline_handler, func in get_inline_handlers(cls).items():
  195. cls.strings[f"_ihandle_doc_{inline_handler}"] = inspect.getdoc(func)
  196. cls.strings["_cls_doc"] = inspect.getdoc(cls)
  197. return cls
  198. tds = translatable_docstring # Shorter name for modules to use
  199. def ratelimit(func):
  200. """Decorator that causes ratelimiting for this command to be enforced more strictly"""
  201. func.ratelimit = True
  202. return func
  203. def get_commands(mod):
  204. """Introspect the module to get its commands"""
  205. return {
  206. method_name.rsplit("cmd", maxsplit=1)[0]: getattr(mod, method_name)
  207. for method_name in dir(mod)
  208. if callable(getattr(mod, method_name)) and method_name.endswith("cmd")
  209. }
  210. def get_inline_handlers(mod):
  211. """Introspect the module to get its inline handlers"""
  212. return {
  213. method_name.rsplit("_inline_handler", maxsplit=1)[0]: getattr(mod, method_name)
  214. for method_name in dir(mod)
  215. if callable(getattr(mod, method_name))
  216. and method_name.endswith("_inline_handler")
  217. }
  218. def get_callback_handlers(mod):
  219. """Introspect the module to get its callback handlers"""
  220. return {
  221. method_name.rsplit("_callback_handler", maxsplit=1)[0]: getattr(
  222. mod,
  223. method_name,
  224. )
  225. for method_name in dir(mod)
  226. if callable(getattr(mod, method_name))
  227. and method_name.endswith("_callback_handler")
  228. }
  229. class Modules:
  230. """Stores all registered modules"""
  231. client = None
  232. _initial_registration = True
  233. def __init__(self):
  234. self.commands = {}
  235. self.inline_handlers = {}
  236. self.callback_handlers = {}
  237. self.aliases = {}
  238. self.modules = [] # skipcq: PTC-W0052
  239. self.libraries = []
  240. self.watchers = []
  241. self._log_handlers = []
  242. self._core_commands = []
  243. def register_all(self, client, db, mods=None):
  244. """Load all modules in the module directory"""
  245. self._db = db
  246. external_mods = []
  247. if not mods:
  248. mods = [
  249. os.path.join(utils.get_base_dir(), MODULES_NAME, mod)
  250. for mod in filter(
  251. lambda x: (x.endswith(".py") and not x.startswith("_")),
  252. os.listdir(os.path.join(utils.get_base_dir(), MODULES_NAME)),
  253. )
  254. ]
  255. if "DYNO" not in os.environ and not db.get(__name__, "secure_boot", False):
  256. external_mods = [
  257. os.path.join(LOADED_MODULES_DIR, mod)
  258. for mod in filter(
  259. lambda x: (
  260. x.endswith(f"{client._tg_id}.py") and not x.startswith("_")
  261. ),
  262. os.listdir(LOADED_MODULES_DIR),
  263. )
  264. ]
  265. else:
  266. external_mods = []
  267. self._register_modules(mods)
  268. self._register_modules(external_mods, "<file>")
  269. def _register_modules(self, modules: list, origin: str = "<core>"):
  270. with contextlib.suppress(AttributeError):
  271. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  272. for mod in modules:
  273. try:
  274. mod_shortname = (
  275. os.path.basename(mod)
  276. .rsplit(".py", maxsplit=1)[0]
  277. .rsplit("_", maxsplit=1)[0]
  278. )
  279. module_name = f"{__package__}.{MODULES_NAME}.{mod_shortname}"
  280. user_friendly_origin = (
  281. "<core {}>" if origin == "<core>" else "<file {}>"
  282. ).format(mod_shortname)
  283. logger.debug(f"Loading {module_name} from filesystem")
  284. with open(mod, "r") as file:
  285. spec = ModuleSpec(
  286. module_name,
  287. StringLoader(file.read(), user_friendly_origin),
  288. origin=user_friendly_origin,
  289. )
  290. self.register_module(spec, module_name, origin)
  291. except BaseException as e:
  292. logger.exception(f"Failed to load module {mod} due to {e}:")
  293. def register_module(
  294. self,
  295. spec: ModuleSpec,
  296. module_name: str,
  297. origin: str = "<core>",
  298. save_fs: bool = False,
  299. ) -> Module:
  300. """Register single module from importlib spec"""
  301. with contextlib.suppress(AttributeError):
  302. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  303. module = importlib.util.module_from_spec(spec)
  304. sys.modules[module_name] = module
  305. spec.loader.exec_module(module)
  306. ret = None
  307. ret = next(
  308. (
  309. value()
  310. for value in vars(module).values()
  311. if inspect.isclass(value) and issubclass(value, Module)
  312. ),
  313. None,
  314. )
  315. if hasattr(module, "__version__"):
  316. ret.__version__ = module.__version__
  317. if ret is None:
  318. ret = module.register(module_name)
  319. if not isinstance(ret, Module):
  320. raise TypeError(f"Instance is not a Module, it is {type(ret)}")
  321. self.complete_registration(ret)
  322. ret.__origin__ = origin
  323. cls_name = ret.__class__.__name__
  324. if save_fs and "DYNO" not in os.environ:
  325. path = os.path.join(
  326. LOADED_MODULES_DIR,
  327. f"{cls_name}_{self.client._tg_id}.py",
  328. )
  329. if origin == "<string>":
  330. with open(path, "w") as f:
  331. f.write(spec.loader.data.decode("utf-8"))
  332. logger.debug(f"Saved {cls_name=} to {path=}")
  333. return ret
  334. def add_aliases(self, aliases: dict):
  335. """Saves aliases and applies them to <core>/<file> modules"""
  336. self.aliases.update(aliases)
  337. for alias, cmd in aliases.items():
  338. self.add_alias(alias, cmd)
  339. def register_commands(self, instance: Module):
  340. """Register commands from instance"""
  341. with contextlib.suppress(AttributeError):
  342. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  343. if getattr(instance, "__origin__", "") == "<core>":
  344. self._core_commands += list(map(lambda x: x.lower(), instance.commands))
  345. for command in instance.commands.copy():
  346. # Restrict overwriting core modules' commands
  347. if (
  348. command.lower() in self._core_commands
  349. and getattr(instance, "__origin__", "") != "<core>"
  350. ):
  351. with contextlib.suppress(Exception):
  352. self.modules.remove(instance)
  353. raise CoreOverwriteError(command=command)
  354. # Verify that command does not already exist, or,
  355. # if it does, the command must be from the same class name
  356. if command.lower() in self.commands:
  357. if (
  358. hasattr(instance.commands[command], "__self__")
  359. and hasattr(self.commands[command], "__self__")
  360. and instance.commands[command].__self__.__class__.__name__
  361. != self.commands[command].__self__.__class__.__name__
  362. ):
  363. logger.debug(f"Duplicate command {command}")
  364. logger.debug(f"Replacing command for {self.commands[command]}")
  365. if not instance.commands[command].__doc__:
  366. logger.debug(f"Missing docs for {command}")
  367. self.commands.update({command.lower(): instance.commands[command]})
  368. for alias, cmd in self.aliases.items():
  369. if cmd in instance.commands:
  370. self.add_alias(alias, cmd)
  371. for handler in instance.inline_handlers.copy():
  372. if handler.lower() in self.inline_handlers:
  373. if (
  374. hasattr(instance.inline_handlers[handler], "__self__")
  375. and hasattr(self.inline_handlers[handler], "__self__")
  376. and instance.inline_handlers[handler].__self__.__class__.__name__
  377. != self.inline_handlers[handler].__self__.__class__.__name__
  378. ):
  379. logger.debug(f"Duplicate inline_handler {handler}")
  380. logger.debug(
  381. f"Replacing inline_handler for {self.inline_handlers[handler]}"
  382. )
  383. if not instance.inline_handlers[handler].__doc__:
  384. logger.debug(f"Missing docs for {handler}")
  385. self.inline_handlers.update(
  386. {handler.lower(): instance.inline_handlers[handler]}
  387. )
  388. for handler in instance.callback_handlers.copy():
  389. if handler.lower() in self.callback_handlers and (
  390. hasattr(instance.callback_handlers[handler], "__self__")
  391. and hasattr(self.callback_handlers[handler], "__self__")
  392. and instance.callback_handlers[handler].__self__.__class__.__name__
  393. != self.callback_handlers[handler].__self__.__class__.__name__
  394. ):
  395. logger.debug(f"Duplicate callback_handler {handler}")
  396. self.callback_handlers.update(
  397. {handler.lower(): instance.callback_handlers[handler]}
  398. )
  399. def register_watcher(self, instance: Module):
  400. """Register watcher from instance"""
  401. with contextlib.suppress(AttributeError):
  402. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  403. with contextlib.suppress(AttributeError):
  404. if instance.watcher:
  405. for watcher in self.watchers:
  406. if (
  407. hasattr(watcher, "__self__")
  408. and watcher.__self__.__class__.__name__
  409. == instance.watcher.__self__.__class__.__name__
  410. ):
  411. logger.debug(f"Removing watcher for update {watcher}")
  412. self.watchers.remove(watcher)
  413. self.watchers += [instance.watcher]
  414. def _lookup(self, modname: str):
  415. return next(
  416. (lib for lib in self.libraries if lib.name.lower() == modname.lower()),
  417. False,
  418. ) or next(
  419. (
  420. mod
  421. for mod in self.modules
  422. if mod.__class__.__name__.lower() == modname.lower()
  423. or mod.name.lower() == modname.lower()
  424. ),
  425. False,
  426. )
  427. def complete_registration(self, instance: Module):
  428. """Complete registration of instance"""
  429. with contextlib.suppress(AttributeError):
  430. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  431. instance.allmodules = self
  432. instance.hikka = True
  433. instance.get = partial(self._mod_get, _module=instance)
  434. instance.set = partial(self._mod_set, _modname=instance.__class__.__name__)
  435. instance.get_prefix = partial(self._db.get, "hikka.main", "command_prefix", ".")
  436. instance.client = self.client
  437. instance._client = self.client
  438. instance.db = self._db
  439. instance._db = self._db
  440. instance.lookup = self._lookup
  441. instance.import_lib = self._mod_import_lib
  442. for module in self.modules:
  443. if module.__class__.__name__ == instance.__class__.__name__:
  444. if getattr(module, "__origin__", "") == "<core>":
  445. raise CoreOverwriteError(
  446. module=module.__class__.__name__[:-3]
  447. if module.__class__.__name__.endswith("Mod")
  448. else module.__class__.__name__
  449. )
  450. logger.debug(f"Removing module for update {module}")
  451. asyncio.ensure_future(module.on_unload())
  452. self.modules.remove(module)
  453. for method in dir(module):
  454. if isinstance(getattr(module, method), InfiniteLoop):
  455. getattr(module, method).stop()
  456. logger.debug(f"Stopped loop in {module=}, {method=}")
  457. self.modules += [instance]
  458. def _mod_get(
  459. self,
  460. key: str,
  461. default: Optional[Hashable] = None,
  462. _module: Module = None,
  463. ) -> Hashable:
  464. mod, legacy = _module.__class__.__name__, _module.strings["name"]
  465. if self._db.get(legacy, key, Placeholder) is not Placeholder:
  466. for iterkey, value in self._db[legacy].items():
  467. if iterkey == "__config__":
  468. # Config already uses classname as key
  469. # No need to migrate
  470. continue
  471. if isinstance(value, dict) and isinstance(
  472. self._db.get(mod, iterkey), dict
  473. ):
  474. self._db[mod][iterkey].update(value)
  475. else:
  476. self._db.set(mod, iterkey, value)
  477. logger.debug(f"Migrated {legacy} -> {mod}")
  478. del self._db[legacy]
  479. return self._db.get(mod, key, default)
  480. def _mod_set(self, key: str, value: Hashable, _modname: str = None) -> bool:
  481. return self._db.set(_modname, key, value)
  482. async def _mod_import_lib(
  483. self,
  484. url: str,
  485. *,
  486. suspend_on_error: Optional[bool] = False,
  487. ) -> object:
  488. """
  489. Import library from url and register it in :obj:`Modules`
  490. :param url: Url to import
  491. :param suspend_on_error: Will raise :obj:`loader.SelfSuspend` if library can't be loaded
  492. :return: Library class instance
  493. :raise: HTTPError if library is not found
  494. :raise: ImportError if library doesn't have any class which is a subclass of :obj:`loader.Library`
  495. :raise: ImportError if library name doesn't end with `Lib`
  496. :raise: RuntimeError if library throws in :method:`init`
  497. :raise: RuntimeError if library classname exists in :obj:`Modules`.libraries
  498. """
  499. def _raise(e: Exception):
  500. if suspend_on_error:
  501. raise SelfSuspend("Required library is not available or is corrupted.")
  502. else:
  503. raise e
  504. if not utils.check_url(url):
  505. _raise(ValueError("Invalid url for library"))
  506. code = await utils.run_sync(requests.get, url)
  507. code.raise_for_status()
  508. code = code.text
  509. module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}"
  510. origin = f"<library {url}>"
  511. spec = ModuleSpec(module, StringLoader(code, origin), origin=origin)
  512. instance = importlib.util.module_from_spec(spec)
  513. sys.modules[module] = instance
  514. spec.loader.exec_module(instance)
  515. class_instance = next(
  516. (
  517. value()
  518. for value in vars(instance).values()
  519. if inspect.isclass(value) and issubclass(value, Library)
  520. ),
  521. None,
  522. )
  523. if not class_instance:
  524. _raise(ImportError("Invalid library. No class found"))
  525. if not class_instance.__class__.__name__.endswith("Lib"):
  526. _raise(ImportError("Invalid library. Class name must end with 'Lib'"))
  527. class_instance.client = self.client
  528. class_instance._client = self.client # skipcq
  529. class_instance.db = self._db # skipcq
  530. class_instance._db = self._db # skipcq
  531. class_instance.name = class_instance.__class__.__name__
  532. class_instance.source_url = url.strip("/")
  533. for lib in self.libraries.copy():
  534. if lib.source_url == class_instance.source_url:
  535. logging.debug(f"Using existing instance of library {lib.source_url}")
  536. return lib
  537. if hasattr(class_instance, "init") and callable(class_instance.init):
  538. try:
  539. await class_instance.init()
  540. except Exception:
  541. _raise(RuntimeError("Library init() failed"))
  542. if hasattr(class_instance, "config"):
  543. if not isinstance(class_instance.config, LibraryConfig):
  544. _raise(
  545. RuntimeError("Library config must be a `LibraryConfig` instance")
  546. )
  547. libcfg = class_instance.db.get(
  548. class_instance.__class__.__name__,
  549. "__config__",
  550. {},
  551. )
  552. for conf in class_instance.config.keys():
  553. with contextlib.suppress(Exception):
  554. class_instance.config.set_no_raise(
  555. conf,
  556. (
  557. libcfg[conf]
  558. if conf in libcfg.keys()
  559. else os.environ.get(
  560. f"{class_instance.__class__.__name__}.{conf}"
  561. )
  562. or class_instance.config.getdef(conf)
  563. ),
  564. )
  565. self.libraries += [class_instance]
  566. if len([x.name for x in self.libraries]) != len(
  567. {x.name for x in self.libraries}
  568. ):
  569. self.libraries.remove(class_instance)
  570. _raise(
  571. RuntimeError(
  572. "Use different classname for your library. You have a override"
  573. )
  574. )
  575. return class_instance
  576. def dispatch(self, command: str) -> tuple:
  577. """Dispatch command to appropriate module"""
  578. change = str.maketrans(ru_keys + en_keys, en_keys + ru_keys)
  579. try:
  580. return command, self.commands[command.lower()]
  581. except KeyError:
  582. try:
  583. cmd = self.aliases[command.lower()]
  584. return cmd, self.commands[cmd.lower()]
  585. except KeyError:
  586. try:
  587. cmd = self.aliases[str.translate(command, change).lower()]
  588. return cmd, self.commands[cmd.lower()]
  589. except KeyError:
  590. try:
  591. cmd = str.translate(command, change).lower()
  592. return cmd, self.commands[cmd.lower()]
  593. except KeyError:
  594. return command, None
  595. def send_config(self, db, translator, skip_hook: bool = False):
  596. """Configure modules"""
  597. for mod in self.modules:
  598. self.send_config_one(mod, db, translator, skip_hook)
  599. def send_config_one(
  600. self,
  601. mod: "Module",
  602. db: "Database", # type: ignore
  603. translator: "Translator" = None, # type: ignore
  604. skip_hook: bool = False,
  605. ):
  606. """Send config to single instance"""
  607. with contextlib.suppress(AttributeError):
  608. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  609. if hasattr(mod, "config"):
  610. modcfg = db.get(
  611. mod.__class__.__name__,
  612. "__config__",
  613. {},
  614. )
  615. try:
  616. for conf in mod.config.keys():
  617. with contextlib.suppress(validators.ValidationError):
  618. mod.config.set_no_raise(
  619. conf,
  620. (
  621. modcfg[conf]
  622. if conf in modcfg.keys()
  623. else os.environ.get(f"{mod.__class__.__name__}.{conf}")
  624. or mod.config.getdef(conf)
  625. ),
  626. )
  627. except AttributeError:
  628. logger.warning(
  629. "Got invalid config instance. Expected `ModuleConfig`, got"
  630. f" {type(mod.config)=}, {mod.config=}"
  631. )
  632. if skip_hook:
  633. return
  634. if not hasattr(mod, "name"):
  635. mod.name = mod.strings["name"]
  636. if hasattr(mod, "strings"):
  637. mod.strings = Strings(mod, translator)
  638. mod.translator = translator
  639. try:
  640. mod.config_complete()
  641. except Exception as e:
  642. logger.exception(f"Failed to send mod config complete signal due to {e}")
  643. raise
  644. async def send_ready(self, client, db, allclients):
  645. """Send all data to all modules"""
  646. self.client = client
  647. # Init inline manager anyway, so the modules
  648. # can access its `init_complete`
  649. inline_manager = InlineManager(client, db, self)
  650. await inline_manager._register_manager()
  651. # We save it to `Modules` attribute, so not to re-init
  652. # it everytime module is loaded. Then we can just
  653. # re-assign it to all modules
  654. self.inline = inline_manager
  655. try:
  656. await asyncio.gather(
  657. *[
  658. self.send_ready_one(mod, client, db, allclients)
  659. for mod in self.modules
  660. ]
  661. )
  662. except Exception as e:
  663. logger.exception(f"Failed to send mod init complete signal due to {e}")
  664. async def _animate(
  665. self,
  666. message: Union[Message, InlineMessage],
  667. frames: List[str],
  668. interval: Union[float, int],
  669. *,
  670. inline: bool = False,
  671. ) -> None:
  672. """
  673. Animate message
  674. :param message: Message to animate
  675. :param frames: A List of strings which are the frames of animation
  676. :param interval: Animation delay
  677. :param inline: Whether to use inline bot for animation
  678. :returns message:
  679. Please, note that if you set `inline=True`, first frame will be shown with an empty
  680. button due to the limitations of Telegram API
  681. """
  682. with contextlib.suppress(AttributeError):
  683. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  684. if interval < 0.1:
  685. logger.warning(
  686. "Resetting animation interval to 0.1s, because it may get you in"
  687. " floodwaits bro"
  688. )
  689. interval = 0.1
  690. for frame in frames:
  691. if isinstance(message, Message):
  692. if inline:
  693. message = await self.inline.form(
  694. message=message,
  695. text=frame,
  696. reply_markup={"text": "\u0020\u2800", "data": "empty"},
  697. )
  698. else:
  699. message = await utils.answer(message, frame)
  700. elif isinstance(message, InlineMessage) and inline:
  701. await message.edit(frame)
  702. await asyncio.sleep(interval)
  703. return message
  704. async def send_ready_one(
  705. self,
  706. mod: Module,
  707. client: "TelegramClient", # type: ignore
  708. db: "Database", # type: ignore
  709. allclients: list,
  710. no_self_unload: bool = False,
  711. from_dlmod: bool = False,
  712. ):
  713. mod.allclients = allclients
  714. mod._client = client
  715. mod._tg_id = client._tg_id
  716. with contextlib.suppress(AttributeError):
  717. _hikka_client_id_logging_tag = copy.copy(client._tg_id)
  718. mod.inline = self.inline
  719. mod.animate = self._animate
  720. for method in dir(mod):
  721. if isinstance(getattr(mod, method), InfiniteLoop):
  722. setattr(getattr(mod, method), "module_instance", mod)
  723. if getattr(mod, method).autostart:
  724. getattr(mod, method).start()
  725. logger.debug(f"Added {mod=} to {method=}")
  726. if from_dlmod:
  727. try:
  728. await mod.on_dlmod(client, db)
  729. except Exception:
  730. logger.info("Can't process `on_dlmod` hook", exc_info=True)
  731. try:
  732. await mod.client_ready(client, db)
  733. except SelfUnload as e:
  734. if no_self_unload:
  735. raise e
  736. logger.debug(f"Unloading {mod}, because it raised SelfUnload")
  737. self.modules.remove(mod)
  738. except SelfSuspend as e:
  739. if no_self_unload:
  740. raise e
  741. logger.debug(f"Suspending {mod}, because it raised SelfSuspend")
  742. return
  743. except Exception as e:
  744. logger.exception(
  745. f"Failed to send mod init complete signal for {mod} due to {e},"
  746. " attempting unload"
  747. )
  748. self.modules.remove(mod)
  749. raise
  750. if not hasattr(mod, "commands"):
  751. mod.commands = get_commands(mod)
  752. if not hasattr(mod, "inline_handlers"):
  753. mod.inline_handlers = get_inline_handlers(mod)
  754. if not hasattr(mod, "callback_handlers"):
  755. mod.callback_handlers = get_callback_handlers(mod)
  756. self.register_commands(mod)
  757. self.register_watcher(mod)
  758. def get_classname(self, name: str) -> str:
  759. return next(
  760. (
  761. module.__class__.__module__
  762. for module in reversed(self.modules)
  763. if name in (module.name, module.__class__.__module__)
  764. ),
  765. name,
  766. )
  767. def unload_module(self, classname: str) -> bool:
  768. """Remove module and all stuff from it"""
  769. worked = []
  770. to_remove = []
  771. with contextlib.suppress(AttributeError):
  772. _hikka_client_id_logging_tag = copy.copy(self.client._tg_id)
  773. for module in self.modules:
  774. if classname.lower() in (
  775. module.name.lower(),
  776. module.__class__.__name__.lower(),
  777. ):
  778. if getattr(module, "__origin__", "") == "<core>":
  779. raise RuntimeError("You can't unload core module")
  780. worked += [module.__class__.__name__]
  781. name = module.__class__.__name__
  782. if "DYNO" not in os.environ:
  783. path = os.path.join(
  784. LOADED_MODULES_DIR,
  785. f"{name}_{self.client._tg_id}.py",
  786. )
  787. if os.path.isfile(path):
  788. os.remove(path)
  789. logger.debug(f"Removed {name} file at {path=}")
  790. logger.debug(f"Removing module for unload {module}")
  791. self.modules.remove(module)
  792. asyncio.ensure_future(module.on_unload())
  793. for method in dir(module):
  794. if isinstance(getattr(module, method), InfiniteLoop):
  795. getattr(module, method).stop()
  796. logger.debug(f"Stopped loop in {module=}, {method=}")
  797. to_remove += module.commands.values()
  798. if hasattr(module, "watcher"):
  799. to_remove += [module.watcher]
  800. logger.debug(f"{to_remove=}, {worked=}")
  801. for watcher in self.watchers.copy():
  802. if watcher in to_remove:
  803. logger.debug(f"Removing {watcher=} for unload")
  804. self.watchers.remove(watcher)
  805. aliases_to_remove = []
  806. for name, command in self.commands.copy().items():
  807. if command in to_remove:
  808. logger.debug(f"Removing {command=} for unload")
  809. del self.commands[name]
  810. aliases_to_remove.append(name)
  811. for alias, command in self.aliases.copy().items():
  812. if command in aliases_to_remove:
  813. del self.aliases[alias]
  814. return worked
  815. def add_alias(self, alias, cmd):
  816. """Make an alias"""
  817. if cmd not in self.commands:
  818. return False
  819. self.aliases[alias.lower().strip()] = cmd
  820. return True
  821. def remove_alias(self, alias):
  822. """Remove an alias"""
  823. try:
  824. del self.aliases[alias.lower().strip()]
  825. except KeyError:
  826. return False
  827. return True
  828. async def log(
  829. self,
  830. type_,
  831. *,
  832. group=None,
  833. affected_uids=None,
  834. data=None,
  835. ):
  836. return await asyncio.gather(
  837. *[fun(type_, group, affected_uids, data) for fun in self._log_handlers]
  838. )
  839. def register_logger(self, _logger):
  840. self._log_handlers += [_logger]