types.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import ast
  7. import asyncio
  8. import contextlib
  9. import copy
  10. import importlib
  11. import importlib.machinery
  12. import importlib.util
  13. import inspect
  14. import logging
  15. import os
  16. import re
  17. import sys
  18. import time
  19. import typing
  20. from dataclasses import dataclass, field
  21. from importlib.abc import SourceLoader
  22. import requests
  23. from hikkatl.hints import EntityLike
  24. from hikkatl.tl.functions.account import UpdateNotifySettingsRequest
  25. from hikkatl.tl.types import (
  26. Channel,
  27. ChannelFull,
  28. InputPeerNotifySettings,
  29. Message,
  30. UserFull,
  31. )
  32. from . import version
  33. from ._reference_finder import replace_all_refs
  34. from .inline.types import (
  35. BotInlineCall,
  36. BotInlineMessage,
  37. BotMessage,
  38. InlineCall,
  39. InlineMessage,
  40. InlineQuery,
  41. InlineUnit,
  42. )
  43. from .pointers import PointerDict, PointerList
  44. __all__ = [
  45. "JSONSerializable",
  46. "HikkaReplyMarkup",
  47. "ListLike",
  48. "Command",
  49. "StringLoader",
  50. "Module",
  51. "get_commands",
  52. "get_inline_handlers",
  53. "get_callback_handlers",
  54. "BotInlineCall",
  55. "BotMessage",
  56. "InlineCall",
  57. "InlineMessage",
  58. "InlineQuery",
  59. "InlineUnit",
  60. "BotInlineMessage",
  61. "PointerDict",
  62. "PointerList",
  63. ]
  64. logger = logging.getLogger(__name__)
  65. JSONSerializable = typing.Union[str, int, float, bool, list, dict, None]
  66. HikkaReplyMarkup = typing.Union[typing.List[typing.List[dict]], typing.List[dict], dict]
  67. ListLike = typing.Union[list, set, tuple]
  68. Command = typing.Callable[..., typing.Awaitable[typing.Any]]
  69. class StringLoader(SourceLoader):
  70. """Load a python module/file from a string"""
  71. def __init__(self, data: str, origin: str):
  72. self.data = data.encode("utf-8") if isinstance(data, str) else data
  73. self.origin = origin
  74. def get_source(self, _=None) -> str:
  75. return self.data.decode("utf-8")
  76. def get_code(self, fullname: str) -> bytes:
  77. return (
  78. compile(source, self.origin, "exec", dont_inherit=True)
  79. if (source := self.get_data(fullname))
  80. else None
  81. )
  82. def get_filename(self, *args, **kwargs) -> str:
  83. return self.origin
  84. def get_data(self, *args, **kwargs) -> bytes:
  85. return self.data
  86. class Module:
  87. strings = {"name": "Unknown"}
  88. """There is no help for this module"""
  89. def config_complete(self):
  90. """Called when module.config is populated"""
  91. async def client_ready(self):
  92. """Called after client is ready (after config_loaded)"""
  93. def internal_init(self):
  94. """Called after the class is initialized in order to pass the client and db. Do not call it yourself"""
  95. self.db = self.allmodules.db
  96. self._db = self.allmodules.db
  97. self.client = self.allmodules.client
  98. self._client = self.allmodules.client
  99. self.lookup = self.allmodules.lookup
  100. self.get_prefix = self.allmodules.get_prefix
  101. self.inline = self.allmodules.inline
  102. self.allclients = self.allmodules.allclients
  103. self.tg_id = self._client.tg_id
  104. self._tg_id = self._client.tg_id
  105. async def on_unload(self):
  106. """Called after unloading / reloading module"""
  107. async def on_dlmod(self):
  108. """
  109. Called after the module is first time loaded with .dlmod or .loadmod
  110. Possible use-cases:
  111. - Send reaction to author's channel message
  112. - Create asset folder
  113. - ...
  114. ⚠️ Note, that any error there will not interrupt module load, and will just
  115. send a message to logs with verbosity INFO and exception traceback
  116. """
  117. async def invoke(
  118. self,
  119. command: str,
  120. args: typing.Optional[str] = None,
  121. peer: typing.Optional[EntityLike] = None,
  122. message: typing.Optional[Message] = None,
  123. edit: bool = False,
  124. ) -> Message:
  125. """
  126. Invoke another command
  127. :param command: Command to invoke
  128. :param args: Arguments to pass to command
  129. :param peer: Peer to send the command to. If not specified, will send to the current chat
  130. :param edit: Whether to edit the message
  131. :returns Message:
  132. """
  133. if command not in self.allmodules.commands:
  134. raise ValueError(f"Command {command} not found")
  135. if not message and not peer:
  136. raise ValueError("Either peer or message must be specified")
  137. cmd = f"{self.get_prefix()}{command} {args or ''}".strip()
  138. message = (
  139. (await self._client.send_message(peer, cmd))
  140. if peer
  141. else (await (message.edit if edit else message.respond)(cmd))
  142. )
  143. await self.allmodules.commands[command](message)
  144. return message
  145. @property
  146. def commands(self) -> typing.Dict[str, Command]:
  147. """List of commands that module supports"""
  148. return get_commands(self)
  149. @property
  150. def hikka_commands(self) -> typing.Dict[str, Command]:
  151. """List of commands that module supports"""
  152. return get_commands(self)
  153. @property
  154. def inline_handlers(self) -> typing.Dict[str, Command]:
  155. """List of inline handlers that module supports"""
  156. return get_inline_handlers(self)
  157. @property
  158. def hikka_inline_handlers(self) -> typing.Dict[str, Command]:
  159. """List of inline handlers that module supports"""
  160. return get_inline_handlers(self)
  161. @property
  162. def callback_handlers(self) -> typing.Dict[str, Command]:
  163. """List of callback handlers that module supports"""
  164. return get_callback_handlers(self)
  165. @property
  166. def hikka_callback_handlers(self) -> typing.Dict[str, Command]:
  167. """List of callback handlers that module supports"""
  168. return get_callback_handlers(self)
  169. @property
  170. def watchers(self) -> typing.Dict[str, Command]:
  171. """List of watchers that module supports"""
  172. return get_watchers(self)
  173. @property
  174. def hikka_watchers(self) -> typing.Dict[str, Command]:
  175. """List of watchers that module supports"""
  176. return get_watchers(self)
  177. @commands.setter
  178. def commands(self, _):
  179. pass
  180. @hikka_commands.setter
  181. def hikka_commands(self, _):
  182. pass
  183. @inline_handlers.setter
  184. def inline_handlers(self, _):
  185. pass
  186. @hikka_inline_handlers.setter
  187. def hikka_inline_handlers(self, _):
  188. pass
  189. @callback_handlers.setter
  190. def callback_handlers(self, _):
  191. pass
  192. @hikka_callback_handlers.setter
  193. def hikka_callback_handlers(self, _):
  194. pass
  195. @watchers.setter
  196. def watchers(self, _):
  197. pass
  198. @hikka_watchers.setter
  199. def hikka_watchers(self, _):
  200. pass
  201. async def animate(
  202. self,
  203. message: typing.Union[Message, InlineMessage],
  204. frames: typing.List[str],
  205. interval: typing.Union[float, int],
  206. *,
  207. inline: bool = False,
  208. ) -> None:
  209. """
  210. Animate message
  211. :param message: Message to animate
  212. :param frames: A List of strings which are the frames of animation
  213. :param interval: Animation delay
  214. :param inline: Whether to use inline bot for animation
  215. :returns message:
  216. Please, note that if you set `inline=True`, first frame will be shown with an empty
  217. button due to the limitations of Telegram API
  218. """
  219. from . import utils
  220. with contextlib.suppress(AttributeError):
  221. _hikka_client_id_logging_tag = copy.copy(self.client.tg_id) # noqa: F841
  222. if interval < 0.1:
  223. logger.warning(
  224. "Resetting animation interval to 0.1s, because it may get you in"
  225. " floodwaits"
  226. )
  227. interval = 0.1
  228. for frame in frames:
  229. if isinstance(message, Message):
  230. if inline:
  231. message = await self.inline.form(
  232. message=message,
  233. text=frame,
  234. reply_markup={"text": "\u0020\u2800", "data": "empty"},
  235. )
  236. else:
  237. message = await utils.answer(message, frame)
  238. elif isinstance(message, InlineMessage) and inline:
  239. await message.edit(frame)
  240. await asyncio.sleep(interval)
  241. return message
  242. def get(
  243. self,
  244. key: str,
  245. default: typing.Optional[JSONSerializable] = None,
  246. ) -> JSONSerializable:
  247. return self._db.get(self.__class__.__name__, key, default)
  248. def set(self, key: str, value: JSONSerializable) -> bool:
  249. self._db.set(self.__class__.__name__, key, value)
  250. def pointer(
  251. self,
  252. key: str,
  253. default: typing.Optional[JSONSerializable] = None,
  254. item_type: typing.Optional[typing.Any] = None,
  255. ) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
  256. return self._db.pointer(self.__class__.__name__, key, default, item_type)
  257. async def _decline(
  258. self,
  259. call: InlineCall,
  260. channel: EntityLike,
  261. event: asyncio.Event,
  262. ):
  263. from . import utils
  264. self._db.set(
  265. "hikka.main",
  266. "declined_joins",
  267. list(set(self._db.get("hikka.main", "declined_joins", []) + [channel.id])),
  268. )
  269. event.status = False
  270. event.set()
  271. await call.edit(
  272. (
  273. "✖️ <b>Declined joining <a"
  274. f' href="https://t.me/{channel.username}">{utils.escape_html(channel.title)}</a></b>'
  275. ),
  276. gif="https://data.whicdn.com/images/324445359/original.gif",
  277. )
  278. async def request_join(
  279. self,
  280. peer: EntityLike,
  281. reason: str,
  282. assure_joined: typing.Optional[bool] = False,
  283. ) -> bool:
  284. """
  285. Request to join a channel.
  286. :param peer: The channel to join.
  287. :param reason: The reason for joining.
  288. :param assure_joined: If set, module will not be loaded unless the required channel is joined.
  289. ⚠️ Works only in `client_ready`!
  290. ⚠️ If user declines to join channel, he will not be asked to
  291. join again, so unless he joins it manually, module will not be loaded
  292. ever.
  293. :return: Status of the request.
  294. :rtype: bool
  295. :notice: This method will block module loading until the request is approved or declined.
  296. """
  297. from . import utils
  298. event = asyncio.Event()
  299. await self.client(
  300. UpdateNotifySettingsRequest(
  301. peer=self.inline.bot_username,
  302. settings=InputPeerNotifySettings(show_previews=False, silent=False),
  303. )
  304. )
  305. channel = await self.client.get_entity(peer)
  306. if channel.id in self._db.get("hikka.main", "declined_joins", []):
  307. if assure_joined:
  308. raise LoadError(
  309. f"You need to join @{channel.username} in order to use this module"
  310. )
  311. return False
  312. if not isinstance(channel, Channel):
  313. raise TypeError("`peer` field must be a channel")
  314. if getattr(channel, "left", True):
  315. channel = await self.client.force_get_entity(peer)
  316. if not getattr(channel, "left", True):
  317. return True
  318. await self.inline.bot.send_animation(
  319. self.tg_id,
  320. "https://i.gifer.com/SD5S.gif",
  321. caption=(
  322. self._client.loader.lookup("translations")
  323. .strings("requested_join")
  324. .format(
  325. self.__class__.__name__,
  326. channel.username,
  327. utils.escape_html(channel.title),
  328. utils.escape_html(reason),
  329. )
  330. ),
  331. reply_markup=self.inline.generate_markup([
  332. {
  333. "text": "💫 Approve",
  334. "callback": self.lookup("loader").approve_internal,
  335. "args": (channel, event),
  336. },
  337. {
  338. "text": "✖️ Decline",
  339. "callback": self._decline,
  340. "args": (channel, event),
  341. },
  342. ]),
  343. )
  344. self.hikka_wait_channel_approve = (
  345. self.__class__.__name__,
  346. channel,
  347. reason,
  348. )
  349. event.status = False
  350. await event.wait()
  351. with contextlib.suppress(AttributeError):
  352. delattr(self, "hikka_wait_channel_approve")
  353. if assure_joined and not event.status:
  354. raise LoadError(
  355. f"You need to join @{channel.username} in order to use this module"
  356. )
  357. return event.status
  358. async def import_lib(
  359. self,
  360. url: str,
  361. *,
  362. suspend_on_error: typing.Optional[bool] = False,
  363. _did_requirements: bool = False,
  364. ) -> "Library":
  365. """
  366. Import library from url and register it in :obj:`Modules`
  367. :param url: Url to import
  368. :param suspend_on_error: Will raise :obj:`loader.SelfSuspend` if library can't be loaded
  369. :return: :obj:`Library`
  370. :raise: SelfUnload if :attr:`suspend_on_error` is True and error occurred
  371. :raise: HTTPError if library is not found
  372. :raise: ImportError if library doesn't have any class which is a subclass of :obj:`loader.Library`
  373. :raise: ImportError if library name doesn't end with `Lib`
  374. :raise: RuntimeError if library throws in :method:`init`
  375. :raise: RuntimeError if library classname exists in :obj:`Modules`.libraries
  376. """
  377. from . import utils # Avoiding circular import
  378. from .loader import USER_INSTALL, VALID_PIP_PACKAGES
  379. from .translations import Strings
  380. def _raise(e: Exception):
  381. if suspend_on_error:
  382. raise SelfSuspend("Required library is not available or is corrupted.")
  383. raise e
  384. if not utils.check_url(url):
  385. _raise(ValueError("Invalid url for library"))
  386. code = await utils.run_sync(requests.get, url)
  387. code.raise_for_status()
  388. code = code.text
  389. if re.search(r"# ?scope: ?hikka_min", code):
  390. ver = tuple(
  391. map(
  392. int,
  393. re.search(r"# ?scope: ?hikka_min ((\d+\.){2}\d+)", code)[1].split(
  394. "."
  395. ),
  396. )
  397. )
  398. if version.__version__ < ver:
  399. _raise(
  400. RuntimeError(
  401. f"Library requires Hikka version {'{}.{}.{}'.format(*ver)}+"
  402. )
  403. )
  404. module = f"hikka.libraries.{url.replace('%', '%%').replace('.', '%d')}"
  405. origin = f"<library {url}>"
  406. spec = importlib.machinery.ModuleSpec(
  407. module,
  408. StringLoader(code, origin),
  409. origin=origin,
  410. )
  411. try:
  412. instance = importlib.util.module_from_spec(spec)
  413. sys.modules[module] = instance
  414. spec.loader.exec_module(instance)
  415. except ImportError as e:
  416. logger.info(
  417. "Library loading failed, attemping dependency installation (%s)",
  418. e.name,
  419. )
  420. # Let's try to reinstall dependencies
  421. try:
  422. requirements = list(
  423. filter(
  424. lambda x: not x.startswith(("-", "_", ".")),
  425. map(
  426. str.strip,
  427. VALID_PIP_PACKAGES.search(code)[1].split(),
  428. ),
  429. )
  430. )
  431. except TypeError:
  432. logger.warning(
  433. "No valid pip packages specified in code, attemping"
  434. " installation from error"
  435. )
  436. requirements = [e.name]
  437. logger.debug("Installing requirements: %s", requirements)
  438. if not requirements or _did_requirements:
  439. _raise(e)
  440. pip = await asyncio.create_subprocess_exec(
  441. sys.executable,
  442. "-m",
  443. "pip",
  444. "install",
  445. "--upgrade",
  446. "-q",
  447. "--disable-pip-version-check",
  448. "--no-warn-script-location",
  449. *["--user"] if USER_INSTALL else [],
  450. *requirements,
  451. )
  452. rc = await pip.wait()
  453. if rc != 0:
  454. _raise(e)
  455. importlib.invalidate_caches()
  456. kwargs = utils.get_kwargs()
  457. kwargs["_did_requirements"] = True
  458. return await self._mod_import_lib(**kwargs) # Try again
  459. lib_obj = next(
  460. (
  461. value()
  462. for value in vars(instance).values()
  463. if inspect.isclass(value) and issubclass(value, Library)
  464. ),
  465. None,
  466. )
  467. if not lib_obj:
  468. _raise(ImportError("Invalid library. No class found"))
  469. if not lib_obj.__class__.__name__.endswith("Lib"):
  470. _raise(
  471. ImportError(
  472. "Invalid library. Classname {} does not end with 'Lib'".format(
  473. lib_obj.__class__.__name__
  474. )
  475. )
  476. )
  477. if (
  478. all(
  479. line.replace(" ", "") != "#scope:no_stats" for line in code.splitlines()
  480. )
  481. and self._db.get("hikka.main", "stats", True)
  482. and url is not None
  483. and utils.check_url(url)
  484. ):
  485. with contextlib.suppress(Exception):
  486. await self.lookup("loader")._send_stats(url)
  487. lib_obj.source_url = url.strip("/")
  488. lib_obj.allmodules = self.allmodules
  489. lib_obj.internal_init()
  490. for old_lib in self.allmodules.libraries:
  491. if old_lib.name == lib_obj.name and (
  492. not isinstance(getattr(old_lib, "version", None), tuple)
  493. and not isinstance(getattr(lib_obj, "version", None), tuple)
  494. or old_lib.version >= lib_obj.version
  495. ):
  496. logger.debug("Using existing instance of library %s", old_lib.name)
  497. return old_lib
  498. if hasattr(lib_obj, "init"):
  499. if not callable(lib_obj.init):
  500. _raise(ValueError("Library init() must be callable"))
  501. try:
  502. await lib_obj.init()
  503. except Exception:
  504. _raise(RuntimeError("Library init() failed"))
  505. if hasattr(lib_obj, "config"):
  506. if not isinstance(lib_obj.config, LibraryConfig):
  507. _raise(
  508. RuntimeError("Library config must be a `LibraryConfig` instance")
  509. )
  510. libcfg = lib_obj.db.get(
  511. lib_obj.__class__.__name__,
  512. "__config__",
  513. {},
  514. )
  515. for conf in lib_obj.config:
  516. with contextlib.suppress(Exception):
  517. lib_obj.config.set_no_raise(
  518. conf,
  519. (
  520. libcfg[conf]
  521. if conf in libcfg
  522. else os.environ.get(f"{lib_obj.__class__.__name__}.{conf}")
  523. or lib_obj.config.getdef(conf)
  524. ),
  525. )
  526. if hasattr(lib_obj, "strings"):
  527. lib_obj.strings = Strings(lib_obj, self.translator)
  528. lib_obj.translator = self.translator
  529. for old_lib in self.allmodules.libraries:
  530. if old_lib.name == lib_obj.name:
  531. if hasattr(old_lib, "on_lib_update") and callable(
  532. old_lib.on_lib_update
  533. ):
  534. await old_lib.on_lib_update(lib_obj)
  535. replace_all_refs(old_lib, lib_obj)
  536. logger.debug(
  537. "Replacing existing instance of library %s with updated object",
  538. lib_obj.name,
  539. )
  540. return lib_obj
  541. self.allmodules.libraries += [lib_obj]
  542. return lib_obj
  543. class Library:
  544. """All external libraries must have a class-inheritant from this class"""
  545. def internal_init(self):
  546. self.name = self.__class__.__name__
  547. self.db = self.allmodules.db
  548. self._db = self.allmodules.db
  549. self.client = self.allmodules.client
  550. self._client = self.allmodules.client
  551. self.tg_id = self._client.tg_id
  552. self._tg_id = self._client.tg_id
  553. self.lookup = self.allmodules.lookup
  554. self.get_prefix = self.allmodules.get_prefix
  555. self.inline = self.allmodules.inline
  556. self.allclients = self.allmodules.allclients
  557. def _lib_get(
  558. self,
  559. key: str,
  560. default: typing.Optional[JSONSerializable] = None,
  561. ) -> JSONSerializable:
  562. return self._db.get(self.__class__.__name__, key, default)
  563. def _lib_set(self, key: str, value: JSONSerializable) -> bool:
  564. self._db.set(self.__class__.__name__, key, value)
  565. def _lib_pointer(
  566. self,
  567. key: str,
  568. default: typing.Optional[JSONSerializable] = None,
  569. ) -> typing.Union[JSONSerializable, PointerDict, PointerList]:
  570. return self._db.pointer(self.__class__.__name__, key, default)
  571. class LoadError(Exception):
  572. """Tells user, why your module can't be loaded, if raised in `client_ready`"""
  573. def __init__(self, error_message: str): # skipcq: PYL-W0231
  574. self._error = error_message
  575. def __str__(self) -> str:
  576. return self._error
  577. class CoreOverwriteError(LoadError):
  578. """Is being raised when core module or command is overwritten"""
  579. def __init__(
  580. self,
  581. module: typing.Optional[str] = None,
  582. command: typing.Optional[str] = None,
  583. ):
  584. self.type = "module" if module else "command"
  585. self.target = module or command
  586. super().__init__(str(self))
  587. def __str__(self) -> str:
  588. return (
  589. f"{'Module' if self.type == 'module' else 'command'} {self.target} will not"
  590. " be overwritten, because it's core"
  591. )
  592. class CoreUnloadError(Exception):
  593. """Is being raised when user tries to unload core module"""
  594. def __init__(self, module: str):
  595. self.module = module
  596. super().__init__()
  597. def __str__(self) -> str:
  598. return f"Module {self.module} will not be unloaded, because it's core"
  599. class SelfUnload(Exception):
  600. """Silently unloads module, if raised in `client_ready`"""
  601. def __init__(self, error_message: str = ""):
  602. super().__init__()
  603. self._error = error_message
  604. def __str__(self) -> str:
  605. return self._error
  606. class SelfSuspend(Exception):
  607. """
  608. Silently suspends module, if raised in `client_ready`
  609. Commands and watcher will not be registered if raised
  610. Module won't be unloaded from db and will be unfreezed after restart, unless
  611. the exception is raised again
  612. """
  613. def __init__(self, error_message: str = ""):
  614. super().__init__()
  615. self._error = error_message
  616. def __str__(self) -> str:
  617. return self._error
  618. class StopLoop(Exception):
  619. """Stops the loop, in which is raised"""
  620. class ModuleConfig(dict):
  621. """Stores config for modules and apparently libraries"""
  622. def __init__(self, *entries: typing.Union[str, "ConfigValue"]):
  623. if all(isinstance(entry, ConfigValue) for entry in entries):
  624. # New config format processing
  625. self._config = {config.option: config for config in entries}
  626. else:
  627. # Legacy config processing
  628. keys = []
  629. values = []
  630. defaults = []
  631. docstrings = []
  632. for i, entry in enumerate(entries):
  633. if i % 3 == 0:
  634. keys += [entry]
  635. elif i % 3 == 1:
  636. values += [entry]
  637. defaults += [entry]
  638. else:
  639. docstrings += [entry]
  640. self._config = {
  641. key: ConfigValue(option=key, default=default, doc=doc)
  642. for key, default, doc in zip(keys, defaults, docstrings)
  643. }
  644. super().__init__(
  645. {option: config.value for option, config in self._config.items()}
  646. )
  647. def getdoc(self, key: str, message: typing.Optional[Message] = None) -> str:
  648. """Get the documentation by key"""
  649. ret = self._config[key].doc
  650. if callable(ret):
  651. try:
  652. # Compatibility tweak
  653. # does nothing in Hikka
  654. ret = ret(message)
  655. except Exception:
  656. ret = ret()
  657. return ret
  658. def getdef(self, key: str) -> str:
  659. """Get the default value by key"""
  660. return self._config[key].default
  661. def __setitem__(self, key: str, value: typing.Any):
  662. self._config[key].value = value
  663. super().__setitem__(key, value)
  664. def set_no_raise(self, key: str, value: typing.Any):
  665. self._config[key].set_no_raise(value)
  666. super().__setitem__(key, value)
  667. def __getitem__(self, key: str) -> typing.Any:
  668. try:
  669. return self._config[key].value
  670. except KeyError:
  671. return None
  672. def reload(self):
  673. for key in self._config:
  674. super().__setitem__(key, self._config[key].value)
  675. def change_validator(
  676. self,
  677. key: str,
  678. validator: typing.Callable[[JSONSerializable], JSONSerializable],
  679. ):
  680. self._config[key].validator = validator
  681. LibraryConfig = ModuleConfig
  682. class _Placeholder:
  683. """Placeholder to determine if the default value is going to be set"""
  684. async def wrap(func: typing.Callable[[], typing.Awaitable]) -> typing.Any:
  685. with contextlib.suppress(Exception):
  686. return await func()
  687. def syncwrap(func: typing.Callable[[], typing.Any]) -> typing.Any:
  688. with contextlib.suppress(Exception):
  689. return func()
  690. @dataclass(repr=True)
  691. class ConfigValue:
  692. option: str
  693. default: typing.Any = None
  694. doc: typing.Union[typing.Callable[[], str], str] = "No description"
  695. value: typing.Any = field(default_factory=_Placeholder)
  696. validator: typing.Optional[
  697. typing.Callable[[JSONSerializable], JSONSerializable]
  698. ] = None
  699. on_change: typing.Optional[
  700. typing.Union[typing.Callable[[], typing.Awaitable], typing.Callable]
  701. ] = None
  702. def __post_init__(self):
  703. if isinstance(self.value, _Placeholder):
  704. self.value = self.default
  705. def set_no_raise(self, value: typing.Any) -> bool:
  706. """
  707. Sets the config value w/o ValidationError being raised
  708. Should not be used uninternally
  709. """
  710. return self.__setattr__("value", value, ignore_validation=True)
  711. def __setattr__(
  712. self,
  713. key: str,
  714. value: typing.Any,
  715. *,
  716. ignore_validation: bool = False,
  717. ):
  718. if key == "value":
  719. try:
  720. value = ast.literal_eval(value)
  721. except Exception:
  722. pass
  723. # Convert value to list if it's tuple just not to mess up
  724. # with json convertations
  725. if isinstance(value, (set, tuple)):
  726. value = list(value)
  727. if isinstance(value, list):
  728. value = [
  729. item.strip() if isinstance(item, str) else item for item in value
  730. ]
  731. if self.validator is not None:
  732. if value is not None:
  733. from . import validators
  734. try:
  735. value = self.validator.validate(value)
  736. except validators.ValidationError as e:
  737. if not ignore_validation:
  738. raise e
  739. logger.debug(
  740. "Config value was broken (%s), so it was reset to %s",
  741. value,
  742. self.default,
  743. )
  744. value = self.default
  745. else:
  746. defaults = {
  747. "String": "",
  748. "Integer": 0,
  749. "Boolean": False,
  750. "Series": [],
  751. "Float": 0.0,
  752. }
  753. if self.validator.internal_id in defaults:
  754. logger.debug(
  755. "Config value was None, so it was reset to %s",
  756. defaults[self.validator.internal_id],
  757. )
  758. value = defaults[self.validator.internal_id]
  759. # This attribute will tell the `Loader` to save this value in db
  760. self._save_marker = True
  761. object.__setattr__(self, key, value)
  762. if key == "value" and not ignore_validation and callable(self.on_change):
  763. if inspect.iscoroutinefunction(self.on_change):
  764. asyncio.ensure_future(wrap(self.on_change))
  765. else:
  766. syncwrap(self.on_change)
  767. def _get_members(
  768. mod: Module,
  769. ending: str,
  770. attribute: typing.Optional[str] = None,
  771. strict: bool = False,
  772. ) -> dict:
  773. """Get method of module, which end with ending"""
  774. return {
  775. (
  776. method_name.rsplit(ending, maxsplit=1)[0]
  777. if (method_name == ending if strict else method_name.endswith(ending))
  778. else method_name
  779. ).lower(): getattr(mod, method_name)
  780. for method_name in dir(mod)
  781. if not isinstance(getattr(type(mod), method_name, None), property)
  782. and callable(getattr(mod, method_name))
  783. and (
  784. (method_name == ending if strict else method_name.endswith(ending))
  785. or attribute
  786. and getattr(getattr(mod, method_name), attribute, False)
  787. )
  788. }
  789. class CacheRecordEntity:
  790. def __init__(
  791. self,
  792. hashable_entity: "Hashable", # type: ignore # noqa: F821
  793. resolved_entity: EntityLike,
  794. exp: int,
  795. ):
  796. self.entity = copy.deepcopy(resolved_entity)
  797. self._hashable_entity = copy.deepcopy(hashable_entity)
  798. self._exp = round(time.time() + exp)
  799. self.ts = time.time()
  800. @property
  801. def expired(self) -> bool:
  802. return self._exp < time.time()
  803. def __eq__(self, record: "CacheRecordEntity") -> bool:
  804. return hash(record) == hash(self)
  805. def __hash__(self) -> int:
  806. return hash(self._hashable_entity)
  807. def __str__(self) -> str:
  808. return f"CacheRecordEntity of {self.entity}"
  809. def __repr__(self) -> str:
  810. return (
  811. f"CacheRecordEntity(entity={type(self.entity).__name__}(...),"
  812. f" exp={self._exp})"
  813. )
  814. class CacheRecordPerms:
  815. def __init__(
  816. self,
  817. hashable_entity: "Hashable", # type: ignore # noqa: F821
  818. hashable_user: "Hashable", # type: ignore # noqa: F821
  819. resolved_perms: EntityLike,
  820. exp: int,
  821. ):
  822. self.perms = copy.deepcopy(resolved_perms)
  823. self._hashable_entity = copy.deepcopy(hashable_entity)
  824. self._hashable_user = copy.deepcopy(hashable_user)
  825. self._exp = round(time.time() + exp)
  826. self.ts = time.time()
  827. @property
  828. def expired(self) -> bool:
  829. return self._exp < time.time()
  830. def __eq__(self, record: "CacheRecordPerms") -> bool:
  831. return hash(record) == hash(self)
  832. def __hash__(self) -> int:
  833. return hash((self._hashable_entity, self._hashable_user))
  834. def __str__(self) -> str:
  835. return f"CacheRecordPerms of {self.perms}"
  836. def __repr__(self) -> str:
  837. return (
  838. f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
  839. )
  840. class CacheRecordFullChannel:
  841. def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
  842. self.channel_id = channel_id
  843. self.full_channel = full_channel
  844. self._exp = round(time.time() + exp)
  845. self.ts = time.time()
  846. @property
  847. def expired(self) -> bool:
  848. return self._exp < time.time()
  849. def __eq__(self, record: "CacheRecordFullChannel") -> bool:
  850. return hash(record) == hash(self)
  851. def __hash__(self) -> int:
  852. return hash((self._hashable_entity, self._hashable_user))
  853. def __str__(self) -> str:
  854. return f"CacheRecordFullChannel of {self.channel_id}"
  855. def __repr__(self) -> str:
  856. return (
  857. f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
  858. f" exp={self._exp})"
  859. )
  860. class CacheRecordFullUser:
  861. def __init__(self, user_id: int, full_user: UserFull, exp: int):
  862. self.user_id = user_id
  863. self.full_user = full_user
  864. self._exp = round(time.time() + exp)
  865. self.ts = time.time()
  866. @property
  867. def expired(self) -> bool:
  868. return self._exp < time.time()
  869. def __eq__(self, record: "CacheRecordFullUser") -> bool:
  870. return hash(record) == hash(self)
  871. def __hash__(self) -> int:
  872. return hash((self._hashable_entity, self._hashable_user))
  873. def __str__(self) -> str:
  874. return f"CacheRecordFullUser of {self.user_id}"
  875. def __repr__(self) -> str:
  876. return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
  877. def get_commands(mod: Module) -> dict:
  878. """Introspect the module to get its commands"""
  879. return _get_members(mod, "cmd", "is_command")
  880. def get_inline_handlers(mod: Module) -> dict:
  881. """Introspect the module to get its inline handlers"""
  882. return _get_members(mod, "_inline_handler", "is_inline_handler")
  883. def get_callback_handlers(mod: Module) -> dict:
  884. """Introspect the module to get its callback handlers"""
  885. return _get_members(mod, "_callback_handler", "is_callback_handler")
  886. def get_watchers(mod: Module) -> dict:
  887. """Introspect the module to get its watchers"""
  888. return _get_members(
  889. mod,
  890. "watcher",
  891. "is_watcher",
  892. strict=True,
  893. )