tl_cache.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import copy
  9. import inspect
  10. import time
  11. import asyncio
  12. import logging
  13. from typing import Optional, Union
  14. from telethon import TelegramClient
  15. from telethon.hints import EntityLike
  16. from telethon.utils import is_list_like
  17. from telethon.network import MTProtoSender
  18. from telethon.tl.tlobject import TLRequest
  19. from telethon.tl.functions.channels import GetFullChannelRequest
  20. from telethon.tl.functions.users import GetFullUserRequest
  21. from telethon.tl.types import ChannelFull, UserFull
  22. from .types import (
  23. CacheRecord,
  24. CacheRecordPerms,
  25. CacheRecordFullChannel,
  26. CacheRecordFullUser,
  27. Module,
  28. )
  29. logger = logging.getLogger(__name__)
  30. def hashable(value):
  31. """Determine whether `value` can be hashed."""
  32. try:
  33. hash(value)
  34. except TypeError:
  35. return False
  36. return True
  37. class CustomTelegramClient(TelegramClient):
  38. def __init__(self, *args, **kwargs):
  39. super().__init__(*args, **kwargs)
  40. self._hikka_entity_cache = {}
  41. self._hikka_perms_cache = {}
  42. self._hikka_fullchannel_cache = {}
  43. self._hikka_fulluser_cache = {}
  44. self.__forbidden_constructors = []
  45. asyncio.ensure_future(self.cleaner())
  46. async def force_get_entity(self, *args, **kwargs):
  47. return await self.get_entity(*args, force=True, **kwargs)
  48. async def get_entity(
  49. self,
  50. entity: EntityLike,
  51. exp: int = 5 * 60,
  52. force: bool = False,
  53. ):
  54. # Will be used to determine, which client caused logging messages
  55. # parsed via inspect.stack()
  56. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  57. if not hashable(entity):
  58. try:
  59. hashable_entity = next(
  60. getattr(entity, attr)
  61. for attr in {"user_id", "channel_id", "chat_id", "id"}
  62. if getattr(entity, attr, None)
  63. )
  64. except StopIteration:
  65. logger.debug(
  66. f"Can't parse hashable from {entity=}, using legacy resolve"
  67. )
  68. return await TelegramClient.get_entity(self, entity)
  69. else:
  70. hashable_entity = entity
  71. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  72. hashable_entity = int(str(hashable_entity)[4:])
  73. if (
  74. not force
  75. and hashable_entity
  76. and hashable_entity in self._hikka_entity_cache
  77. and (
  78. not exp
  79. or self._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  80. )
  81. ):
  82. logger.debug(
  83. "Using cached entity"
  84. f" {entity} ({type(self._hikka_entity_cache[hashable_entity].entity).__name__})"
  85. )
  86. return copy.deepcopy(self._hikka_entity_cache[hashable_entity].entity)
  87. resolved_entity = await TelegramClient.get_entity(self, entity)
  88. if resolved_entity:
  89. cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
  90. self._hikka_entity_cache[hashable_entity] = cache_record
  91. logger.debug(f"Saved hashable_entity {hashable_entity} to cache")
  92. if getattr(resolved_entity, "id", None):
  93. logger.debug(f"Saved resolved_entity id {resolved_entity.id} to cache")
  94. self._hikka_entity_cache[resolved_entity.id] = cache_record
  95. if getattr(resolved_entity, "username", None):
  96. logger.debug(
  97. f"Saved resolved_entity username @{resolved_entity.username} to"
  98. " cache"
  99. )
  100. self._hikka_entity_cache[f"@{resolved_entity.username}"] = cache_record
  101. self._hikka_entity_cache[resolved_entity.username] = cache_record
  102. return copy.deepcopy(resolved_entity)
  103. async def get_perms_cached(
  104. self,
  105. entity: EntityLike,
  106. user: Optional[EntityLike] = None,
  107. exp: int = 5 * 60,
  108. force: bool = False,
  109. ):
  110. # Will be used to determine, which client caused logging messages
  111. # parsed via inspect.stack()
  112. _hikka_client_id_logging_tag = copy.copy(self.tg_id) # skipcq
  113. entity = await self.get_entity(entity)
  114. user = await self.get_entity(user) if user else None
  115. if not hashable(entity) or not hashable(user):
  116. try:
  117. hashable_entity = next(
  118. getattr(entity, attr)
  119. for attr in {"user_id", "channel_id", "chat_id", "id"}
  120. if getattr(entity, attr, None)
  121. )
  122. except StopIteration:
  123. logger.debug(
  124. f"Can't parse hashable from {entity=}, using legacy method"
  125. )
  126. return await self.get_permissions(entity, user)
  127. try:
  128. hashable_user = next(
  129. getattr(user, attr)
  130. for attr in {"user_id", "channel_id", "chat_id", "id"}
  131. if getattr(user, attr, None)
  132. )
  133. except StopIteration:
  134. logger.debug(f"Can't parse hashable from {user=}, using legacy method")
  135. return await self.get_permissions(entity, user)
  136. else:
  137. hashable_entity = entity
  138. hashable_user = user
  139. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  140. hashable_entity = int(str(hashable_entity)[4:])
  141. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  142. hashable_user = int(str(hashable_user)[4:])
  143. if (
  144. not force
  145. and hashable_entity
  146. and hashable_user
  147. and hashable_user in self._hikka_perms_cache.get(hashable_entity, {})
  148. and (
  149. not exp
  150. or self._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  151. > time.time()
  152. )
  153. ):
  154. logger.debug(f"Using cached perms {hashable_entity} ({hashable_user})")
  155. return copy.deepcopy(
  156. self._hikka_perms_cache[hashable_entity][hashable_user].perms
  157. )
  158. resolved_perms = await self.get_permissions(entity, user)
  159. if resolved_perms:
  160. cache_record = CacheRecordPerms(
  161. hashable_entity,
  162. hashable_user,
  163. resolved_perms,
  164. exp,
  165. )
  166. self._hikka_perms_cache.setdefault(hashable_entity, {})[
  167. hashable_user
  168. ] = cache_record
  169. logger.debug(f"Saved hashable_entity {hashable_entity} perms to cache")
  170. def save_user(key: Union[str, int]):
  171. nonlocal self, cache_record, user, hashable_user
  172. if getattr(user, "id", None):
  173. self._hikka_perms_cache.setdefault(key, {})[user.id] = cache_record
  174. if getattr(user, "username", None):
  175. self._hikka_perms_cache.setdefault(key, {})[
  176. f"@{user.username}"
  177. ] = cache_record
  178. self._hikka_perms_cache.setdefault(key, {})[
  179. user.username
  180. ] = cache_record
  181. if getattr(entity, "id", None):
  182. logger.debug(f"Saved resolved_entity id {entity.id} perms to cache")
  183. save_user(entity.id)
  184. if getattr(entity, "username", None):
  185. logger.debug(
  186. f"Saved resolved_entity username @{entity.username} perms to cache"
  187. )
  188. save_user(f"@{entity.username}")
  189. save_user(entity.username)
  190. return copy.deepcopy(resolved_perms)
  191. async def get_fullchannel(
  192. self,
  193. entity: EntityLike,
  194. exp: int = 300,
  195. force: bool = False,
  196. ) -> ChannelFull:
  197. """
  198. Gets the FullChannelRequest and cache it
  199. :param entity: Channel to fetch ChannelFull of
  200. :param exp: Expiration time of the cache record and maximum time of already cached record
  201. :param force: Whether to force refresh the cache (make API request)
  202. :return: :obj:`ChannelFull`
  203. """
  204. if not hashable(entity):
  205. try:
  206. hashable_entity = next(
  207. getattr(entity, attr)
  208. for attr in {"channel_id", "chat_id", "id"}
  209. if getattr(entity, attr, None)
  210. )
  211. except StopIteration:
  212. logger.debug(
  213. f"Can't parse hashable from {entity=}, using legacy fullchannel"
  214. " request"
  215. )
  216. return await self(GetFullChannelRequest(channel=entity))
  217. else:
  218. hashable_entity = entity
  219. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  220. hashable_entity = int(str(hashable_entity)[4:])
  221. if (
  222. not force
  223. and self._hikka_fullchannel_cache.get(hashable_entity)
  224. and not self._hikka_fullchannel_cache[hashable_entity].expired()
  225. and self._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  226. ):
  227. return self._hikka_fullchannel_cache[hashable_entity].full_channel
  228. result = await self(GetFullChannelRequest(channel=entity))
  229. self._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  230. hashable_entity,
  231. result,
  232. exp,
  233. )
  234. return result
  235. async def get_fulluser(
  236. self,
  237. entity: EntityLike,
  238. exp: int = 300,
  239. force: bool = False,
  240. ) -> UserFull:
  241. """
  242. Gets the FullUserRequest and cache it
  243. :param entity: User to fetch UserFull of
  244. :param exp: Expiration time of the cache record and maximum time of already cached record
  245. :param force: Whether to force refresh the cache (make API request)
  246. :return: :obj:`UserFull`
  247. """
  248. if not hashable(entity):
  249. try:
  250. hashable_entity = next(
  251. getattr(entity, attr)
  252. for attr in {"user_id", "chat_id", "id"}
  253. if getattr(entity, attr, None)
  254. )
  255. except StopIteration:
  256. logger.debug(
  257. f"Can't parse hashable from {entity=}, using legacy fulluser"
  258. " request"
  259. )
  260. return await self(GetFullUserRequest(entity))
  261. else:
  262. hashable_entity = entity
  263. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  264. hashable_entity = int(str(hashable_entity)[4:])
  265. if (
  266. not force
  267. and self._hikka_fulluser_cache.get(hashable_entity)
  268. and not self._hikka_fulluser_cache[hashable_entity].expired()
  269. and self._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  270. ):
  271. return self._hikka_fulluser_cache[hashable_entity].full_user
  272. result = await self(GetFullUserRequest(entity))
  273. self._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  274. hashable_entity,
  275. result,
  276. exp,
  277. )
  278. return result
  279. async def _call(
  280. self,
  281. sender: MTProtoSender,
  282. request: TLRequest,
  283. ordered: bool = False,
  284. flood_sleep_threshold: int = None,
  285. ):
  286. # ⚠️⚠️ WARNING! ⚠️⚠️
  287. # If you are a module developer, and you'll try to bypass this protection to
  288. # force user join your channel, you will be added to SCAM modules
  289. # list and you will be banned from Hikka federation.
  290. # Let USER decide, which channel he will follow. Do not be so petty
  291. # I hope, you understood me.
  292. # Thank you
  293. if not self.__forbidden_constructors:
  294. return await TelegramClient._call(
  295. self,
  296. sender,
  297. request,
  298. ordered,
  299. flood_sleep_threshold,
  300. )
  301. not_tuple = False
  302. if not is_list_like(request):
  303. not_tuple = True
  304. request = (request,)
  305. new_request = []
  306. for item in request:
  307. if item.CONSTRUCTOR_ID in self.__forbidden_constructors and next(
  308. (
  309. frame_info.frame.f_locals["self"]
  310. for frame_info in inspect.stack()
  311. if hasattr(frame_info, "frame")
  312. and hasattr(frame_info.frame, "f_locals")
  313. and isinstance(frame_info.frame.f_locals, dict)
  314. and "self" in frame_info.frame.f_locals
  315. and isinstance(frame_info.frame.f_locals["self"], Module)
  316. and frame_info.frame.f_locals["self"].__class__.__name__
  317. not in {
  318. "APIRatelimiterMod",
  319. "ForbidJoinMod",
  320. "LoaderMod",
  321. "HikkaSettingsMod",
  322. }
  323. # APIRatelimiterMod is a core proxy, so it wraps around every module in Hikka, if installed
  324. # ForbidJoinMod is also a Core proxy, so it wraps around every module in Hikka, if installed
  325. # LoaderMod prompts user to join developers' channels
  326. # HikkaSettings prompts user to join channels, required by modules
  327. ),
  328. None,
  329. ):
  330. logger.debug(
  331. "🎉 I protected you from unintented"
  332. f" {item.__class__.__name__} ({item})!"
  333. )
  334. continue
  335. new_request += [item]
  336. if not new_request:
  337. return
  338. return await TelegramClient._call(
  339. self,
  340. sender,
  341. new_request[0] if not_tuple else tuple(new_request),
  342. ordered,
  343. flood_sleep_threshold,
  344. )
  345. def forbid_joins(self):
  346. self.__forbidden_constructors.extend([615851205, 1817183516])
  347. async def cleaner(self):
  348. while True:
  349. for record, record_data in self._hikka_entity_cache.copy().items():
  350. if record_data.expired():
  351. del self._hikka_entity_cache[record]
  352. logger.debug(f"Cleaned outdated cache {record=}")
  353. for chat, chat_data in self._hikka_perms_cache.copy().items():
  354. for user, user_data in chat_data.copy().items():
  355. if user_data.expired():
  356. del self._hikka_perms_cache[chat][user]
  357. logger.debug(f"Cleaned outdated perms cache {chat=} {user=}")
  358. for channel_id, record in self._hikka_fullchannel_cache.copy().items():
  359. if record.expired():
  360. del self._hikka_fullchannel_cache[channel_id]
  361. logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
  362. for user_id, record in self._hikka_fulluser_cache.copy().items():
  363. if record.expired():
  364. del self._hikka_fulluser_cache[user_id]
  365. logger.debug(f"Cleaned outdated fulluser cache {user_id=}")
  366. await asyncio.sleep(3)