tl_cache.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import copy
  7. import inspect
  8. import logging
  9. import time
  10. import typing
  11. from telethon import TelegramClient
  12. from telethon.errors.rpcerrorlist import TopicDeletedError
  13. from telethon.hints import EntityLike
  14. from telethon.network import MTProtoSender
  15. from telethon.tl.functions.channels import GetFullChannelRequest
  16. from telethon.tl.functions.users import GetFullUserRequest
  17. from telethon.tl.tlobject import TLRequest
  18. from telethon.tl.types import (
  19. ChannelFull,
  20. Message,
  21. Updates,
  22. UpdatesCombined,
  23. UpdateShort,
  24. UserFull,
  25. )
  26. from telethon.utils import is_list_like
  27. from .types import (
  28. CacheRecordEntity,
  29. CacheRecordFullChannel,
  30. CacheRecordFullUser,
  31. CacheRecordPerms,
  32. Module,
  33. )
  34. logger = logging.getLogger(__name__)
  35. def hashable(value: typing.Any) -> bool:
  36. """
  37. Determine whether `value` can be hashed.
  38. This is a copy of `collections.abc.Hashable` from Python 3.8.
  39. """
  40. try:
  41. hash(value)
  42. except TypeError:
  43. return False
  44. return True
  45. class CustomTelegramClient(TelegramClient):
  46. def __init__(self, *args, **kwargs):
  47. super().__init__(*args, **kwargs)
  48. self._hikka_entity_cache = {}
  49. self._hikka_perms_cache = {}
  50. self._hikka_fullchannel_cache = {}
  51. self._hikka_fulluser_cache = {}
  52. self.__forbidden_constructors = []
  53. self.raw_updates_processor = None # Will be monkeypatched by pyro proxy
  54. async def force_get_entity(self, *args, **kwargs):
  55. """Forcefully makes a request to Telegram to get the entity."""
  56. return await self.get_entity(*args, force=True, **kwargs)
  57. async def get_entity(
  58. self,
  59. entity: EntityLike,
  60. exp: int = 5 * 60,
  61. force: bool = False,
  62. ):
  63. """
  64. Gets the entity and cache it
  65. :param entity: Entity to fetch
  66. :param exp: Expiration time of the cache record and maximum time of already cached record
  67. :param force: Whether to force refresh the cache (make API request)
  68. :return: :obj:`Entity`
  69. """
  70. # Will be used to determine, which client caused logging messages
  71. # parsed via inspect.stack()
  72. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  73. if not hashable(entity):
  74. try:
  75. hashable_entity = next(
  76. getattr(entity, attr)
  77. for attr in {"user_id", "channel_id", "chat_id", "id"}
  78. if getattr(entity, attr, None)
  79. )
  80. except StopIteration:
  81. logger.debug(
  82. "Can't parse hashable from entity %s, using legacy resolve",
  83. entity,
  84. )
  85. return await TelegramClient.get_entity(self, entity)
  86. else:
  87. hashable_entity = entity
  88. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  89. hashable_entity = int(str(hashable_entity)[4:])
  90. if (
  91. not force
  92. and hashable_entity
  93. and hashable_entity in self._hikka_entity_cache
  94. and (
  95. not exp
  96. or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  97. )
  98. ):
  99. logger.debug(
  100. "Using cached entity %s (%s)",
  101. entity,
  102. type(self._hikka_entity_cache[hashable_entity].entity).__name__,
  103. )
  104. return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
  105. resolved_entity = await TelegramClient.get_entity(self, entity)
  106. if resolved_entity:
  107. cache_record = CacheRecordEntity(hashable_entity, resolved_entity, exp)
  108. self._hikka_entity_cache[hashable_entity] = cache_record
  109. logger.debug("Saved hashable_entity %s to cache", hashable_entity)
  110. if getattr(resolved_entity, "id", None):
  111. logger.debug("Saved resolved_entity id %s to cache", resolved_entity.id)
  112. self._hikka_entity_cache[resolved_entity.id] = cache_record
  113. if getattr(resolved_entity, "username", None):
  114. logger.debug(
  115. "Saved resolved_entity username @%s to cache",
  116. resolved_entity.username,
  117. )
  118. self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
  119. self._hikka_entity_cache[resolved_entity.username] = cache_record
  120. return copy.deepcopy(resolved_entity)
  121. async def get_perms_cached(
  122. self,
  123. entity: EntityLike,
  124. user: typing.Optional[EntityLike] = None,
  125. exp: int = 5 * 60,
  126. force: bool = False,
  127. ):
  128. """
  129. Gets the permissions of the user in the entity and cache it
  130. :param entity: Entity to fetch
  131. :param user: User to fetch
  132. :param exp: Expiration time of the cache record and maximum time of already cached record
  133. :param force: Whether to force refresh the cache (make API request)
  134. :return: :obj:`ChatPermissions`
  135. """
  136. # Will be used to determine, which client caused logging messages
  137. # parsed via inspect.stack()
  138. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  139. entity = await self.get_entity(entity)
  140. user = await self.get_entity(user) if user else None
  141. if not hashable(entity) or not hashable(user):
  142. try:
  143. hashable_entity = next(
  144. getattr(entity, attr)
  145. for attr in {"user_id", "channel_id", "chat_id", "id"}
  146. if getattr(entity, attr, None)
  147. )
  148. except StopIteration:
  149. logger.debug(
  150. "Can't parse hashable from entity %s, using legacy method",
  151. entity,
  152. )
  153. return await self.get_permissions(entity, user)
  154. try:
  155. hashable_user = next(
  156. getattr(user, attr)
  157. for attr in {"user_id", "channel_id", "chat_id", "id"}
  158. if getattr(user, attr, None)
  159. )
  160. except StopIteration:
  161. logger.debug(
  162. "Can't parse hashable from user %s, using legacy method",
  163. user,
  164. )
  165. return await self.get_permissions(entity, user)
  166. else:
  167. hashable_entity = entity
  168. hashable_user = user
  169. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  170. hashable_entity = int(str(hashable_entity)[4:])
  171. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  172. hashable_user = int(str(hashable_user)[4:])
  173. if (
  174. not force
  175. and hashable_entity
  176. and hashable_user
  177. and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
  178. and (
  179. not exp
  180. or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  181. > time.time()
  182. )
  183. ):
  184. logger.debug("Using cached perms %s (%s)", hashable_entity, hashable_user)
  185. return copy.deepcopy(
  186. self._hikka_perms_cache[hashable_entity][hashable_user].perms
  187. )
  188. resolved_perms = await self.get_permissions(entity, user)
  189. if resolved_perms:
  190. cache_record = CacheRecordPerms(
  191. hashable_entity,
  192. hashable_user,
  193. resolved_perms,
  194. exp,
  195. )
  196. self._hikka_perms_cache.setdefault(hashable_entity, {})[
  197. hashable_user
  198. ] = cache_record
  199. logger.debug("Saved hashable_entity %s perms to cache", hashable_entity)
  200. def save_user(key: typing.Union[str, int]):
  201. nonlocal self, cache_record, user, hashable_user
  202. if getattr(user, "id", None):
  203. self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
  204. if getattr(user, "username", None):
  205. self._hikka_perms_cache.setdefault(key, {})[
  206. f"@{user.username}"
  207. ] = cache_record
  208. self._hikka_perms_cache.setdefault(key, {})[
  209. user.username
  210. ] = cache_record
  211. if getattr(entity, "id", None):
  212. logger.debug("Saved resolved_entity id %s perms to cache", entity.id)
  213. save_user(entity.id)
  214. if getattr(entity, "username", None):
  215. logger.debug(
  216. "Saved resolved_entity username @%s perms to cache",
  217. entity.username,
  218. )
  219. save_user(f"@{entity.username}")
  220. save_user(entity.username)
  221. return copy.deepcopy(resolved_perms)
  222. async def get_fullchannel(
  223. self,
  224. entity: EntityLike,
  225. exp: int = 300,
  226. force: bool = False,
  227. ) -> ChannelFull:
  228. """
  229. Gets the FullChannelRequest and cache it
  230. :param entity: Channel to fetch ChannelFull of
  231. :param exp: Expiration time of the cache record and maximum time of already cached record
  232. :param force: Whether to force refresh the cache (make API request)
  233. :return: :obj:`ChannelFull`
  234. """
  235. if not hashable(entity):
  236. try:
  237. hashable_entity = next(
  238. getattr(entity, attr)
  239. for attr in {"channel_id", "chat_id", "id"}
  240. if getattr(entity, attr, None)
  241. )
  242. except StopIteration:
  243. logger.debug(
  244. "Can't parse hashable from entity %s, using legacy fullchannel"
  245. " request",
  246. entity,
  247. )
  248. return await self(GetFullChannelRequest(channel=entity))
  249. else:
  250. hashable_entity = entity
  251. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  252. hashable_entity = int(str(hashable_entity)[4:])
  253. if (
  254. not force
  255. and self._hikka_fullchannel_cache.get(hashable_entity)
  256. and not self._hikka_fullchannel_cache[hashable_entity].expired
  257. and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  258. ):
  259. return self._hikka_fullchannel_cache[hashable_entity].full_channel
  260. result = await self(GetFullChannelRequest(channel=entity))
  261. self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  262. hashable_entity,
  263. result,
  264. exp,
  265. )
  266. return result
  267. async def get_fulluser(
  268. self,
  269. entity: EntityLike,
  270. exp: int = 300,
  271. force: bool = False,
  272. ) -> UserFull:
  273. """
  274. Gets the FullUserRequest and cache it
  275. :param entity: User to fetch UserFull of
  276. :param exp: Expiration time of the cache record and maximum time of already cached record
  277. :param force: Whether to force refresh the cache (make API request)
  278. :return: :obj:`UserFull`
  279. """
  280. if not hashable(entity):
  281. try:
  282. hashable_entity = next(
  283. getattr(entity, attr)
  284. for attr in {"user_id", "chat_id", "id"}
  285. if getattr(entity, attr, None)
  286. )
  287. except StopIteration:
  288. logger.debug(
  289. "Can't parse hashable from entity %s, using legacy fulluser"
  290. " request",
  291. entity,
  292. )
  293. return await self(GetFullUserRequest(entity))
  294. else:
  295. hashable_entity = entity
  296. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  297. hashable_entity = int(str(hashable_entity)[4:])
  298. if (
  299. not force
  300. and self._hikka_fulluser_cache.get(hashable_entity)
  301. and not self._hikka_fulluser_cache[hashable_entity].expired
  302. and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  303. ):
  304. return self._hikka_fulluser_cache[hashable_entity].full_user
  305. result = await self(GetFullUserRequest(entity))
  306. self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  307. hashable_entity,
  308. result,
  309. exp,
  310. )
  311. return result
  312. @staticmethod
  313. def _find_message_obj_in_frame(
  314. chat_id: int,
  315. frame: inspect.FrameInfo,
  316. ) -> typing.Optional[Message]:
  317. """
  318. Finds the message object from the frame
  319. """
  320. logger.debug("Finding message object in frame %s", frame)
  321. return next(
  322. (
  323. obj
  324. for obj in frame.frame.f_locals.values()
  325. if isinstance(obj, Message)
  326. and getattr(obj.reply_to, "forum_topic", False)
  327. and chat_id == getattr(obj.peer_id, "channel_id", None)
  328. ),
  329. None,
  330. )
  331. async def _find_message_obj_in_stack(
  332. self,
  333. chat: EntityLike,
  334. stack: typing.List[inspect.FrameInfo],
  335. ) -> typing.Optional[Message]:
  336. """
  337. Finds the message object from the stack
  338. """
  339. chat_id = (await self.get_entity(chat, exp=0)).id
  340. logger.debug("Finding message object in stack for chat %s", chat_id)
  341. return next(
  342. (
  343. self._find_message_obj_in_frame(chat_id, frame_info)
  344. for frame_info in stack
  345. if self._find_message_obj_in_frame(chat_id, frame_info)
  346. ),
  347. None,
  348. )
  349. async def _find_topic_in_stack(
  350. self,
  351. chat: EntityLike,
  352. stack: typing.List[inspect.FrameInfo],
  353. ) -> typing.Optional[Message]:
  354. """
  355. Finds the message object from the stack
  356. """
  357. message = await self._find_message_obj_in_stack(chat, stack)
  358. return (
  359. (message.reply_to.reply_to_top_id or message.reply_to.reply_to_msg_id)
  360. if message
  361. else None
  362. )
  363. async def _topic_guesser(
  364. self,
  365. native_method: typing.Callable[..., typing.Awaitable[Message]],
  366. stack: typing.List[inspect.FrameInfo],
  367. *args,
  368. **kwargs,
  369. ):
  370. no_retry = kwargs.pop("_topic_no_retry", False)
  371. try:
  372. return await native_method(self, *args, **kwargs)
  373. except TopicDeletedError:
  374. if no_retry:
  375. raise
  376. logger.debug("Topic deleted, trying to guess topic id")
  377. topic = await self._find_topic_in_stack(args[0], stack)
  378. logger.debug("Guessed topic id: %s", topic)
  379. if not topic:
  380. raise
  381. kwargs["reply_to"] = topic
  382. kwargs["_topic_no_retry"] = True
  383. return await self._topic_guesser(native_method, stack, *args, **kwargs)
  384. async def send_file(self, *args, **kwargs) -> Message:
  385. return await self._topic_guesser(
  386. TelegramClient.send_file,
  387. inspect.stack(),
  388. *args,
  389. **kwargs,
  390. )
  391. async def send_message(self, *args, **kwargs) -> Message:
  392. return await self._topic_guesser(
  393. TelegramClient.send_message,
  394. inspect.stack(),
  395. *args,
  396. **kwargs,
  397. )
  398. async def _call(
  399. self,
  400. sender: MTProtoSender,
  401. request: TLRequest,
  402. ordered: bool = False,
  403. flood_sleep_threshold: typing.Optional[int] = None,
  404. ):
  405. """
  406. Calls the given request and handles user-side forbidden constructors
  407. :param sender: Sender to use
  408. :param request: Request to send
  409. :param ordered: Whether to send the request ordered
  410. :param flood_sleep_threshold: Flood sleep threshold
  411. :return: The result of the request
  412. """
  413. # ⚠️⚠️ WARNING! ⚠️⚠️
  414. # If you are a module developer, and you'll try to bypass this protection to
  415. # force user join your channel, you will be added to SCAM modules
  416. # list and you will be banned from Hikka federation.
  417. # Let USER decide, which channel he will follow. Do not be so petty
  418. # I hope, you understood me.
  419. # Thank you
  420. if not self.__forbidden_constructors:
  421. return await TelegramClient._call(
  422. self,
  423. sender,
  424. request,
  425. ordered,
  426. flood_sleep_threshold,
  427. )
  428. not_tuple = False
  429. if not is_list_like(request):
  430. not_tuple = True
  431. request = (request,)
  432. new_request = []
  433. for item in request:
  434. if item.CONSTRUCTOR_ID in self.__forbidden_constructors and next(
  435. (
  436. frame_info.frame.f_locals["self"]
  437. for frame_info in inspect.stack()
  438. if hasattr(frame_info, "frame")
  439. and hasattr(frame_info.frame, "f_locals")
  440. and isinstance(frame_info.frame.f_locals, dict)
  441. and "self" in frame_info.frame.f_locals
  442. and isinstance(frame_info.frame.f_locals["self"], Module)
  443. and not getattr(
  444. frame_info.frame.f_locals["self"], "__origin__", ""
  445. ).startswith("<core")
  446. ),
  447. None,
  448. ):
  449. logger.debug(
  450. "🎉 I protected you from unintented %s (%s)!",
  451. item.__class__.__name__,
  452. item,
  453. )
  454. continue
  455. new_request += [item]
  456. if not new_request:
  457. return
  458. return await TelegramClient._call(
  459. self,
  460. sender,
  461. new_request[0] if not_tuple else tuple(new_request),
  462. ordered,
  463. flood_sleep_threshold,
  464. )
  465. def forbid_constructor(self, constructor: int):
  466. """
  467. Forbids the given constructor to be called
  468. :param constructor: Constructor id to forbid
  469. """
  470. self.__forbidden_constructors.extend([constructor])
  471. self.__forbidden_constructors = list(set(self.__forbidden_constructors))
  472. def forbid_constructors(self, constructors: list):
  473. """
  474. Forbids the given constructors to be called.
  475. All existing forbidden constructors will be removed
  476. :param constructors: Constructor ids to forbid
  477. """
  478. self.__forbidden_constructors = list(set(constructors))
  479. def _handle_update(
  480. self: "CustomTelegramClient",
  481. update: typing.Union[Updates, UpdatesCombined, UpdateShort],
  482. ):
  483. if self.raw_updates_processor is not None:
  484. self.raw_updates_processor(update)
  485. super()._handle_update(update)