dragon.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import builtins
  8. import importlib
  9. import inspect
  10. import io
  11. import logging
  12. import os
  13. import subprocess
  14. import sys
  15. import traceback
  16. import typing
  17. from io import BytesIO
  18. from sys import version_info
  19. import git
  20. try:
  21. from PIL import Image
  22. except Exception:
  23. PIP_AVAILABLE = False
  24. else:
  25. PIP_AVAILABLE = True
  26. from pyrogram import Client, errors, types
  27. from .. import version
  28. from .._internal import restart
  29. from ..database import Database
  30. from ..tl_cache import CustomTelegramClient
  31. from ..types import JSONSerializable
  32. DRAGON_EMOJI = "<emoji document_id=5375360100196163660>🐲</emoji>"
  33. native_import = builtins.__import__
  34. logger = logging.getLogger(__name__)
  35. class ImportLock:
  36. # This is used to ensure, that dynamic dragon import passes in
  37. # the right client. Whenever one of the clients attempts to install
  38. # dragon-specific module, it must aqcuire the `import_lock` or wait
  39. # until it's released. Then set the `current_client` variable to self.
  40. def __init__(self):
  41. self.lock = asyncio.Lock()
  42. self.current_client = None
  43. def __call__(self, client: CustomTelegramClient) -> typing.ContextManager:
  44. self.current_client = client
  45. return self
  46. async def __aenter__(self):
  47. await self.lock.acquire()
  48. async def __aexit__(self, *_):
  49. self.current_client = None
  50. self.lock.release()
  51. import_lock = ImportLock()
  52. class DragonDb:
  53. def __init__(self, db: Database):
  54. self.db = db
  55. def get(
  56. self,
  57. module: str,
  58. variable: str,
  59. default: typing.Optional[typing.Any] = None,
  60. ) -> JSONSerializable:
  61. return self.db.get(f"dragon.{module}", variable, default)
  62. def set(self, module: str, variable: str, value: JSONSerializable) -> bool:
  63. return self.db.set(f"dragon.{module}", variable, value)
  64. def get_collection(self, module: str) -> typing.Dict[str, JSONSerializable]:
  65. return dict.get(self.db, f"dragon.{module}", {})
  66. def remove(self, module: str, variable: str) -> JSONSerializable:
  67. if f"dragon.{module}" not in self.db:
  68. return None
  69. return self.db[f"dragon.{module}"].pop(variable, None)
  70. def close(self):
  71. pass
  72. class DragonDbWrapper:
  73. def __init__(self, db: DragonDb):
  74. self.db = db
  75. class Notifier:
  76. def __init__(self, modules_help: "ModulesHelpDict"):
  77. self.modules_help = modules_help
  78. self.cache = {}
  79. def __enter__(self):
  80. self.modules_help.notifier = self
  81. return self
  82. def __exit__(self, *_):
  83. self.modules_help.notifier = None
  84. def notify(self, key: str, value: dict):
  85. self.cache[key] = value
  86. @property
  87. def modname(self):
  88. return next(iter(self.cache), "Unknown")
  89. @property
  90. def commands(self):
  91. return {
  92. key.split()[0]: (
  93. ((key.split()[1] + " - ") if len(key.split()) > 1 else "") + value
  94. )
  95. for key, value in self.cache[self.modname].items()
  96. }
  97. class ModulesHelpDict(dict):
  98. def __init__(self, *args, **kwargs):
  99. super().__init__(*args, **kwargs)
  100. self.notifier = None
  101. def append(self, obj: dict):
  102. # convert help from old to new type
  103. module_name = list(obj.keys())[0]
  104. cmds = obj[module_name]
  105. commands = {}
  106. for cmd in cmds:
  107. cmd_name = list(cmd.keys())[0]
  108. cmd_desc = cmd[cmd_name]
  109. commands[cmd_name] = cmd_desc
  110. self[module_name] = commands
  111. def __setitem__(self, key, value):
  112. super().__setitem__(key, value)
  113. if self.notifier:
  114. self.notifier.notify(key, value)
  115. def get_notifier(self) -> Notifier:
  116. return Notifier(self)
  117. class DragonMisc:
  118. def __init__(self, client: CustomTelegramClient):
  119. self.client = client
  120. self.modules_help = ModulesHelpDict()
  121. self.requirements_list = []
  122. self.python_version = f"{version_info[0]}.{version_info[1]}.{version_info[2]}"
  123. self.gitrepo = git.Repo(
  124. path=os.path.abspath(os.path.join(os.path.dirname(version.__file__), ".."))
  125. )
  126. self.commits_since_tag = 0
  127. self.userbot_version = version.__version__
  128. @property
  129. def prefix(self):
  130. return self.client.loader.get_prefix("dragon")
  131. class DragonConfig:
  132. def __init__(self, client: CustomTelegramClient):
  133. self.api_id = client.api_id
  134. self.api_hash = client.api_hash
  135. self.db_type = ""
  136. self.db_url = ""
  137. self.db_name = ""
  138. self.test_server = False
  139. class DragonScripts:
  140. def __init__(self, misc: DragonMisc):
  141. self.interact_with_to_delete = []
  142. self.misc = misc
  143. @staticmethod
  144. def text(message: types.Message):
  145. return message.text if message.text else message.caption
  146. @staticmethod
  147. def restart():
  148. restart()
  149. @staticmethod
  150. def format_exc(e: Exception, hint: str = None):
  151. traceback.print_exc()
  152. if isinstance(e, errors.RPCError):
  153. return (
  154. "<b>Telegram API error!</b>\n"
  155. f"<code>[{e.CODE} {e.ID or e.NAME}] - {e.MESSAGE}</code>"
  156. )
  157. hint_text = f"\n\n<b>Hint: {hint}</b>" if hint else ""
  158. return f"<b>Error!</b>\n<code>{e.__class__.__name__}: {e}</code>" + hint_text
  159. @staticmethod
  160. def with_reply(func):
  161. async def wrapped(client: Client, message: types.Message):
  162. if not message.reply_to_message:
  163. await message.edit("<b>Reply to message is required</b>")
  164. else:
  165. return await func(client, message)
  166. return wrapped
  167. async def interact_with(self, message: types.Message) -> types.Message:
  168. """
  169. Check history with bot and return bot's response
  170. Example:
  171. .. code-block:: python
  172. bot_msg = await interact_with(await bot.send_message("@BotFather", "/start"))
  173. :param message: already sent message to bot
  174. :return: bot's response
  175. """
  176. await asyncio.sleep(1)
  177. # noinspection PyProtectedMember
  178. response = await message._client.get_history(message.chat.id, limit=1)
  179. seconds_waiting = 0
  180. while response[0].from_user.is_self:
  181. seconds_waiting += 1
  182. if seconds_waiting >= 5:
  183. raise RuntimeError("bot didn't answer in 5 seconds")
  184. await asyncio.sleep(1)
  185. # noinspection PyProtectedMember
  186. response = await message._client.get_history(message.chat.id, limit=1)
  187. self.interact_with_to_delete.append(message.message_id)
  188. self.interact_with_to_delete.append(response[0].message_id)
  189. return response[0]
  190. def format_module_help(self, module_name: str):
  191. commands = self.misc.modules_help[module_name]
  192. help_text = (
  193. f"{DRAGON_EMOJI} <b>Help for"
  194. f" </b><code>{module_name}</code>\n\n<b>Usage:</b>\n"
  195. )
  196. for command, desc in commands.items():
  197. cmd = command.split(maxsplit=1)
  198. args = " <code>" + cmd[1] + "</code>" if len(cmd) > 1 else ""
  199. help_text += (
  200. f"<code>{self.misc.prefix}{cmd[0]}</code>{args} — <i>{desc}</i>\n"
  201. )
  202. return help_text
  203. def format_small_module_help(self, module_name: str):
  204. commands = self.misc.modules_help[module_name]
  205. help_text = (
  206. f"{DRAGON_EMOJI }<b>Help for </b><code>{module_name}</code>\n\n<b>Commands"
  207. " list:</b>\n"
  208. )
  209. for command, desc in commands.items():
  210. cmd = command.split(maxsplit=1)
  211. args = " <code>" + cmd[1] + "</code>" if len(cmd) > 1 else ""
  212. help_text += f"<code>{self.misc.prefix}{cmd[0]}</code>{args}\n"
  213. help_text += (
  214. f"\n<b>Get full usage:</b> <code>{self.misc.prefix}help"
  215. f" {module_name}</code>"
  216. )
  217. return help_text
  218. def import_library(
  219. self,
  220. library_name: str,
  221. package_name: typing.Optional[str] = None,
  222. ):
  223. """
  224. Loads a library, or installs it in ImportError case
  225. :param library_name: library name (import example...)
  226. :param package_name: package name in PyPi (pip install example)
  227. :return: loaded module
  228. """
  229. if package_name is None:
  230. package_name = library_name
  231. self.misc.requirements_list.append(package_name)
  232. try:
  233. return importlib.import_module(library_name)
  234. except ImportError:
  235. completed = subprocess.run(
  236. [
  237. sys.executable,
  238. "-m",
  239. "pip",
  240. "install",
  241. "--upgrade",
  242. "-q",
  243. "--disable-pip-version-check",
  244. "--no-warn-script-location",
  245. *(
  246. ["--user"]
  247. if "PIP_TARGET" not in os.environ
  248. and "VIRTUAL_ENV" not in os.environ
  249. else []
  250. ),
  251. package_name,
  252. ],
  253. check=False,
  254. )
  255. if completed.returncode != 0:
  256. raise RuntimeError(
  257. f"Failed to install library {package_name} (pip exited with code"
  258. f" {completed.returncode})"
  259. )
  260. return importlib.import_module(library_name)
  261. @staticmethod
  262. def resize_image(
  263. input_img: typing.Union[bytes, io.BytesIO],
  264. output: typing.Optional[io.BytesIO] = None,
  265. img_type: str = "PNG",
  266. ) -> io.BytesIO:
  267. if not PIP_AVAILABLE:
  268. raise RuntimeError("Install Pillow with: pip install Pillow -U")
  269. if output is None:
  270. output = BytesIO()
  271. output.name = f"sticker.{img_type.lower()}"
  272. with Image.open(input_img) as img:
  273. # We used to use thumbnail(size) here, but it returns with a *max* dimension of 512,512
  274. # rather than making one side exactly 512 so we have to calculate dimensions manually :(
  275. if img.width == img.height:
  276. size = (512, 512)
  277. elif img.width < img.height:
  278. size = (max(512 * img.width // img.height, 1), 512)
  279. else:
  280. size = (512, max(512 * img.height // img.width, 1))
  281. img.resize(size).save(output, img_type)
  282. return output
  283. class DragonCompat:
  284. def __init__(self, client: CustomTelegramClient):
  285. self.client = client
  286. self.db = DragonDbWrapper(DragonDb(client.loader.db))
  287. self.misc = DragonMisc(client)
  288. self.scripts = DragonScripts(self.misc)
  289. self.config = DragonConfig(client)
  290. def patched_import(name: str, *args, **kwargs):
  291. caller = inspect.currentframe().f_back
  292. caller_name = caller.f_globals.get("__name__")
  293. if name.startswith("utils") and caller_name.startswith("dragon"):
  294. if not import_lock.current_client:
  295. raise RuntimeError("Dragon client not set")
  296. if name.split(".", maxsplit=1)[1] in {"db", "misc", "scripts", "config"}:
  297. return getattr(
  298. import_lock.current_client.dragon_compat,
  299. name.split(".", maxsplit=1)[1],
  300. )
  301. raise ImportError(f"Unknown module {name}")
  302. return native_import(name, *args, **kwargs)
  303. builtins.__import__ = patched_import
  304. def apply_compat(client: CustomTelegramClient):
  305. client.dragon_compat = DragonCompat(client)