pyroproxy.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import copy
  8. import datetime
  9. import functools
  10. import logging
  11. import re
  12. import typing
  13. import telethon
  14. from pyrogram import Client as PyroClient
  15. from pyrogram import errors as pyro_errors
  16. from pyrogram import raw
  17. from .. import translations, utils
  18. from ..tl_cache import CustomTelegramClient
  19. from ..version import __version__
  20. PROXY = {
  21. pyro_object: telethon.tl.alltlobjects.tlobjects[constructor_id]
  22. for constructor_id, pyro_object in raw.all.objects.items()
  23. if constructor_id in telethon.tl.alltlobjects.tlobjects
  24. }
  25. REVERSED_PROXY = {
  26. **{tl_object: pyro_object for pyro_object, tl_object in PROXY.items()},
  27. **{
  28. tl_object: raw.all.objects[tl_object.CONSTRUCTOR_ID]
  29. for _, tl_object in utils.iter_attrs(telethon.tl.custom)
  30. if getattr(tl_object, "CONSTRUCTOR_ID", None) in raw.all.objects
  31. },
  32. }
  33. PYRO_ERRORS = {
  34. cls.ID: cls
  35. for _, cls in utils.iter_attrs(pyro_errors)
  36. if hasattr(cls, "ID") and issubclass(cls, pyro_errors.RPCError)
  37. }
  38. logger = logging.getLogger(__name__)
  39. class PyroProxyClient(PyroClient):
  40. def __init__(self, tl_client: CustomTelegramClient):
  41. self.tl_client = tl_client
  42. super().__init__(
  43. **{
  44. "name": "proxied_pyrogram_client",
  45. "api_id": tl_client.api_id,
  46. "api_hash": tl_client.api_hash,
  47. "app_version": (
  48. f"Hikka v{__version__[0]}.{__version__[1]}.{__version__[2]}"
  49. ),
  50. "lang_code": tl_client.loader.db.get(
  51. translations.__name__, "lang", "en"
  52. ).split()[0],
  53. "in_memory": True,
  54. "phone_number": tl_client.hikka_me.phone,
  55. }
  56. )
  57. # We need to set this to True so pyro thinks he's connected
  58. # even tho it's not. We don't need to connect to Telegram as
  59. # we redirect all requests to telethon's handler
  60. self.is_connected = True
  61. self.conn = tl_client.session._conn
  62. async def start(self):
  63. self.me = await self.get_me()
  64. self.tl_client.raw_updates_processor = self._on_event
  65. def _on_event(
  66. self,
  67. event: typing.Union[
  68. telethon.tl.types.Updates,
  69. telethon.tl.types.UpdatesCombined,
  70. telethon.tl.types.UpdateShort,
  71. ],
  72. ):
  73. asyncio.ensure_future(self.handle_updates(self._tl2pyro(event)))
  74. async def invoke(
  75. self,
  76. query: raw.core.TLObject,
  77. *args,
  78. **kwargs,
  79. ) -> typing.Union[typing.List[raw.core.TLObject], raw.core.TLObject]:
  80. logger.debug(
  81. "Running Pyrogram's invoke of %s with Telethon proxying",
  82. query.__class__.__name__,
  83. )
  84. if self.tl_client.session.takeout_id:
  85. query = raw.functions.InvokeWithTakeout(
  86. takeout_id=self.tl_client.session.takeout_id,
  87. query=query,
  88. )
  89. try:
  90. r = await self.tl_client(self._pyro2tl(query))
  91. except telethon.errors.rpcerrorlist.RPCError as e:
  92. raise self._tl_error2pyro(e)
  93. return self._tl2pyro(r)
  94. @staticmethod
  95. def _tl_error2pyro(
  96. error: telethon.errors.rpcerrorlist.RPCError,
  97. ) -> pyro_errors.RPCError:
  98. rpc = (
  99. re.sub(r"([A-Z])", r"_\1", error.__class__.__name__)
  100. .upper()
  101. .strip("_")
  102. .rsplit("ERROR", maxsplit=1)[0]
  103. .strip("_")
  104. )
  105. if rpc in PYRO_ERRORS:
  106. return PYRO_ERRORS[rpc]()
  107. return PYRO_ERRORS.get(
  108. f"{rpc}_X",
  109. PYRO_ERRORS.get(
  110. f"{rpc}_0",
  111. pyro_errors.RPCError,
  112. ),
  113. )()
  114. def _pyro2tl(self, pyro_obj: raw.core.TLObject) -> telethon.tl.TLObject:
  115. """
  116. Recursively converts Pyrogram TLObjects to Telethon TLObjects (methods,
  117. types and everything else, which is in tl schema)
  118. :param pyro_obj: Pyrogram TLObject
  119. :return: Telethon TLObject
  120. :raises TypeError: if it's not possible to convert Pyrogram TLObject to
  121. Telethon TLObject
  122. """
  123. pyro_obj = self._convert(pyro_obj)
  124. if isinstance(pyro_obj, list):
  125. return [self._pyro2tl(i) for i in pyro_obj]
  126. if isinstance(pyro_obj, dict):
  127. return {k: self._pyro2tl(v) for k, v in pyro_obj.items()}
  128. if not isinstance(pyro_obj, raw.core.TLObject):
  129. return pyro_obj
  130. if type(pyro_obj) not in PROXY:
  131. raise TypeError(
  132. f"Cannot convert Pyrogram's {type(pyro_obj)} to Telethon TLObject"
  133. )
  134. return PROXY[type(pyro_obj)](
  135. **{
  136. attr: self._pyro2tl(getattr(pyro_obj, attr))
  137. for attr in pyro_obj.__slots__
  138. }
  139. )
  140. def _tl2pyro(self, tl_obj: telethon.tl.TLObject) -> raw.core.TLObject:
  141. """
  142. Recursively converts Telethon TLObjects to Pyrogram TLObjects (methods,
  143. types and everything else, which is in tl schema)
  144. :param tl_obj: Telethon TLObject
  145. :return: Pyrogram TLObject
  146. :raises TypeError: if it's not possible to convert Telethon TLObject to
  147. Pyrogram TLObject
  148. """
  149. tl_obj = self._convert(tl_obj)
  150. if (
  151. isinstance(getattr(tl_obj, "from_id", None), int)
  152. and tl_obj.from_id
  153. and hasattr(tl_obj, "sender_id")
  154. ):
  155. tl_obj = copy.copy(tl_obj)
  156. tl_obj.from_id = telethon.tl.types.PeerUser(tl_obj.sender_id)
  157. if isinstance(tl_obj, list):
  158. return [self._tl2pyro(i) for i in tl_obj]
  159. if isinstance(tl_obj, dict):
  160. return {k: self._tl2pyro(v) for k, v in tl_obj.items()}
  161. if isinstance(tl_obj, int) and str(tl_obj).startswith("-100"):
  162. return int(str(tl_obj)[4:])
  163. if not isinstance(tl_obj, telethon.tl.TLObject):
  164. return tl_obj
  165. if type(tl_obj) not in REVERSED_PROXY:
  166. raise TypeError(
  167. f"Cannot convert Telethon's {type(tl_obj)} to Pyrogram TLObject"
  168. )
  169. hints = typing.get_type_hints(REVERSED_PROXY[type(tl_obj)].__init__) or {}
  170. return REVERSED_PROXY[type(tl_obj)](
  171. **{
  172. attr: self._convert_types(
  173. hints.get(attr),
  174. self._tl2pyro(getattr(tl_obj, attr)),
  175. )
  176. for attr in REVERSED_PROXY[type(tl_obj)].__slots__
  177. }
  178. )
  179. @staticmethod
  180. def _get_origin(hint: typing.Any) -> typing.Any:
  181. try:
  182. return typing.get_origin(hint)
  183. except Exception:
  184. return None
  185. def _convert_types(self, hint: typing.Any, value: typing.Any) -> typing.Any:
  186. if not value and (
  187. self._get_origin(hint) in {typing.List, list}
  188. or (
  189. self._get_origin(hint) is typing.Union
  190. and any(
  191. self._get_origin(i) in {typing.List, list} for i in hint.__args__
  192. )
  193. )
  194. ):
  195. return []
  196. return value
  197. def _convert(self, obj: typing.Any) -> typing.Any:
  198. if isinstance(obj, datetime.datetime):
  199. return int(obj.timestamp())
  200. return obj
  201. async def resolve_peer(self, *args, **kwargs):
  202. return self._tl2pyro(await self.tl_client.get_entity(*args, **kwargs))
  203. async def fetch_peers(
  204. self,
  205. peers: typing.List[
  206. typing.Union[raw.types.User, raw.types.Chat, raw.types.Channel]
  207. ],
  208. ) -> bool:
  209. return any(getattr(peer, "min", False) for peer in peers)
  210. @property
  211. def iter_chat_members(self):
  212. return self.get_chat_members
  213. @property
  214. def iter_dialogs(self):
  215. return self.get_dialogs
  216. @property
  217. def iter_history(self):
  218. return self.get_chat_history
  219. @property
  220. def iter_profile_photos(self):
  221. return self.get_chat_photos
  222. async def save_file(
  223. self,
  224. path: typing.Union[str, typing.BinaryIO],
  225. file_id: int = None,
  226. file_part: int = 0,
  227. progress: typing.Callable = None,
  228. progress_args: tuple = (),
  229. ):
  230. return self._tl2pyro(
  231. await self.tl_client.upload_file(
  232. path,
  233. part_size_kb=file_part,
  234. progress_callback=(
  235. functools.partial(progress, *progress_args)
  236. if progress and callable(progress)
  237. else None
  238. ),
  239. )
  240. )