tl_cache.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import copy
  7. import inspect
  8. import logging
  9. import time
  10. import typing
  11. from hikkatl import TelegramClient
  12. from hikkatl import __name__ as __base_name__
  13. from hikkatl import helpers
  14. from hikkatl._updates import ChannelState, Entity, EntityType, SessionState
  15. from hikkatl.errors import RPCError
  16. from hikkatl.errors.rpcerrorlist import TopicDeletedError
  17. from hikkatl.hints import EntityLike
  18. from hikkatl.network import MTProtoSender
  19. from hikkatl.tl import functions
  20. from hikkatl.tl.alltlobjects import LAYER
  21. from hikkatl.tl.functions.channels import GetFullChannelRequest
  22. from hikkatl.tl.functions.users import GetFullUserRequest
  23. from hikkatl.tl.tlobject import TLRequest
  24. from hikkatl.tl.types import (
  25. ChannelFull,
  26. Message,
  27. Pong,
  28. Updates,
  29. UpdatesCombined,
  30. UpdateShort,
  31. UserFull,
  32. )
  33. from hikkatl.utils import is_list_like
  34. from .types import (
  35. CacheRecordEntity,
  36. CacheRecordFullChannel,
  37. CacheRecordFullUser,
  38. CacheRecordPerms,
  39. Module,
  40. )
  41. logger = logging.getLogger(__name__)
  42. def hashable(value: typing.Any) -> bool:
  43. """
  44. Determine whether `value` can be hashed.
  45. This is a copy of `collections.abc.Hashable` from Python 3.8.
  46. """
  47. try:
  48. hash(value)
  49. except TypeError:
  50. return False
  51. return True
  52. class CustomTelegramClient(TelegramClient):
  53. def __init__(self, *args, **kwargs):
  54. super().__init__(*args, **kwargs)
  55. self._hikka_entity_cache: typing.Dict[
  56. typing.Union[str, int],
  57. CacheRecordEntity,
  58. ] = {}
  59. self._hikka_perms_cache: typing.Dict[
  60. typing.Union[str, int],
  61. CacheRecordPerms,
  62. ] = {}
  63. self._hikka_fullchannel_cache: typing.Dict[
  64. typing.Union[str, int],
  65. CacheRecordFullChannel,
  66. ] = {}
  67. self._hikka_fulluser_cache: typing.Dict[
  68. typing.Union[str, int],
  69. CacheRecordFullUser,
  70. ] = {}
  71. self._forbidden_constructors: typing.List[int] = []
  72. self._raw_updates_processor: typing.Optional[
  73. typing.Callable[
  74. [typing.Union[Updates, UpdatesCombined, UpdateShort]],
  75. typing.Any,
  76. ]
  77. ] = None
  78. async def connect(self, unix_socket_path: typing.Optional[str] = None):
  79. if self.session is None:
  80. raise ValueError(
  81. "TelegramClient instance cannot be reused after logging out"
  82. )
  83. if self._loop is None:
  84. self._loop = helpers.get_running_loop()
  85. elif self._loop != helpers.get_running_loop():
  86. raise RuntimeError(
  87. "The asyncio event loop must not change after connection (see the FAQ"
  88. " for details)"
  89. )
  90. connection = self._connection(
  91. self.session.server_address,
  92. self.session.port,
  93. self.session.dc_id,
  94. loggers=self._log,
  95. proxy=self._proxy,
  96. local_addr=self._local_addr,
  97. )
  98. if unix_socket_path is not None:
  99. connection.set_unix_socket(unix_socket_path)
  100. if not await self._sender.connect(connection):
  101. # We don't want to init or modify anything if we were already connected
  102. return
  103. self.session.auth_key = self._sender.auth_key
  104. self.session.save()
  105. if self._catch_up:
  106. ss = SessionState(0, 0, False, 0, 0, 0, 0, None)
  107. cs = []
  108. for entity_id, state in self.session.get_update_states():
  109. if entity_id == 0:
  110. # TODO current session doesn't store self-user info but adding that is breaking on downstream session impls
  111. ss = SessionState(
  112. 0,
  113. 0,
  114. False,
  115. state.pts,
  116. state.qts,
  117. int(state.date.timestamp()),
  118. state.seq,
  119. None,
  120. )
  121. else:
  122. cs.append(ChannelState(entity_id, state.pts))
  123. self._message_box.load(ss, cs)
  124. for state in cs:
  125. try:
  126. entity = self.session.get_input_entity(state.channel_id)
  127. except ValueError:
  128. self._log[__name__].warning(
  129. "No access_hash in cache for channel %s, will not catch up",
  130. state.channel_id,
  131. )
  132. else:
  133. self._mb_entity_cache.put(
  134. Entity(
  135. EntityType.CHANNEL, entity.channel_id, entity.access_hash
  136. )
  137. )
  138. self._init_request.query = functions.help.GetConfigRequest()
  139. req = self._init_request
  140. if self._no_updates:
  141. req = functions.InvokeWithoutUpdatesRequest(req)
  142. await self._sender.send(functions.InvokeWithLayerRequest(LAYER, req))
  143. if self._message_box.is_empty():
  144. me = await self.get_me()
  145. if me:
  146. await self._on_login(
  147. me
  148. ) # also calls GetState to initialize the MessageBox
  149. self._updates_handle = self.loop.create_task(self._update_loop())
  150. self._keepalive_handle = self.loop.create_task(self._keepalive_loop())
  151. @property
  152. def raw_updates_processor(self) -> typing.Optional[callable]:
  153. return self._raw_updates_processor
  154. @raw_updates_processor.setter
  155. def raw_updates_processor(self, value: callable):
  156. if self._raw_updates_processor is not None:
  157. raise ValueError("raw_updates_processor is already set")
  158. if not callable(value):
  159. raise ValueError("raw_updates_processor must be callable")
  160. self._raw_updates_processor = value
  161. @property
  162. def hikka_entity_cache(self) -> typing.Dict[int, CacheRecordEntity]:
  163. return self._hikka_entity_cache
  164. @property
  165. def hikka_perms_cache(self) -> typing.Dict[int, CacheRecordPerms]:
  166. return self._hikka_perms_cache
  167. @property
  168. def hikka_fullchannel_cache(self) -> typing.Dict[int, CacheRecordFullChannel]:
  169. return self._hikka_fullchannel_cache
  170. @property
  171. def hikka_fulluser_cache(self) -> typing.Dict[int, CacheRecordFullUser]:
  172. return self._hikka_fulluser_cache
  173. @property
  174. def forbidden_constructors(self) -> typing.List[str]:
  175. return self._forbidden_constructors
  176. async def force_get_entity(self, *args, **kwargs):
  177. """Forcefully makes a request to Telegram to get the entity."""
  178. return await self.get_entity(*args, force=True, **kwargs)
  179. async def get_entity(
  180. self,
  181. entity: EntityLike,
  182. exp: int = 5 * 60,
  183. force: bool = False,
  184. ):
  185. """
  186. Gets the entity and cache it
  187. :param entity: Entity to fetch
  188. :param exp: Expiration time of the cache record and maximum time of already cached record
  189. :param force: Whether to force refresh the cache (make API request)
  190. :return: :obj:`Entity`
  191. """
  192. # Will be used to determine, which client caused logging messages
  193. # parsed via inspect.stack()
  194. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # noqa: F841
  195. if not hashable(entity):
  196. try:
  197. hashable_entity = next(
  198. getattr(entity, attr)
  199. for attr in {"user_id", "channel_id", "chat_id", "id"}
  200. if getattr(entity, attr, None)
  201. )
  202. except StopIteration:
  203. logger.debug(
  204. "Can't parse hashable from entity %s, using legacy resolve",
  205. entity,
  206. )
  207. return await super().get_entity(entity)
  208. else:
  209. hashable_entity = entity
  210. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  211. hashable_entity = int(str(hashable_entity)[4:])
  212. if (
  213. not force
  214. and hashable_entity
  215. and hashable_entity in self._hikka_entity_cache
  216. and (
  217. not exp
  218. or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  219. )
  220. ):
  221. logger.debug(
  222. "Using cached entity %s (%s)",
  223. entity,
  224. type(self._hikka_entity_cache[hashable_entity].entity).__name__,
  225. )
  226. return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
  227. resolved_entity = await super().get_entity(entity)
  228. if resolved_entity:
  229. cache_record = CacheRecordEntity(hashable_entity, resolved_entity, exp)
  230. self._hikka_entity_cache[hashable_entity] = cache_record
  231. logger.debug("Saved hashable_entity %s to cache", hashable_entity)
  232. if getattr(resolved_entity, "id", None):
  233. logger.debug("Saved resolved_entity id %s to cache", resolved_entity.id)
  234. self._hikka_entity_cache[resolved_entity.id] = cache_record
  235. if getattr(resolved_entity, "username", None):
  236. logger.debug(
  237. "Saved resolved_entity username @%s to cache",
  238. resolved_entity.username,
  239. )
  240. self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
  241. self._hikka_entity_cache[resolved_entity.username] = cache_record
  242. return copy.deepcopy(resolved_entity)
  243. async def get_perms_cached(
  244. self,
  245. entity: EntityLike,
  246. user: typing.Optional[EntityLike] = None,
  247. exp: int = 5 * 60,
  248. force: bool = False,
  249. ):
  250. """
  251. Gets the permissions of the user in the entity and cache it
  252. :param entity: Entity to fetch
  253. :param user: User to fetch
  254. :param exp: Expiration time of the cache record and maximum time of already cached record
  255. :param force: Whether to force refresh the cache (make API request)
  256. :return: :obj:`ChatPermissions`
  257. """
  258. # Will be used to determine, which client caused logging messages
  259. # parsed via inspect.stack()
  260. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # noqa: F841
  261. entity = await self.get_entity(entity)
  262. user = await self.get_entity(user) if user else None
  263. if not hashable(entity) or not hashable(user):
  264. try:
  265. hashable_entity = next(
  266. getattr(entity, attr)
  267. for attr in {"user_id", "channel_id", "chat_id", "id"}
  268. if getattr(entity, attr, None)
  269. )
  270. except StopIteration:
  271. logger.debug(
  272. "Can't parse hashable from entity %s, using legacy method",
  273. entity,
  274. )
  275. return await self.get_permissions(entity, user)
  276. try:
  277. hashable_user = next(
  278. getattr(user, attr)
  279. for attr in {"user_id", "channel_id", "chat_id", "id"}
  280. if getattr(user, attr, None)
  281. )
  282. except StopIteration:
  283. logger.debug(
  284. "Can't parse hashable from user %s, using legacy method",
  285. user,
  286. )
  287. return await self.get_permissions(entity, user)
  288. else:
  289. hashable_entity = entity
  290. hashable_user = user
  291. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  292. hashable_entity = int(str(hashable_entity)[4:])
  293. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  294. hashable_user = int(str(hashable_user)[4:])
  295. if (
  296. not force
  297. and hashable_entity
  298. and hashable_user
  299. and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
  300. and (
  301. not exp
  302. or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  303. > time.time()
  304. )
  305. ):
  306. logger.debug("Using cached perms %s (%s)", hashable_entity, hashable_user)
  307. return copy.deepcopy(
  308. self._hikka_perms_cache[hashable_entity][hashable_user].perms
  309. )
  310. resolved_perms = await self.get_permissions(entity, user)
  311. if resolved_perms:
  312. cache_record = CacheRecordPerms(
  313. hashable_entity,
  314. hashable_user,
  315. resolved_perms,
  316. exp,
  317. )
  318. self._hikka_perms_cache.setdefault(hashable_entity, {})[
  319. hashable_user
  320. ] = cache_record
  321. logger.debug("Saved hashable_entity %s perms to cache", hashable_entity)
  322. def save_user(key: typing.Union[str, int]):
  323. nonlocal self, cache_record, user, hashable_user
  324. if getattr(user, "id", None):
  325. self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
  326. if getattr(user, "username", None):
  327. self._hikka_perms_cache.setdefault(key, {})[
  328. f"@{user.username}"
  329. ] = cache_record
  330. self._hikka_perms_cache.setdefault(key, {})[
  331. user.username
  332. ] = cache_record
  333. if getattr(entity, "id", None):
  334. logger.debug("Saved resolved_entity id %s perms to cache", entity.id)
  335. save_user(entity.id)
  336. if getattr(entity, "username", None):
  337. logger.debug(
  338. "Saved resolved_entity username @%s perms to cache",
  339. entity.username,
  340. )
  341. save_user(f"@{entity.username}")
  342. save_user(entity.username)
  343. return copy.deepcopy(resolved_perms)
  344. async def get_fullchannel(
  345. self,
  346. entity: EntityLike,
  347. exp: int = 300,
  348. force: bool = False,
  349. ) -> ChannelFull:
  350. """
  351. Gets the FullChannelRequest and cache it
  352. :param entity: Channel to fetch ChannelFull of
  353. :param exp: Expiration time of the cache record and maximum time of already cached record
  354. :param force: Whether to force refresh the cache (make API request)
  355. :return: :obj:`ChannelFull`
  356. """
  357. if not hashable(entity):
  358. try:
  359. hashable_entity = next(
  360. getattr(entity, attr)
  361. for attr in {"channel_id", "chat_id", "id"}
  362. if getattr(entity, attr, None)
  363. )
  364. except StopIteration:
  365. logger.debug(
  366. (
  367. "Can't parse hashable from entity %s, using legacy fullchannel"
  368. " request"
  369. ),
  370. entity,
  371. )
  372. return await self(GetFullChannelRequest(channel=entity))
  373. else:
  374. hashable_entity = entity
  375. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  376. hashable_entity = int(str(hashable_entity)[4:])
  377. if (
  378. not force
  379. and self._hikka_fullchannel_cache.get(hashable_entity)
  380. and not self._hikka_fullchannel_cache[hashable_entity].expired
  381. and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  382. ):
  383. return self._hikka_fullchannel_cache[hashable_entity].full_channel
  384. result = await self(GetFullChannelRequest(channel=entity))
  385. self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  386. hashable_entity,
  387. result,
  388. exp,
  389. )
  390. return result
  391. async def get_fulluser(
  392. self,
  393. entity: EntityLike,
  394. exp: int = 300,
  395. force: bool = False,
  396. ) -> UserFull:
  397. """
  398. Gets the FullUserRequest and cache it
  399. :param entity: User to fetch UserFull of
  400. :param exp: Expiration time of the cache record and maximum time of already cached record
  401. :param force: Whether to force refresh the cache (make API request)
  402. :return: :obj:`UserFull`
  403. """
  404. if not hashable(entity):
  405. try:
  406. hashable_entity = next(
  407. getattr(entity, attr)
  408. for attr in {"user_id", "chat_id", "id"}
  409. if getattr(entity, attr, None)
  410. )
  411. except StopIteration:
  412. logger.debug(
  413. (
  414. "Can't parse hashable from entity %s, using legacy fulluser"
  415. " request"
  416. ),
  417. entity,
  418. )
  419. return await self(GetFullUserRequest(entity))
  420. else:
  421. hashable_entity = entity
  422. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  423. hashable_entity = int(str(hashable_entity)[4:])
  424. if (
  425. not force
  426. and self._hikka_fulluser_cache.get(hashable_entity)
  427. and not self._hikka_fulluser_cache[hashable_entity].expired
  428. and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  429. ):
  430. return self._hikka_fulluser_cache[hashable_entity].full_user
  431. result = await self(GetFullUserRequest(entity))
  432. self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  433. hashable_entity,
  434. result,
  435. exp,
  436. )
  437. return result
  438. @staticmethod
  439. def _find_message_obj_in_frame(
  440. chat_id: int,
  441. frame: inspect.FrameInfo,
  442. ) -> typing.Optional[Message]:
  443. """
  444. Finds the message object from the frame
  445. """
  446. logger.debug("Finding message object in frame %s", frame)
  447. return next(
  448. (
  449. obj
  450. for obj in frame.frame.f_locals.values()
  451. if isinstance(obj, Message)
  452. and getattr(obj.reply_to, "forum_topic", False)
  453. and chat_id == getattr(obj.peer_id, "channel_id", None)
  454. ),
  455. None,
  456. )
  457. async def _find_message_obj_in_stack(
  458. self,
  459. chat: EntityLike,
  460. stack: typing.List[inspect.FrameInfo],
  461. ) -> typing.Optional[Message]:
  462. """
  463. Finds the message object from the stack
  464. """
  465. chat_id = (await self.get_entity(chat, exp=0)).id
  466. logger.debug("Finding message object in stack for chat %s", chat_id)
  467. return next(
  468. (
  469. self._find_message_obj_in_frame(chat_id, frame_info)
  470. for frame_info in stack
  471. if self._find_message_obj_in_frame(chat_id, frame_info)
  472. ),
  473. None,
  474. )
  475. async def _find_topic_in_stack(
  476. self,
  477. chat: EntityLike,
  478. stack: typing.List[inspect.FrameInfo],
  479. ) -> typing.Optional[Message]:
  480. """
  481. Finds the message object from the stack
  482. """
  483. message = await self._find_message_obj_in_stack(chat, stack)
  484. return (
  485. (message.reply_to.reply_to_top_id or message.reply_to.reply_to_msg_id)
  486. if message
  487. else None
  488. )
  489. async def _topic_guesser(
  490. self,
  491. native_method: typing.Callable[..., typing.Awaitable[Message]],
  492. stack: typing.List[inspect.FrameInfo],
  493. *args,
  494. **kwargs,
  495. ):
  496. no_retry = kwargs.pop("_topic_no_retry", False)
  497. try:
  498. return await native_method(*args, **kwargs)
  499. except TopicDeletedError:
  500. if no_retry:
  501. raise
  502. logger.debug("Topic deleted, trying to guess topic id")
  503. topic = await self._find_topic_in_stack(args[0], stack)
  504. logger.debug("Guessed topic id: %s", topic)
  505. if not topic:
  506. raise
  507. kwargs["reply_to"] = topic
  508. kwargs["_topic_no_retry"] = True
  509. return await self._topic_guesser(native_method, stack, *args, **kwargs)
  510. async def send_file(self, *args, **kwargs) -> Message:
  511. return await self._topic_guesser(
  512. super().send_file,
  513. inspect.stack(),
  514. *args,
  515. **kwargs,
  516. )
  517. async def send_message(self, *args, **kwargs) -> Message:
  518. return await self._topic_guesser(
  519. super().send_message,
  520. inspect.stack(),
  521. *args,
  522. **kwargs,
  523. )
  524. async def _call(
  525. self,
  526. sender: MTProtoSender,
  527. request: TLRequest,
  528. ordered: bool = False,
  529. flood_sleep_threshold: typing.Optional[int] = None,
  530. ):
  531. """
  532. Calls the given request and handles user-side forbidden constructors
  533. :param sender: Sender to use
  534. :param request: Request to send
  535. :param ordered: Whether to send the request ordered
  536. :param flood_sleep_threshold: Flood sleep threshold
  537. :return: The result of the request
  538. """
  539. # ⚠️⚠️ WARNING! ⚠️⚠️
  540. # If you are a module developer, and you'll try to bypass this protection to
  541. # force user join your channel, you will be added to SCAM modules
  542. # list and you will be banned from Hikka federation.
  543. # Let USER decide, which channel he will follow. Do not be so petty
  544. # I hope, you understood me.
  545. # Thank you
  546. not_tuple = False
  547. if not is_list_like(request):
  548. not_tuple = True
  549. request = (request,)
  550. new_request = []
  551. for item in request:
  552. if item.CONSTRUCTOR_ID in self._forbidden_constructors and next(
  553. (
  554. frame_info.frame.f_locals["self"]
  555. for frame_info in inspect.stack()
  556. if hasattr(frame_info, "frame")
  557. and hasattr(frame_info.frame, "f_locals")
  558. and isinstance(frame_info.frame.f_locals, dict)
  559. and "self" in frame_info.frame.f_locals
  560. and isinstance(frame_info.frame.f_locals["self"], Module)
  561. and not getattr(
  562. frame_info.frame.f_locals["self"], "__origin__", ""
  563. ).startswith("<core")
  564. ),
  565. None,
  566. ):
  567. logger.debug(
  568. "🎉 I protected you from unintented %s (%s)!",
  569. item.__class__.__name__,
  570. item,
  571. )
  572. continue
  573. new_request += [item]
  574. if not new_request:
  575. return
  576. return await super()._call(
  577. sender,
  578. new_request[0] if not_tuple else tuple(new_request),
  579. ordered,
  580. flood_sleep_threshold,
  581. )
  582. def forbid_constructor(self, constructor: int):
  583. """
  584. Forbids the given constructor to be called
  585. :param constructor: Constructor id to forbid
  586. """
  587. self._forbidden_constructors.extend([constructor])
  588. self._forbidden_constructors = list(set(self._forbidden_constructors))
  589. def forbid_constructors(self, constructors: list):
  590. """
  591. Forbids the given constructors to be called.
  592. All existing forbidden constructors will be removed
  593. :param constructors: Constructor ids to forbid
  594. """
  595. self._forbidden_constructors = list(set(constructors))
  596. def _handle_update(
  597. self: "CustomTelegramClient",
  598. update: typing.Union[Updates, UpdatesCombined, UpdateShort],
  599. ):
  600. if self._raw_updates_processor is not None:
  601. self._raw_updates_processor(update)
  602. super()._handle_update(update)