tl_cache.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import copy
  7. import inspect
  8. import logging
  9. import time
  10. import typing
  11. from hikkatl import TelegramClient
  12. from hikkatl.errors.rpcerrorlist import TopicDeletedError
  13. from hikkatl.hints import EntityLike
  14. from hikkatl.network import MTProtoSender
  15. from hikkatl.tl.functions.channels import GetFullChannelRequest
  16. from hikkatl.tl.functions.users import GetFullUserRequest
  17. from hikkatl.tl.tlobject import TLRequest
  18. from hikkatl.tl.types import (
  19. ChannelFull,
  20. Message,
  21. Updates,
  22. UpdatesCombined,
  23. UpdateShort,
  24. UserFull,
  25. )
  26. from hikkatl.utils import is_list_like
  27. from .types import (
  28. CacheRecordEntity,
  29. CacheRecordFullChannel,
  30. CacheRecordFullUser,
  31. CacheRecordPerms,
  32. Module,
  33. )
  34. logger = logging.getLogger(__name__)
  35. def hashable(value: typing.Any) -> bool:
  36. """
  37. Determine whether `value` can be hashed.
  38. This is a copy of `collections.abc.Hashable` from Python 3.8.
  39. """
  40. try:
  41. hash(value)
  42. except TypeError:
  43. return False
  44. return True
  45. class CustomTelegramClient(TelegramClient):
  46. def __init__(self, *args, **kwargs):
  47. super().__init__(*args, **kwargs)
  48. self._hikka_entity_cache: typing.Dict[
  49. typing.Union[str, int],
  50. CacheRecordEntity,
  51. ] = {}
  52. self._hikka_perms_cache: typing.Dict[
  53. typing.Union[str, int],
  54. CacheRecordPerms,
  55. ] = {}
  56. self._hikka_fullchannel_cache: typing.Dict[
  57. typing.Union[str, int],
  58. CacheRecordFullChannel,
  59. ] = {}
  60. self._hikka_fulluser_cache: typing.Dict[
  61. typing.Union[str, int],
  62. CacheRecordFullUser,
  63. ] = {}
  64. self._forbidden_constructors: typing.List[int] = []
  65. self._raw_updates_processor: typing.Optional[
  66. typing.Callable[
  67. [typing.Union[Updates, UpdatesCombined, UpdateShort]],
  68. typing.Any,
  69. ]
  70. ] = None
  71. @property
  72. def raw_updates_processor(self) -> typing.Optional[callable]:
  73. return self._raw_updates_processor
  74. @raw_updates_processor.setter
  75. def raw_updates_processor(self, value: callable):
  76. if self._raw_updates_processor is not None:
  77. raise ValueError("raw_updates_processor is already set")
  78. if not callable(value):
  79. raise ValueError("raw_updates_processor must be callable")
  80. self._raw_updates_processor = value
  81. @property
  82. def hikka_entity_cache(self) -> typing.Dict[int, CacheRecordEntity]:
  83. return self._hikka_entity_cache
  84. @property
  85. def hikka_perms_cache(self) -> typing.Dict[int, CacheRecordPerms]:
  86. return self._hikka_perms_cache
  87. @property
  88. def hikka_fullchannel_cache(self) -> typing.Dict[int, CacheRecordFullChannel]:
  89. return self._hikka_fullchannel_cache
  90. @property
  91. def hikka_fulluser_cache(self) -> typing.Dict[int, CacheRecordFullUser]:
  92. return self._hikka_fulluser_cache
  93. @property
  94. def forbidden_constructors(self) -> typing.List[str]:
  95. return self._forbidden_constructors
  96. async def force_get_entity(self, *args, **kwargs):
  97. """Forcefully makes a request to Telegram to get the entity."""
  98. return await self.get_entity(*args, force=True, **kwargs)
  99. async def get_entity(
  100. self,
  101. entity: EntityLike,
  102. exp: int = 5 * 60,
  103. force: bool = False,
  104. ):
  105. """
  106. Gets the entity and cache it
  107. :param entity: Entity to fetch
  108. :param exp: Expiration time of the cache record and maximum time of already cached record
  109. :param force: Whether to force refresh the cache (make API request)
  110. :return: :obj:`Entity`
  111. """
  112. # Will be used to determine, which client caused logging messages
  113. # parsed via inspect.stack()
  114. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # noqa: F841
  115. if not hashable(entity):
  116. try:
  117. hashable_entity = next(
  118. getattr(entity, attr)
  119. for attr in {"user_id", "channel_id", "chat_id", "id"}
  120. if getattr(entity, attr, None)
  121. )
  122. except StopIteration:
  123. logger.debug(
  124. "Can't parse hashable from entity %s, using legacy resolve",
  125. entity,
  126. )
  127. return await TelegramClient.get_entity(self, entity)
  128. else:
  129. hashable_entity = entity
  130. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  131. hashable_entity = int(str(hashable_entity)[4:])
  132. if (
  133. not force
  134. and hashable_entity
  135. and hashable_entity in self._hikka_entity_cache
  136. and (
  137. not exp
  138. or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  139. )
  140. ):
  141. logger.debug(
  142. "Using cached entity %s (%s)",
  143. entity,
  144. type(self._hikka_entity_cache[hashable_entity].entity).__name__,
  145. )
  146. return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
  147. resolved_entity = await TelegramClient.get_entity(self, entity)
  148. if resolved_entity:
  149. cache_record = CacheRecordEntity(hashable_entity, resolved_entity, exp)
  150. self._hikka_entity_cache[hashable_entity] = cache_record
  151. logger.debug("Saved hashable_entity %s to cache", hashable_entity)
  152. if getattr(resolved_entity, "id", None):
  153. logger.debug("Saved resolved_entity id %s to cache", resolved_entity.id)
  154. self._hikka_entity_cache[resolved_entity.id] = cache_record
  155. if getattr(resolved_entity, "username", None):
  156. logger.debug(
  157. "Saved resolved_entity username @%s to cache",
  158. resolved_entity.username,
  159. )
  160. self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
  161. self._hikka_entity_cache[resolved_entity.username] = cache_record
  162. return copy.deepcopy(resolved_entity)
  163. async def get_perms_cached(
  164. self,
  165. entity: EntityLike,
  166. user: typing.Optional[EntityLike] = None,
  167. exp: int = 5 * 60,
  168. force: bool = False,
  169. ):
  170. """
  171. Gets the permissions of the user in the entity and cache it
  172. :param entity: Entity to fetch
  173. :param user: User to fetch
  174. :param exp: Expiration time of the cache record and maximum time of already cached record
  175. :param force: Whether to force refresh the cache (make API request)
  176. :return: :obj:`ChatPermissions`
  177. """
  178. # Will be used to determine, which client caused logging messages
  179. # parsed via inspect.stack()
  180. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # noqa: F841
  181. entity = await self.get_entity(entity)
  182. user = await self.get_entity(user) if user else None
  183. if not hashable(entity) or not hashable(user):
  184. try:
  185. hashable_entity = next(
  186. getattr(entity, attr)
  187. for attr in {"user_id", "channel_id", "chat_id", "id"}
  188. if getattr(entity, attr, None)
  189. )
  190. except StopIteration:
  191. logger.debug(
  192. "Can't parse hashable from entity %s, using legacy method",
  193. entity,
  194. )
  195. return await self.get_permissions(entity, user)
  196. try:
  197. hashable_user = next(
  198. getattr(user, attr)
  199. for attr in {"user_id", "channel_id", "chat_id", "id"}
  200. if getattr(user, attr, None)
  201. )
  202. except StopIteration:
  203. logger.debug(
  204. "Can't parse hashable from user %s, using legacy method",
  205. user,
  206. )
  207. return await self.get_permissions(entity, user)
  208. else:
  209. hashable_entity = entity
  210. hashable_user = user
  211. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  212. hashable_entity = int(str(hashable_entity)[4:])
  213. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  214. hashable_user = int(str(hashable_user)[4:])
  215. if (
  216. not force
  217. and hashable_entity
  218. and hashable_user
  219. and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
  220. and (
  221. not exp
  222. or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  223. > time.time()
  224. )
  225. ):
  226. logger.debug("Using cached perms %s (%s)", hashable_entity, hashable_user)
  227. return copy.deepcopy(
  228. self._hikka_perms_cache[hashable_entity][hashable_user].perms
  229. )
  230. resolved_perms = await self.get_permissions(entity, user)
  231. if resolved_perms:
  232. cache_record = CacheRecordPerms(
  233. hashable_entity,
  234. hashable_user,
  235. resolved_perms,
  236. exp,
  237. )
  238. self._hikka_perms_cache.setdefault(hashable_entity, {})[
  239. hashable_user
  240. ] = cache_record
  241. logger.debug("Saved hashable_entity %s perms to cache", hashable_entity)
  242. def save_user(key: typing.Union[str, int]):
  243. nonlocal self, cache_record, user, hashable_user
  244. if getattr(user, "id", None):
  245. self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
  246. if getattr(user, "username", None):
  247. self._hikka_perms_cache.setdefault(key, {})[
  248. f"@{user.username}"
  249. ] = cache_record
  250. self._hikka_perms_cache.setdefault(key, {})[
  251. user.username
  252. ] = cache_record
  253. if getattr(entity, "id", None):
  254. logger.debug("Saved resolved_entity id %s perms to cache", entity.id)
  255. save_user(entity.id)
  256. if getattr(entity, "username", None):
  257. logger.debug(
  258. "Saved resolved_entity username @%s perms to cache",
  259. entity.username,
  260. )
  261. save_user(f"@{entity.username}")
  262. save_user(entity.username)
  263. return copy.deepcopy(resolved_perms)
  264. async def get_fullchannel(
  265. self,
  266. entity: EntityLike,
  267. exp: int = 300,
  268. force: bool = False,
  269. ) -> ChannelFull:
  270. """
  271. Gets the FullChannelRequest and cache it
  272. :param entity: Channel to fetch ChannelFull of
  273. :param exp: Expiration time of the cache record and maximum time of already cached record
  274. :param force: Whether to force refresh the cache (make API request)
  275. :return: :obj:`ChannelFull`
  276. """
  277. if not hashable(entity):
  278. try:
  279. hashable_entity = next(
  280. getattr(entity, attr)
  281. for attr in {"channel_id", "chat_id", "id"}
  282. if getattr(entity, attr, None)
  283. )
  284. except StopIteration:
  285. logger.debug(
  286. (
  287. "Can't parse hashable from entity %s, using legacy fullchannel"
  288. " request"
  289. ),
  290. entity,
  291. )
  292. return await self(GetFullChannelRequest(channel=entity))
  293. else:
  294. hashable_entity = entity
  295. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  296. hashable_entity = int(str(hashable_entity)[4:])
  297. if (
  298. not force
  299. and self._hikka_fullchannel_cache.get(hashable_entity)
  300. and not self._hikka_fullchannel_cache[hashable_entity].expired
  301. and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  302. ):
  303. return self._hikka_fullchannel_cache[hashable_entity].full_channel
  304. result = await self(GetFullChannelRequest(channel=entity))
  305. self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  306. hashable_entity,
  307. result,
  308. exp,
  309. )
  310. return result
  311. async def get_fulluser(
  312. self,
  313. entity: EntityLike,
  314. exp: int = 300,
  315. force: bool = False,
  316. ) -> UserFull:
  317. """
  318. Gets the FullUserRequest and cache it
  319. :param entity: User to fetch UserFull of
  320. :param exp: Expiration time of the cache record and maximum time of already cached record
  321. :param force: Whether to force refresh the cache (make API request)
  322. :return: :obj:`UserFull`
  323. """
  324. if not hashable(entity):
  325. try:
  326. hashable_entity = next(
  327. getattr(entity, attr)
  328. for attr in {"user_id", "chat_id", "id"}
  329. if getattr(entity, attr, None)
  330. )
  331. except StopIteration:
  332. logger.debug(
  333. (
  334. "Can't parse hashable from entity %s, using legacy fulluser"
  335. " request"
  336. ),
  337. entity,
  338. )
  339. return await self(GetFullUserRequest(entity))
  340. else:
  341. hashable_entity = entity
  342. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  343. hashable_entity = int(str(hashable_entity)[4:])
  344. if (
  345. not force
  346. and self._hikka_fulluser_cache.get(hashable_entity)
  347. and not self._hikka_fulluser_cache[hashable_entity].expired
  348. and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  349. ):
  350. return self._hikka_fulluser_cache[hashable_entity].full_user
  351. result = await self(GetFullUserRequest(entity))
  352. self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  353. hashable_entity,
  354. result,
  355. exp,
  356. )
  357. return result
  358. @staticmethod
  359. def _find_message_obj_in_frame(
  360. chat_id: int,
  361. frame: inspect.FrameInfo,
  362. ) -> typing.Optional[Message]:
  363. """
  364. Finds the message object from the frame
  365. """
  366. logger.debug("Finding message object in frame %s", frame)
  367. return next(
  368. (
  369. obj
  370. for obj in frame.frame.f_locals.values()
  371. if isinstance(obj, Message)
  372. and getattr(obj.reply_to, "forum_topic", False)
  373. and chat_id == getattr(obj.peer_id, "channel_id", None)
  374. ),
  375. None,
  376. )
  377. async def _find_message_obj_in_stack(
  378. self,
  379. chat: EntityLike,
  380. stack: typing.List[inspect.FrameInfo],
  381. ) -> typing.Optional[Message]:
  382. """
  383. Finds the message object from the stack
  384. """
  385. chat_id = (await self.get_entity(chat, exp=0)).id
  386. logger.debug("Finding message object in stack for chat %s", chat_id)
  387. return next(
  388. (
  389. self._find_message_obj_in_frame(chat_id, frame_info)
  390. for frame_info in stack
  391. if self._find_message_obj_in_frame(chat_id, frame_info)
  392. ),
  393. None,
  394. )
  395. async def _find_topic_in_stack(
  396. self,
  397. chat: EntityLike,
  398. stack: typing.List[inspect.FrameInfo],
  399. ) -> typing.Optional[Message]:
  400. """
  401. Finds the message object from the stack
  402. """
  403. message = await self._find_message_obj_in_stack(chat, stack)
  404. return (
  405. (message.reply_to.reply_to_top_id or message.reply_to.reply_to_msg_id)
  406. if message
  407. else None
  408. )
  409. async def _topic_guesser(
  410. self,
  411. native_method: typing.Callable[..., typing.Awaitable[Message]],
  412. stack: typing.List[inspect.FrameInfo],
  413. *args,
  414. **kwargs,
  415. ):
  416. no_retry = kwargs.pop("_topic_no_retry", False)
  417. try:
  418. return await native_method(self, *args, **kwargs)
  419. except TopicDeletedError:
  420. if no_retry:
  421. raise
  422. logger.debug("Topic deleted, trying to guess topic id")
  423. topic = await self._find_topic_in_stack(args[0], stack)
  424. logger.debug("Guessed topic id: %s", topic)
  425. if not topic:
  426. raise
  427. kwargs["reply_to"] = topic
  428. kwargs["_topic_no_retry"] = True
  429. return await self._topic_guesser(native_method, stack, *args, **kwargs)
  430. async def send_file(self, *args, **kwargs) -> Message:
  431. return await self._topic_guesser(
  432. TelegramClient.send_file,
  433. inspect.stack(),
  434. *args,
  435. **kwargs,
  436. )
  437. async def send_message(self, *args, **kwargs) -> Message:
  438. return await self._topic_guesser(
  439. TelegramClient.send_message,
  440. inspect.stack(),
  441. *args,
  442. **kwargs,
  443. )
  444. async def _call(
  445. self,
  446. sender: MTProtoSender,
  447. request: TLRequest,
  448. ordered: bool = False,
  449. flood_sleep_threshold: typing.Optional[int] = None,
  450. ):
  451. """
  452. Calls the given request and handles user-side forbidden constructors
  453. :param sender: Sender to use
  454. :param request: Request to send
  455. :param ordered: Whether to send the request ordered
  456. :param flood_sleep_threshold: Flood sleep threshold
  457. :return: The result of the request
  458. """
  459. # ⚠️⚠️ WARNING! ⚠️⚠️
  460. # If you are a module developer, and you'll try to bypass this protection to
  461. # force user join your channel, you will be added to SCAM modules
  462. # list and you will be banned from Hikka federation.
  463. # Let USER decide, which channel he will follow. Do not be so petty
  464. # I hope, you understood me.
  465. # Thank you
  466. not_tuple = False
  467. if not is_list_like(request):
  468. not_tuple = True
  469. request = (request,)
  470. new_request = []
  471. for item in request:
  472. if item.CONSTRUCTOR_ID in self._forbidden_constructors and next(
  473. (
  474. frame_info.frame.f_locals["self"]
  475. for frame_info in inspect.stack()
  476. if hasattr(frame_info, "frame")
  477. and hasattr(frame_info.frame, "f_locals")
  478. and isinstance(frame_info.frame.f_locals, dict)
  479. and "self" in frame_info.frame.f_locals
  480. and isinstance(frame_info.frame.f_locals["self"], Module)
  481. and not getattr(
  482. frame_info.frame.f_locals["self"], "__origin__", ""
  483. ).startswith("<core")
  484. ),
  485. None,
  486. ):
  487. logger.debug(
  488. "🎉 I protected you from unintented %s (%s)!",
  489. item.__class__.__name__,
  490. item,
  491. )
  492. continue
  493. new_request += [item]
  494. if not new_request:
  495. return
  496. return await TelegramClient._call(
  497. self,
  498. sender,
  499. new_request[0] if not_tuple else tuple(new_request),
  500. ordered,
  501. flood_sleep_threshold,
  502. )
  503. def forbid_constructor(self, constructor: int):
  504. """
  505. Forbids the given constructor to be called
  506. :param constructor: Constructor id to forbid
  507. """
  508. self._forbidden_constructors.extend([constructor])
  509. self._forbidden_constructors = list(set(self._forbidden_constructors))
  510. def forbid_constructors(self, constructors: list):
  511. """
  512. Forbids the given constructors to be called.
  513. All existing forbidden constructors will be removed
  514. :param constructors: Constructor ids to forbid
  515. """
  516. self._forbidden_constructors = list(set(constructors))
  517. def _handle_update(
  518. self: "CustomTelegramClient",
  519. update: typing.Union[Updates, UpdatesCombined, UpdateShort],
  520. ):
  521. if self._raw_updates_processor is not None:
  522. self._raw_updates_processor(update)
  523. super()._handle_update(update)