loader.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154
  1. """Registers modules"""
  2. # ©️ Dan Gazizullin, 2021-2023
  3. # This file is a part of Hikka Userbot
  4. # 🌐 https://github.com/hikariatama/Hikka
  5. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  6. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  7. import asyncio
  8. import contextlib
  9. import copy
  10. import importlib
  11. import importlib.machinery
  12. import importlib.util
  13. import inspect
  14. import logging
  15. import os
  16. import re
  17. import sys
  18. import typing
  19. from functools import wraps
  20. from types import FunctionType, ModuleType
  21. from uuid import uuid4
  22. from telethon.tl.tlobject import TLObject
  23. from . import security, utils, validators, version # skipcq
  24. from .database import Database
  25. from .inline.core import InlineManager
  26. from .translations import Strings, Translator
  27. from .types import ConfigValue # skipcq
  28. from .types import ModuleConfig # skipcq
  29. from .types import (
  30. Command,
  31. CoreOverwriteError,
  32. CoreUnloadError,
  33. DragonModule,
  34. InlineMessage,
  35. JSONSerializable,
  36. Library,
  37. LibraryConfig,
  38. LoadError,
  39. Module,
  40. SelfSuspend,
  41. SelfUnload,
  42. StopLoop,
  43. StringLoader,
  44. get_commands,
  45. get_inline_handlers,
  46. )
  47. logger = logging.getLogger(__name__)
  48. owner = security.owner
  49. sudo = security.sudo
  50. support = security.support
  51. group_owner = security.group_owner
  52. group_admin_add_admins = security.group_admin_add_admins
  53. group_admin_change_info = security.group_admin_change_info
  54. group_admin_ban_users = security.group_admin_ban_users
  55. group_admin_delete_messages = security.group_admin_delete_messages
  56. group_admin_pin_messages = security.group_admin_pin_messages
  57. group_admin_invite_users = security.group_admin_invite_users
  58. group_admin = security.group_admin
  59. group_member = security.group_member
  60. pm = security.pm
  61. unrestricted = security.unrestricted
  62. inline_everyone = security.inline_everyone
  63. async def stop_placeholder() -> bool:
  64. return True
  65. class Placeholder:
  66. """Placeholder"""
  67. VALID_PIP_PACKAGES = re.compile(
  68. r"^\s*# ?requires:(?: ?)((?:{url} )*(?:{url}))\s*$".format(
  69. url=r"[-[\]_.~:/?#@!$&'()*+,;%<=>a-zA-Z0-9]+"
  70. ),
  71. re.MULTILINE,
  72. )
  73. USER_INSTALL = "PIP_TARGET" not in os.environ and "VIRTUAL_ENV" not in os.environ
  74. class InfiniteLoop:
  75. _task = None
  76. status = False
  77. module_instance = None # Will be passed later
  78. def __init__(
  79. self,
  80. func: FunctionType,
  81. interval: int,
  82. autostart: bool,
  83. wait_before: bool,
  84. stop_clause: typing.Union[str, None],
  85. ):
  86. self.func = func
  87. self.interval = interval
  88. self._wait_before = wait_before
  89. self._stop_clause = stop_clause
  90. self.autostart = autostart
  91. def _stop(self, *args, **kwargs):
  92. self._wait_for_stop.set()
  93. def stop(self, *args, **kwargs):
  94. with contextlib.suppress(AttributeError):
  95. _hikka_client_id_logging_tag = copy.copy(
  96. self.module_instance.allmodules.client.tg_id
  97. )
  98. if self._task:
  99. logger.debug("Stopped loop for method %s", self.func)
  100. self._wait_for_stop = asyncio.Event()
  101. self.status = False
  102. self._task.add_done_callback(self._stop)
  103. self._task.cancel()
  104. return asyncio.ensure_future(self._wait_for_stop.wait())
  105. logger.debug("Loop is not running")
  106. return asyncio.ensure_future(stop_placeholder())
  107. def start(self, *args, **kwargs):
  108. with contextlib.suppress(AttributeError):
  109. _hikka_client_id_logging_tag = copy.copy(
  110. self.module_instance.allmodules.client.tg_id
  111. )
  112. if not self._task:
  113. logger.debug("Started loop for method %s", self.func)
  114. self._task = asyncio.ensure_future(self.actual_loop(*args, **kwargs))
  115. else:
  116. logger.debug("Attempted to start already running loop")
  117. async def actual_loop(self, *args, **kwargs):
  118. # Wait for loader to set attribute
  119. while not self.module_instance:
  120. await asyncio.sleep(0.01)
  121. if isinstance(self._stop_clause, str) and self._stop_clause:
  122. self.module_instance.set(self._stop_clause, True)
  123. self.status = True
  124. while self.status:
  125. if self._wait_before:
  126. await asyncio.sleep(self.interval)
  127. if (
  128. isinstance(self._stop_clause, str)
  129. and self._stop_clause
  130. and not self.module_instance.get(self._stop_clause, False)
  131. ):
  132. break
  133. try:
  134. await self.func(self.module_instance, *args, **kwargs)
  135. except StopLoop:
  136. break
  137. except Exception:
  138. logger.exception("Error running loop!")
  139. if not self._wait_before:
  140. await asyncio.sleep(self.interval)
  141. self._wait_for_stop.set()
  142. self.status = False
  143. def __del__(self):
  144. self.stop()
  145. def loop(
  146. interval: int = 5,
  147. autostart: typing.Optional[bool] = False,
  148. wait_before: typing.Optional[bool] = False,
  149. stop_clause: typing.Optional[str] = None,
  150. ) -> FunctionType:
  151. """
  152. Create new infinite loop from class method
  153. :param interval: Loop iterations delay
  154. :param autostart: Start loop once module is loaded
  155. :param wait_before: Insert delay before actual iteration, rather than after
  156. :param stop_clause: Database key, based on which the loop will run.
  157. This key will be set to `True` once loop is started,
  158. and will stop after key resets to `False`
  159. :attr status: Boolean, describing whether the loop is running
  160. """
  161. def wrapped(func):
  162. return InfiniteLoop(func, interval, autostart, wait_before, stop_clause)
  163. return wrapped
  164. MODULES_NAME = "modules"
  165. ru_keys = 'ёйцукенгшщзхъфывапролджэячсмитьбю.Ё"№;%:?ЙЦУКЕНГШЩЗХЪФЫВАПРОЛДЖЭ/ЯЧСМИТЬБЮ,'
  166. en_keys = "`qwertyuiop[]asdfghjkl;'zxcvbnm,./~@#$%^&QWERTYUIOP{}ASDFGHJKL:\"|ZXCVBNM<>?"
  167. BASE_DIR = (
  168. "/data"
  169. if "DOCKER" in os.environ
  170. else os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  171. )
  172. LOADED_MODULES_DIR = os.path.join(BASE_DIR, "loaded_modules")
  173. if not os.path.isdir(LOADED_MODULES_DIR):
  174. os.mkdir(LOADED_MODULES_DIR, mode=0o755)
  175. def translatable_docstring(cls):
  176. """Decorator that makes triple-quote docstrings translatable"""
  177. @wraps(cls.config_complete)
  178. def config_complete(self, *args, **kwargs):
  179. def proccess_decorators(mark: str, obj: str):
  180. nonlocal self
  181. for attr in dir(func_):
  182. if (
  183. attr.endswith("_doc")
  184. and len(attr) == 6
  185. and isinstance(getattr(func_, attr), str)
  186. ):
  187. var = f"strings_{attr.split('_')[0]}"
  188. if not hasattr(self, var):
  189. setattr(self, var, {})
  190. getattr(self, var).setdefault(f"{mark}{obj}", getattr(func_, attr))
  191. for command_, func_ in get_commands(cls).items():
  192. proccess_decorators("_cmd_doc_", command_)
  193. try:
  194. func_.__doc__ = self.strings[f"_cmd_doc_{command_}"]
  195. except AttributeError:
  196. func_.__func__.__doc__ = self.strings[f"_cmd_doc_{command_}"]
  197. for inline_handler_, func_ in get_inline_handlers(cls).items():
  198. proccess_decorators("_ihandle_doc_", inline_handler_)
  199. try:
  200. func_.__doc__ = self.strings[f"_ihandle_doc_{inline_handler_}"]
  201. except AttributeError:
  202. func_.__func__.__doc__ = self.strings[f"_ihandle_doc_{inline_handler_}"]
  203. self.__doc__ = self.strings["_cls_doc"]
  204. return (
  205. self.config_complete._old_(self, *args, **kwargs)
  206. if not kwargs.pop("reload_dynamic_translate", None)
  207. else True
  208. )
  209. config_complete._old_ = cls.config_complete
  210. cls.config_complete = config_complete
  211. for command_, func in get_commands(cls).items():
  212. cls.strings[f"_cmd_doc_{command_}"] = inspect.getdoc(func)
  213. for inline_handler_, func in get_inline_handlers(cls).items():
  214. cls.strings[f"_ihandle_doc_{inline_handler_}"] = inspect.getdoc(func)
  215. cls.strings["_cls_doc"] = inspect.getdoc(cls)
  216. return cls
  217. tds = translatable_docstring # Shorter name for modules to use
  218. def ratelimit(func: Command) -> Command:
  219. """Decorator that causes ratelimiting for this command to be enforced more strictly
  220. """
  221. func.ratelimit = True
  222. return func
  223. def tag(*tags, **kwarg_tags):
  224. """
  225. Tag function (esp. watchers) with some tags
  226. Currently available tags:
  227. • `no_commands` - Ignore all userbot commands in watcher
  228. • `only_commands` - Capture only userbot commands in watcher
  229. • `out` - Capture only outgoing events
  230. • `in` - Capture only incoming events
  231. • `only_messages` - Capture only messages (not join events)
  232. • `editable` - Capture only messages, which can be edited (no forwards etc.)
  233. • `no_media` - Capture only messages without media and files
  234. • `only_media` - Capture only messages with media and files
  235. • `only_photos` - Capture only messages with photos
  236. • `only_videos` - Capture only messages with videos
  237. • `only_audios` - Capture only messages with audios
  238. • `only_docs` - Capture only messages with documents
  239. • `only_stickers` - Capture only messages with stickers
  240. • `only_inline` - Capture only messages with inline queries
  241. • `only_channels` - Capture only messages with channels
  242. • `only_groups` - Capture only messages with groups
  243. • `only_pm` - Capture only messages with private chats
  244. • `startswith` - Capture only messages that start with given text
  245. • `endswith` - Capture only messages that end with given text
  246. • `contains` - Capture only messages that contain given text
  247. • `regex` - Capture only messages that match given regex
  248. • `filter` - Capture only messages that pass given function
  249. • `from_id` - Capture only messages from given user
  250. • `chat_id` - Capture only messages from given chat
  251. • `thumb_url` - Works for inline command handlers. Will be shown in help
  252. • `alias` - Set single alias for a command
  253. • `aliases` - Set multiple aliases for a command
  254. Usage example:
  255. @loader.tag("no_commands", "out")
  256. @loader.tag("no_commands", out=True)
  257. @loader.tag(only_messages=True)
  258. @loader.tag("only_messages", "only_pm", regex=r"^[.] ?hikka$", from_id=659800858)
  259. 💡 These tags can be used directly in `@loader.watcher`:
  260. @loader.watcher("no_commands", out=True)
  261. """
  262. def inner(func: Command) -> Command:
  263. for _tag in tags:
  264. setattr(func, _tag, True)
  265. for _tag, value in kwarg_tags.items():
  266. setattr(func, _tag, value)
  267. return func
  268. return inner
  269. def _mark_method(mark: str, *args, **kwargs) -> typing.Callable[..., Command]:
  270. """
  271. Mark method as a method of a class
  272. """
  273. def decorator(func: Command) -> Command:
  274. setattr(func, mark, True)
  275. for arg in args:
  276. setattr(func, arg, True)
  277. for kwarg, value in kwargs.items():
  278. setattr(func, kwarg, value)
  279. return func
  280. return decorator
  281. def command(*args, **kwargs):
  282. """
  283. Decorator that marks function as userbot command
  284. """
  285. return _mark_method("is_command", *args, **kwargs)
  286. def debug_method(*args, **kwargs):
  287. """
  288. Decorator that marks function as IDM (Internal Debug Method)
  289. :param name: Name of the method
  290. """
  291. return _mark_method("is_debug_method", *args, **kwargs)
  292. def inline_handler(*args, **kwargs):
  293. """
  294. Decorator that marks function as inline handler
  295. """
  296. return _mark_method("is_inline_handler", *args, **kwargs)
  297. def watcher(*args, **kwargs):
  298. """
  299. Decorator that marks function as watcher
  300. """
  301. return _mark_method("is_watcher", *args, **kwargs)
  302. def callback_handler(*args, **kwargs):
  303. """
  304. Decorator that marks function as callback handler
  305. """
  306. return _mark_method("is_callback_handler", *args, **kwargs)
  307. def raw_handler(*updates: TLObject):
  308. """
  309. Decorator that marks function as raw telethon events handler
  310. Use it to prevent zombie-event-handlers, left by unloaded modules
  311. :param updates: Update(-s) to handle
  312. ⚠️ Do not try to simulate behavior of this decorator by yourself!
  313. ⚠️ This feature won't work, if you dynamically declare method with decorator!
  314. """
  315. def inner(func: Command) -> Command:
  316. func.is_raw_handler = True
  317. func.updates = updates
  318. func.id = uuid4().hex
  319. return func
  320. return inner
  321. class Modules:
  322. """Stores all registered modules"""
  323. def __init__(
  324. self,
  325. client: "CustomTelegramClient", # type: ignore
  326. db: Database,
  327. allclients: list,
  328. translator: Translator,
  329. ):
  330. self._initial_registration = True
  331. self.commands = {}
  332. self.inline_handlers = {}
  333. self.callback_handlers = {}
  334. self.aliases = {}
  335. self.modules = [] # skipcq: PTC-W0052
  336. self.dragon_modules = []
  337. self.libraries = []
  338. self.watchers = []
  339. self._log_handlers = []
  340. self._core_commands = []
  341. self.__approve = []
  342. self.allclients = allclients
  343. self.client = client
  344. self._db = db
  345. self.db = db
  346. self._translator = translator
  347. self.secure_boot = False
  348. asyncio.ensure_future(self._junk_collector())
  349. self.inline = InlineManager(self.client, self._db, self)
  350. async def _junk_collector(self):
  351. """
  352. Periodically reloads commands, inline handlers, callback handlers and watchers from loaded
  353. modules to prevent zombie handlers
  354. """
  355. while True:
  356. await asyncio.sleep(30)
  357. commands = {}
  358. inline_handlers = {}
  359. callback_handlers = {}
  360. watchers = []
  361. for module in self.modules:
  362. commands.update(module.hikka_commands)
  363. inline_handlers.update(module.hikka_inline_handlers)
  364. callback_handlers.update(module.hikka_callback_handlers)
  365. watchers.extend(module.hikka_watchers.values())
  366. self.commands = commands
  367. self.inline_handlers = inline_handlers
  368. self.callback_handlers = callback_handlers
  369. self.watchers = watchers
  370. logger.debug(
  371. "Reloaded %s commands,"
  372. " %s inline handlers,"
  373. " %s callback handlers and"
  374. " %s watchers",
  375. len(self.commands),
  376. len(self.inline_handlers),
  377. len(self.callback_handlers),
  378. len(self.watchers),
  379. )
  380. async def register_all(
  381. self,
  382. mods: typing.Optional[typing.List[str]] = None,
  383. no_external: bool = False,
  384. ) -> typing.List[Module]:
  385. """Load all modules in the module directory"""
  386. external_mods = []
  387. if not mods:
  388. mods = [
  389. os.path.join(utils.get_base_dir(), MODULES_NAME, mod)
  390. for mod in filter(
  391. lambda x: (x.endswith(".py") and not x.startswith("_")),
  392. os.listdir(os.path.join(utils.get_base_dir(), MODULES_NAME)),
  393. )
  394. ]
  395. self.secure_boot = self._db.get(__name__, "secure_boot", False)
  396. external_mods = (
  397. []
  398. if self.secure_boot
  399. else [
  400. os.path.join(LOADED_MODULES_DIR, mod)
  401. for mod in filter(
  402. lambda x: (
  403. x.endswith(f"{self.client.tg_id}.py")
  404. and not x.startswith("_")
  405. ),
  406. os.listdir(LOADED_MODULES_DIR),
  407. )
  408. ]
  409. )
  410. loaded = []
  411. loaded += await self._register_modules(mods)
  412. if not no_external:
  413. loaded += await self._register_modules(external_mods, "<file>")
  414. return loaded
  415. async def _register_modules(
  416. self,
  417. modules: list,
  418. origin: str = "<core>",
  419. ) -> typing.List[Module]:
  420. with contextlib.suppress(AttributeError):
  421. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  422. loaded = []
  423. for mod in modules:
  424. try:
  425. mod_shortname = os.path.basename(mod).rsplit(".py", maxsplit=1)[0]
  426. module_name = f"{__package__}.{MODULES_NAME}.{mod_shortname}"
  427. user_friendly_origin = (
  428. "<core {}>" if origin == "<core>" else "<file {}>"
  429. ).format(module_name)
  430. logger.debug("Loading %s from filesystem", module_name)
  431. with open(mod, "r") as file:
  432. spec = importlib.machinery.ModuleSpec(
  433. module_name,
  434. StringLoader(file.read(), user_friendly_origin),
  435. origin=user_friendly_origin,
  436. )
  437. loaded += [await self.register_module(spec, module_name, origin)]
  438. except Exception as e:
  439. logger.exception("Failed to load module %s due to %s:", mod, e)
  440. return loaded
  441. def register_dragon(self, module: ModuleType, instance: DragonModule):
  442. for mod in self.dragon_modules.copy():
  443. if mod.name == instance.name:
  444. logger.debug("Removing dragon module %s for reload", mod.name)
  445. self.unload_dragon(mod)
  446. instance.handlers = []
  447. for name, obj in vars(module).items():
  448. for handler, group in getattr(obj, "handlers", []):
  449. try:
  450. handler = self.client.pyro_proxy.add_handler(handler, group)
  451. instance.handlers.append(handler)
  452. except Exception as e:
  453. logging.exception(
  454. "Can't add handler %s due to %s: %s",
  455. name,
  456. type(e).__name__,
  457. e,
  458. )
  459. self.dragon_modules += [instance]
  460. def unload_dragon(self, instance: DragonModule) -> bool:
  461. for handler in instance.handlers:
  462. try:
  463. self.client.pyro_proxy.remove_handler(*handler)
  464. except Exception as e:
  465. logging.exception(
  466. "Can't remove handler %s due to %s: %s",
  467. handler,
  468. type(e).__name__,
  469. e,
  470. )
  471. if instance in self.dragon_modules:
  472. self.dragon_modules.remove(instance)
  473. return True
  474. return False
  475. async def register_module(
  476. self,
  477. spec: importlib.machinery.ModuleSpec,
  478. module_name: str,
  479. origin: str = "<core>",
  480. save_fs: bool = False,
  481. is_dragon: bool = False,
  482. ) -> typing.Union[Module, typing.Tuple[ModuleType, DragonModule]]:
  483. """Register single module from importlib spec"""
  484. with contextlib.suppress(AttributeError):
  485. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  486. module = importlib.util.module_from_spec(spec)
  487. sys.modules[module_name] = module
  488. spec.loader.exec_module(module)
  489. if is_dragon:
  490. return module, DragonModule()
  491. ret = None
  492. ret = next(
  493. (
  494. value()
  495. for value in vars(module).values()
  496. if inspect.isclass(value) and issubclass(value, Module)
  497. ),
  498. None,
  499. )
  500. if hasattr(module, "__version__"):
  501. ret.__version__ = module.__version__
  502. if ret is None:
  503. ret = module.register(module_name)
  504. if not isinstance(ret, Module):
  505. raise TypeError(f"Instance is not a Module, it is {type(ret)}")
  506. await self.complete_registration(ret)
  507. ret.__origin__ = origin
  508. cls_name = ret.__class__.__name__
  509. if save_fs:
  510. path = os.path.join(
  511. LOADED_MODULES_DIR,
  512. f"{cls_name}_{self.client.tg_id}.py",
  513. )
  514. if origin == "<string>":
  515. with open(path, "w") as f:
  516. f.write(spec.loader.data.decode("utf-8"))
  517. logger.debug("Saved class %s to path %s", cls_name, path)
  518. return ret
  519. def add_aliases(self, aliases: dict):
  520. """Saves aliases and applies them to <core>/<file> modules"""
  521. self.aliases.update(aliases)
  522. for alias, cmd in aliases.items():
  523. self.add_alias(alias, cmd)
  524. def register_raw_handlers(self, instance: Module):
  525. """Register event handlers for a module"""
  526. for name, handler in utils.iter_attrs(instance):
  527. if getattr(handler, "is_raw_handler", False):
  528. self.client.dispatcher.raw_handlers.append(handler)
  529. logger.debug(
  530. "Registered raw handler %s for %s. ID: %s",
  531. name,
  532. instance.__class__.__name__,
  533. handler.id,
  534. )
  535. def register_commands(self, instance: Module):
  536. """Register commands from instance"""
  537. with contextlib.suppress(AttributeError):
  538. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  539. if instance.__origin__.startswith("<core"):
  540. self._core_commands += list(
  541. map(lambda x: x.lower(), list(instance.hikka_commands))
  542. )
  543. for _command, cmd in instance.hikka_commands.items():
  544. # Restrict overwriting core modules' commands
  545. if (
  546. _command.lower() in self._core_commands
  547. and not instance.__origin__.startswith("<core")
  548. ):
  549. with contextlib.suppress(Exception):
  550. self.modules.remove(instance)
  551. raise CoreOverwriteError(command=_command)
  552. self.commands.update({_command.lower(): cmd})
  553. for alias, cmd in self.aliases.copy().items():
  554. if cmd in instance.hikka_commands:
  555. self.add_alias(alias, cmd)
  556. self.register_inline_stuff(instance)
  557. def register_inline_stuff(self, instance: Module):
  558. for name, func in instance.hikka_inline_handlers.copy().items():
  559. if name.lower() in self.inline_handlers:
  560. if (
  561. hasattr(func, "__self__")
  562. and hasattr(self.inline_handlers[name], "__self__")
  563. and (
  564. func.__self__.__class__.__name__
  565. != self.inline_handlers[name].__self__.__class__.__name__
  566. )
  567. ):
  568. logger.debug(
  569. "Duplicate inline_handler %s of %s",
  570. name,
  571. instance.__class__.__name__,
  572. )
  573. logger.debug(
  574. "Replacing inline_handler %s for %s",
  575. self.inline_handlers[name],
  576. instance.__class__.__name__,
  577. )
  578. self.inline_handlers.update({name.lower(): func})
  579. for name, func in instance.hikka_callback_handlers.copy().items():
  580. if name.lower() in self.callback_handlers and (
  581. hasattr(func, "__self__")
  582. and hasattr(self.callback_handlers[name], "__self__")
  583. and func.__self__.__class__.__name__
  584. != self.callback_handlers[name].__self__.__class__.__name__
  585. ):
  586. logger.debug(
  587. "Duplicate callback_handler %s of %s",
  588. name,
  589. instance.__class__.__name__,
  590. )
  591. self.callback_handlers.update({name.lower(): func})
  592. def unregister_inline_stuff(self, instance: Module, purpose: str):
  593. for name, func in instance.hikka_inline_handlers.copy().items():
  594. if name.lower() in self.inline_handlers and (
  595. hasattr(func, "__self__")
  596. and hasattr(self.inline_handlers[name], "__self__")
  597. and func.__self__.__class__.__name__
  598. == self.inline_handlers[name].__self__.__class__.__name__
  599. ):
  600. del self.inline_handlers[name.lower()]
  601. logger.debug(
  602. "Unregistered inline_handler %s of %s for %s",
  603. name,
  604. instance.__class__.__name__,
  605. purpose,
  606. )
  607. for name, func in instance.hikka_callback_handlers.copy().items():
  608. if name.lower() in self.callback_handlers and (
  609. hasattr(func, "__self__")
  610. and hasattr(self.callback_handlers[name], "__self__")
  611. and func.__self__.__class__.__name__
  612. == self.callback_handlers[name].__self__.__class__.__name__
  613. ):
  614. del self.callback_handlers[name.lower()]
  615. logger.debug(
  616. "Unregistered callback_handler %s of %s for %s",
  617. name,
  618. instance.__class__.__name__,
  619. purpose,
  620. )
  621. def register_watchers(self, instance: Module):
  622. """Register watcher from instance"""
  623. with contextlib.suppress(AttributeError):
  624. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  625. for _watcher in self.watchers:
  626. if _watcher.__self__.__class__.__name__ == instance.__class__.__name__:
  627. logger.debug("Removing watcher %s for update", _watcher)
  628. self.watchers.remove(_watcher)
  629. for _watcher in instance.hikka_watchers.values():
  630. self.watchers += [_watcher]
  631. def lookup(
  632. self,
  633. modname: str,
  634. include_dragon: bool = False,
  635. ) -> typing.Union[bool, Module, DragonModule, Library]:
  636. return (
  637. next(
  638. (lib for lib in self.libraries if lib.name.lower() == modname.lower()),
  639. False,
  640. )
  641. or next(
  642. (
  643. mod
  644. for mod in self.modules
  645. if mod.__class__.__name__.lower() == modname.lower()
  646. or mod.name.lower() == modname.lower()
  647. ),
  648. False,
  649. )
  650. or (
  651. next(
  652. (
  653. mod
  654. for mod in self.dragon_modules
  655. if mod.name.lower() == modname.lower()
  656. ),
  657. False,
  658. )
  659. if include_dragon
  660. else False
  661. )
  662. )
  663. @property
  664. def get_approved_channel(self):
  665. return self.__approve.pop(0) if self.__approve else None
  666. def get_prefix(self, userbot: typing.Optional[str] = None) -> str:
  667. """Get prefix for specific userbot. Pass `None` to get Hikka prefix"""
  668. if userbot == "dragon":
  669. key = "dragon.prefix"
  670. default = ","
  671. else:
  672. key = "hikka.main"
  673. default = "."
  674. return self._db.get(key, "command_prefix", default)
  675. async def complete_registration(self, instance: Module):
  676. """Complete registration of instance"""
  677. with contextlib.suppress(AttributeError):
  678. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  679. instance.allmodules = self
  680. instance.internal_init()
  681. for module in self.modules:
  682. if module.__class__.__name__ == instance.__class__.__name__:
  683. if module.__origin__.startswith("<core"):
  684. raise CoreOverwriteError(
  685. module=module.__class__.__name__[:-3]
  686. if module.__class__.__name__.endswith("Mod")
  687. else module.__class__.__name__
  688. )
  689. logger.debug("Removing module %s for update", module)
  690. await module.on_unload()
  691. self.modules.remove(module)
  692. for _, method in utils.iter_attrs(module):
  693. if isinstance(method, InfiniteLoop):
  694. method.stop()
  695. logger.debug(
  696. "Stopped loop in module %s, method %s",
  697. module,
  698. method,
  699. )
  700. self.modules += [instance]
  701. def find_alias(
  702. self,
  703. alias: str,
  704. include_legacy: bool = False,
  705. ) -> typing.Optional[str]:
  706. if not alias:
  707. return None
  708. for command_name, _command in self.commands.items():
  709. aliases = []
  710. if getattr(_command, "alias", None) and not (
  711. aliases := getattr(_command, "aliases", None)
  712. ):
  713. aliases = [_command.alias]
  714. if not aliases:
  715. continue
  716. if any(
  717. alias.lower() == _alias.lower()
  718. and alias.lower() not in self._core_commands
  719. for _alias in aliases
  720. ):
  721. return command_name
  722. if alias in self.aliases and include_legacy:
  723. return self.aliases[alias]
  724. return None
  725. def dispatch(self, _command: str) -> typing.Tuple[str, typing.Optional[str]]:
  726. """Dispatch command to appropriate module"""
  727. return next(
  728. (
  729. (cmd, self.commands[cmd.lower()])
  730. for cmd in [
  731. _command,
  732. self.aliases.get(_command.lower()),
  733. self.find_alias(_command),
  734. ]
  735. if cmd and cmd.lower() in self.commands
  736. ),
  737. (_command, None),
  738. )
  739. def send_config(self, skip_hook: bool = False):
  740. """Configure modules"""
  741. for mod in self.modules:
  742. self.send_config_one(mod, skip_hook)
  743. def send_config_one(self, mod: Module, skip_hook: bool = False):
  744. """Send config to single instance"""
  745. with contextlib.suppress(AttributeError):
  746. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  747. if hasattr(mod, "config"):
  748. modcfg = self._db.get(
  749. mod.__class__.__name__,
  750. "__config__",
  751. {},
  752. )
  753. try:
  754. for conf in mod.config:
  755. with contextlib.suppress(validators.ValidationError):
  756. mod.config.set_no_raise(
  757. conf,
  758. (
  759. modcfg[conf]
  760. if conf in modcfg
  761. else os.environ.get(f"{mod.__class__.__name__}.{conf}")
  762. or mod.config.getdef(conf)
  763. ),
  764. )
  765. except AttributeError:
  766. logger.warning(
  767. "Got invalid config instance. Expected `ModuleConfig`, got %s, %s",
  768. type(mod.config),
  769. mod.config,
  770. )
  771. if not hasattr(mod, "name"):
  772. mod.name = mod.strings["name"]
  773. if skip_hook:
  774. return
  775. if hasattr(mod, "strings"):
  776. mod.strings = Strings(mod, self._translator)
  777. mod.translator = self._translator
  778. try:
  779. mod.config_complete()
  780. except Exception as e:
  781. logger.exception("Failed to send mod config complete signal due to %s", e)
  782. raise
  783. async def send_ready(self):
  784. """Send all data to all modules"""
  785. await self.inline.register_manager()
  786. try:
  787. await asyncio.gather(*[self.send_ready_one(mod) for mod in self.modules])
  788. except Exception as e:
  789. logger.exception("Failed to send mod init complete signal due to %s", e)
  790. async def send_ready_one(
  791. self,
  792. mod: Module,
  793. no_self_unload: bool = False,
  794. from_dlmod: bool = False,
  795. ):
  796. with contextlib.suppress(AttributeError):
  797. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  798. if from_dlmod:
  799. try:
  800. if len(inspect.signature(mod.on_dlmod).parameters) == 2:
  801. await mod.on_dlmod(self.client, self._db)
  802. else:
  803. await mod.on_dlmod()
  804. except Exception:
  805. logger.info("Can't process `on_dlmod` hook", exc_info=True)
  806. try:
  807. if len(inspect.signature(mod.client_ready).parameters) == 2:
  808. await mod.client_ready(self.client, self._db)
  809. else:
  810. await mod.client_ready()
  811. except SelfUnload as e:
  812. if no_self_unload:
  813. raise e
  814. logger.debug("Unloading %s, because it raised SelfUnload", mod)
  815. self.modules.remove(mod)
  816. except SelfSuspend as e:
  817. if no_self_unload:
  818. raise e
  819. logger.debug("Suspending %s, because it raised SelfSuspend", mod)
  820. return
  821. except Exception as e:
  822. logger.exception(
  823. "Failed to send mod init complete signal for %s due to %s,"
  824. " attempting unload",
  825. mod,
  826. e,
  827. )
  828. self.modules.remove(mod)
  829. raise
  830. for _, method in utils.iter_attrs(mod):
  831. if isinstance(method, InfiniteLoop):
  832. setattr(method, "module_instance", mod)
  833. if method.autostart:
  834. method.start()
  835. logger.debug("Added module %s to method %s", mod, method)
  836. self.unregister_commands(mod, "update")
  837. self.unregister_raw_handlers(mod, "update")
  838. self.register_commands(mod)
  839. self.register_watchers(mod)
  840. self.register_raw_handlers(mod)
  841. def get_classname(self, name: str) -> str:
  842. return next(
  843. (
  844. module.__class__.__module__
  845. for module in reversed(self.modules)
  846. if name in (module.name, module.__class__.__module__)
  847. ),
  848. name,
  849. )
  850. async def unload_module(self, classname: str) -> typing.List[str]:
  851. """Remove module and all stuff from it"""
  852. worked = []
  853. with contextlib.suppress(AttributeError):
  854. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id)
  855. for module in self.modules:
  856. if classname.lower() in (
  857. module.name.lower(),
  858. module.__class__.__name__.lower(),
  859. ):
  860. if module.__origin__.startswith("<core"):
  861. raise CoreUnloadError(module.__class__.__name__)
  862. worked += [module.__class__.__name__]
  863. name = module.__class__.__name__
  864. path = os.path.join(
  865. LOADED_MODULES_DIR,
  866. f"{name}_{self.client.tg_id}.py",
  867. )
  868. if os.path.isfile(path):
  869. os.remove(path)
  870. logger.debug("Removed %s file at path %s", name, path)
  871. logger.debug("Removing module %s for unload", module)
  872. self.modules.remove(module)
  873. await module.on_unload()
  874. self.unregister_raw_handlers(module, "unload")
  875. self.unregister_loops(module, "unload")
  876. self.unregister_commands(module, "unload")
  877. self.unregister_watchers(module, "unload")
  878. self.unregister_inline_stuff(module, "unload")
  879. logger.debug("Worked: %s", worked)
  880. return worked
  881. def unregister_loops(self, instance: Module, purpose: str):
  882. for name, method in utils.iter_attrs(instance):
  883. if isinstance(method, InfiniteLoop):
  884. logger.debug(
  885. "Stopping loop for %s in module %s, method %s",
  886. purpose,
  887. instance.__class__.__name__,
  888. name,
  889. )
  890. method.stop()
  891. def unregister_commands(self, instance: Module, purpose: str):
  892. for name, cmd in self.commands.copy().items():
  893. if cmd.__self__.__class__.__name__ == instance.__class__.__name__:
  894. logger.debug(
  895. "Removing command %s of module %s for %s",
  896. name,
  897. instance.__class__.__name__,
  898. purpose,
  899. )
  900. del self.commands[name]
  901. for alias, _command in self.aliases.copy().items():
  902. if _command == name:
  903. del self.aliases[alias]
  904. def unregister_watchers(self, instance: Module, purpose: str):
  905. for _watcher in self.watchers.copy():
  906. if _watcher.__self__.__class__.__name__ == instance.__class__.__name__:
  907. logger.debug(
  908. "Removing watcher %s of module %s for %s",
  909. _watcher,
  910. instance.__class__.__name__,
  911. purpose,
  912. )
  913. self.watchers.remove(_watcher)
  914. def unregister_raw_handlers(self, instance: Module, purpose: str):
  915. """Unregister event handlers for a module"""
  916. for handler in self.client.dispatcher.raw_handlers:
  917. if handler.__self__.__class__.__name__ == instance.__class__.__name__:
  918. self.client.dispatcher.raw_handlers.remove(handler)
  919. logger.debug(
  920. "Unregistered raw handler of module %s for %s. ID: %s",
  921. instance.__class__.__name__,
  922. purpose,
  923. handler.id,
  924. )
  925. def add_alias(self, alias: str, cmd: str) -> bool:
  926. """Make an alias"""
  927. if cmd not in self.commands:
  928. return False
  929. self.aliases[alias.lower().strip()] = cmd
  930. return True
  931. def remove_alias(self, alias: str) -> bool:
  932. """Remove an alias"""
  933. return bool(self.aliases.pop(alias.lower().strip(), None))
  934. async def log(self, *args, **kwargs):
  935. """Unnecessary placeholder for logging"""
  936. async def reload_translations(self) -> bool:
  937. if not await self._translator.init():
  938. return False
  939. for module in self.modules:
  940. try:
  941. module.config_complete(reload_dynamic_translate=True)
  942. except Exception as e:
  943. logger.debug(
  944. "Can't complete dynamic translations reload of %s due to %s",
  945. module,
  946. e,
  947. )
  948. return True