tl_cache.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import copy
  9. import inspect
  10. import time
  11. import logging
  12. import typing
  13. from telethon import TelegramClient
  14. from telethon.hints import EntityLike
  15. from telethon.utils import is_list_like
  16. from telethon.network import MTProtoSender
  17. from telethon.tl.tlobject import TLRequest
  18. from telethon.tl.functions.channels import GetFullChannelRequest
  19. from telethon.tl.functions.users import GetFullUserRequest
  20. from telethon.tl.types import ChannelFull, UserFull
  21. from .types import (
  22. CacheRecord,
  23. CacheRecordPerms,
  24. CacheRecordFullChannel,
  25. CacheRecordFullUser,
  26. Module,
  27. )
  28. logger = logging.getLogger(__name__)
  29. def hashable(value: typing.Any) -> bool:
  30. """
  31. Determine whether `value` can be hashed.
  32. This is a copy of `collections.abc.Hashable` from Python 3.8.
  33. """
  34. try:
  35. hash(value)
  36. except TypeError:
  37. return False
  38. return True
  39. class CustomTelegramClient(TelegramClient):
  40. def __init__(self, *args, **kwargs):
  41. super().__init__(*args, **kwargs)
  42. self._hikka_entity_cache = {}
  43. self._hikka_perms_cache = {}
  44. self._hikka_fullchannel_cache = {}
  45. self._hikka_fulluser_cache = {}
  46. self.__forbidden_constructors = []
  47. async def force_get_entity(self, *args, **kwargs):
  48. """Forcefully makes a request to Telegram to get the entity."""
  49. return await self.get_entity(*args, force=True, **kwargs)
  50. async def get_entity(
  51. self,
  52. entity: EntityLike,
  53. exp: int = 5 * 60,
  54. force: bool = False,
  55. ):
  56. """
  57. Gets the entity and cache it
  58. :param entity: Entity to fetch
  59. :param exp: Expiration time of the cache record and maximum time of already cached record
  60. :param force: Whether to force refresh the cache (make API request)
  61. :return: :obj:`Entity`
  62. """
  63. # Will be used to determine, which client caused logging messages
  64. # parsed via inspect.stack()
  65. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  66. if not hashable(entity):
  67. try:
  68. hashable_entity = next(
  69. getattr(entity, attr)
  70. for attr in {"user_id", "channel_id", "chat_id", "id"}
  71. if getattr(entity, attr, None)
  72. )
  73. except StopIteration:
  74. logger.debug(
  75. "Can't parse hashable from entity %s, using legacy resolve",
  76. entity,
  77. )
  78. return await TelegramClient.get_entity(self, entity)
  79. else:
  80. hashable_entity = entity
  81. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  82. hashable_entity = int(str(hashable_entity)[4:])
  83. if (
  84. not force
  85. and hashable_entity
  86. and hashable_entity in self._hikka_entity_cache
  87. and (
  88. not exp
  89. or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  90. )
  91. ):
  92. logger.debug(
  93. "Using cached entity %s (%s)",
  94. entity,
  95. type(self._hikka_entity_cache[hashable_entity].entity).__name__,
  96. )
  97. return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
  98. resolved_entity = await TelegramClient.get_entity(self, entity)
  99. if resolved_entity:
  100. cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
  101. self._hikka_entity_cache[hashable_entity] = cache_record
  102. logger.debug("Saved hashable_entity %s to cache", hashable_entity)
  103. if getattr(resolved_entity, "id", None):
  104. logger.debug("Saved resolved_entity id %s to cache", resolved_entity.id)
  105. self._hikka_entity_cache[resolved_entity.id] = cache_record
  106. if getattr(resolved_entity, "username", None):
  107. logger.debug(
  108. "Saved resolved_entity username @%s to cache",
  109. resolved_entity.username,
  110. )
  111. self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
  112. self._hikka_entity_cache[resolved_entity.username] = cache_record
  113. return copy.deepcopy(resolved_entity)
  114. async def get_perms_cached(
  115. self,
  116. entity: EntityLike,
  117. user: typing.Optional[EntityLike] = None,
  118. exp: int = 5 * 60,
  119. force: bool = False,
  120. ):
  121. """
  122. Gets the permissions of the user in the entity and cache it
  123. :param entity: Entity to fetch
  124. :param user: User to fetch
  125. :param exp: Expiration time of the cache record and maximum time of already cached record
  126. :param force: Whether to force refresh the cache (make API request)
  127. :return: :obj:`ChatPermissions`
  128. """
  129. # Will be used to determine, which client caused logging messages
  130. # parsed via inspect.stack()
  131. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  132. entity = await self.get_entity(entity)
  133. user = await self.get_entity(user) if user else None
  134. if not hashable(entity) or not hashable(user):
  135. try:
  136. hashable_entity = next(
  137. getattr(entity, attr)
  138. for attr in {"user_id", "channel_id", "chat_id", "id"}
  139. if getattr(entity, attr, None)
  140. )
  141. except StopIteration:
  142. logger.debug(
  143. "Can't parse hashable from entity %s, using legacy method",
  144. entity,
  145. )
  146. return await self.get_permissions(entity, user)
  147. try:
  148. hashable_user = next(
  149. getattr(user, attr)
  150. for attr in {"user_id", "channel_id", "chat_id", "id"}
  151. if getattr(user, attr, None)
  152. )
  153. except StopIteration:
  154. logger.debug(
  155. "Can't parse hashable from user %s, using legacy method",
  156. user,
  157. )
  158. return await self.get_permissions(entity, user)
  159. else:
  160. hashable_entity = entity
  161. hashable_user = user
  162. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  163. hashable_entity = int(str(hashable_entity)[4:])
  164. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  165. hashable_user = int(str(hashable_user)[4:])
  166. if (
  167. not force
  168. and hashable_entity
  169. and hashable_user
  170. and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
  171. and (
  172. not exp
  173. or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  174. > time.time()
  175. )
  176. ):
  177. logger.debug("Using cached perms %s (%s)", hashable_entity, hashable_user)
  178. return copy.deepcopy(
  179. self._hikka_perms_cache[hashable_entity][hashable_user].perms
  180. )
  181. resolved_perms = await self.get_permissions(entity, user)
  182. if resolved_perms:
  183. cache_record = CacheRecordPerms(
  184. hashable_entity,
  185. hashable_user,
  186. resolved_perms,
  187. exp,
  188. )
  189. self._hikka_perms_cache.setdefault(hashable_entity, {})[
  190. hashable_user
  191. ] = cache_record
  192. logger.debug("Saved hashable_entity %s perms to cache", hashable_entity)
  193. def save_user(key: typing.Union[str, int]):
  194. nonlocal self, cache_record, user, hashable_user
  195. if getattr(user, "id", None):
  196. self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
  197. if getattr(user, "username", None):
  198. self._hikka_perms_cache.setdefault(key, {})[
  199. f"@{user.username}"
  200. ] = cache_record
  201. self._hikka_perms_cache.setdefault(key, {})[
  202. user.username
  203. ] = cache_record
  204. if getattr(entity, "id", None):
  205. logger.debug("Saved resolved_entity id %s perms to cache", entity.id)
  206. save_user(entity.id)
  207. if getattr(entity, "username", None):
  208. logger.debug(
  209. "Saved resolved_entity username @%s perms to cache",
  210. entity.username,
  211. )
  212. save_user(f"@{entity.username}")
  213. save_user(entity.username)
  214. return copy.deepcopy(resolved_perms)
  215. async def get_fullchannel(
  216. self,
  217. entity: EntityLike,
  218. exp: int = 300,
  219. force: bool = False,
  220. ) -> ChannelFull:
  221. """
  222. Gets the FullChannelRequest and cache it
  223. :param entity: Channel to fetch ChannelFull of
  224. :param exp: Expiration time of the cache record and maximum time of already cached record
  225. :param force: Whether to force refresh the cache (make API request)
  226. :return: :obj:`ChannelFull`
  227. """
  228. if not hashable(entity):
  229. try:
  230. hashable_entity = next(
  231. getattr(entity, attr)
  232. for attr in {"channel_id", "chat_id", "id"}
  233. if getattr(entity, attr, None)
  234. )
  235. except StopIteration:
  236. logger.debug(
  237. "Can't parse hashable from entity %s, using legacy fullchannel"
  238. " request",
  239. entity,
  240. )
  241. return await self(GetFullChannelRequest(channel=entity))
  242. else:
  243. hashable_entity = entity
  244. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  245. hashable_entity = int(str(hashable_entity)[4:])
  246. if (
  247. not force
  248. and self._hikka_fullchannel_cache.get(hashable_entity)
  249. and not self._hikka_fullchannel_cache[hashable_entity].expired()
  250. and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  251. ):
  252. return self._hikka_fullchannel_cache[hashable_entity].full_channel
  253. result = await self(GetFullChannelRequest(channel=entity))
  254. self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  255. hashable_entity,
  256. result,
  257. exp,
  258. )
  259. return result
  260. async def get_fulluser(
  261. self,
  262. entity: EntityLike,
  263. exp: int = 300,
  264. force: bool = False,
  265. ) -> UserFull:
  266. """
  267. Gets the FullUserRequest and cache it
  268. :param entity: User to fetch UserFull of
  269. :param exp: Expiration time of the cache record and maximum time of already cached record
  270. :param force: Whether to force refresh the cache (make API request)
  271. :return: :obj:`UserFull`
  272. """
  273. if not hashable(entity):
  274. try:
  275. hashable_entity = next(
  276. getattr(entity, attr)
  277. for attr in {"user_id", "chat_id", "id"}
  278. if getattr(entity, attr, None)
  279. )
  280. except StopIteration:
  281. logger.debug(
  282. "Can't parse hashable from entity %s, using legacy fulluser"
  283. " request",
  284. entity,
  285. )
  286. return await self(GetFullUserRequest(entity))
  287. else:
  288. hashable_entity = entity
  289. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  290. hashable_entity = int(str(hashable_entity)[4:])
  291. if (
  292. not force
  293. and self._hikka_fulluser_cache.get(hashable_entity)
  294. and not self._hikka_fulluser_cache[hashable_entity].expired()
  295. and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  296. ):
  297. return self._hikka_fulluser_cache[hashable_entity].full_user
  298. result = await self(GetFullUserRequest(entity))
  299. self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  300. hashable_entity,
  301. result,
  302. exp,
  303. )
  304. return result
  305. async def _call(
  306. self,
  307. sender: MTProtoSender,
  308. request: TLRequest,
  309. ordered: bool = False,
  310. flood_sleep_threshold: typing.Optional[int] = None,
  311. ):
  312. """
  313. Calls the given request and handles user-side forbidden constructors
  314. :param sender: Sender to use
  315. :param request: Request to send
  316. :param ordered: Whether to send the request ordered
  317. :param flood_sleep_threshold: Flood sleep threshold
  318. :return: The result of the request
  319. """
  320. # ⚠️⚠️ WARNING! ⚠️⚠️
  321. # If you are a module developer, and you'll try to bypass this protection to
  322. # force user join your channel, you will be added to SCAM modules
  323. # list and you will be banned from Hikka federation.
  324. # Let USER decide, which channel he will follow. Do not be so petty
  325. # I hope, you understood me.
  326. # Thank you
  327. if not self.__forbidden_constructors:
  328. return await TelegramClient._call(
  329. self,
  330. sender,
  331. request,
  332. ordered,
  333. flood_sleep_threshold,
  334. )
  335. not_tuple = False
  336. if not is_list_like(request):
  337. not_tuple = True
  338. request = (request,)
  339. new_request = []
  340. for item in request:
  341. if item.CONSTRUCTOR_ID in self.__forbidden_constructors and next(
  342. (
  343. frame_info.frame.f_locals["self"]
  344. for frame_info in inspect.stack()
  345. if hasattr(frame_info, "frame")
  346. and hasattr(frame_info.frame, "f_locals")
  347. and isinstance(frame_info.frame.f_locals, dict)
  348. and "self" in frame_info.frame.f_locals
  349. and isinstance(frame_info.frame.f_locals["self"], Module)
  350. and not getattr(
  351. frame_info.frame.f_locals["self"], "__origin__", ""
  352. ).startswith("<core")
  353. ),
  354. None,
  355. ):
  356. logger.debug(
  357. "🎉 I protected you from unintented %s (%s)!",
  358. item.__class__.__name__,
  359. item,
  360. )
  361. continue
  362. new_request += [item]
  363. if not new_request:
  364. return
  365. return await TelegramClient._call(
  366. self,
  367. sender,
  368. new_request[0] if not_tuple else tuple(new_request),
  369. ordered,
  370. flood_sleep_threshold,
  371. )
  372. def forbid_constructor(self, constructor: int):
  373. """
  374. Forbids the given constructor to be called
  375. :param constructor: Constructor id to forbid
  376. """
  377. self.__forbidden_constructors.extend([constructor])
  378. self.__forbidden_constructors = list(set(self.__forbidden_constructors))
  379. def forbid_constructors(self, constructors: list):
  380. """
  381. Forbids the given constructors to be called.
  382. All existing forbidden constructors will be removed
  383. :param constructors: Constructor ids to forbid
  384. """
  385. self.__forbidden_constructors = list(set(constructors))