tl_cache.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import copy
  9. import time
  10. import asyncio
  11. import logging
  12. from typing import Optional, Union
  13. from telethon.hints import EntityLike
  14. from telethon import TelegramClient
  15. from telethon.tl.functions.channels import GetFullChannelRequest
  16. from telethon.tl.functions.users import GetFullUserRequest
  17. from telethon.tl.types import ChannelFull, UserFull
  18. logger = logging.getLogger(__name__)
  19. def hashable(value):
  20. """Determine whether `value` can be hashed."""
  21. try:
  22. hash(value)
  23. except TypeError:
  24. return False
  25. return True
  26. class CacheRecord:
  27. def __init__(
  28. self,
  29. hashable_entity: "Hashable", # type: ignore
  30. resolved_entity: EntityLike,
  31. exp: int,
  32. ):
  33. self.entity = copy.deepcopy(resolved_entity)
  34. self._hashable_entity = copy.deepcopy(hashable_entity)
  35. self._exp = round(time.time() + exp)
  36. self.ts = time.time()
  37. def expired(self):
  38. return self._exp < time.time()
  39. def __eq__(self, record: "CacheRecord"):
  40. return hash(record) == hash(self)
  41. def __hash__(self):
  42. return hash(self._hashable_entity)
  43. def __str__(self):
  44. return f"CacheRecord of {self.entity}"
  45. def __repr__(self):
  46. return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
  47. class CacheRecordPerms:
  48. def __init__(
  49. self,
  50. hashable_entity: "Hashable", # type: ignore
  51. hashable_user: "Hashable", # type: ignore
  52. resolved_perms: EntityLike,
  53. exp: int,
  54. ):
  55. self.perms = copy.deepcopy(resolved_perms)
  56. self._hashable_entity = copy.deepcopy(hashable_entity)
  57. self._hashable_user = copy.deepcopy(hashable_user)
  58. self._exp = round(time.time() + exp)
  59. self.ts = time.time()
  60. def expired(self):
  61. return self._exp < time.time()
  62. def __eq__(self, record: "CacheRecordPerms"):
  63. return hash(record) == hash(self)
  64. def __hash__(self):
  65. return hash((self._hashable_entity, self._hashable_user))
  66. def __str__(self):
  67. return f"CacheRecordPerms of {self.perms}"
  68. def __repr__(self):
  69. return (
  70. f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
  71. )
  72. class CacheRecordFullChannel:
  73. def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
  74. self.channel_id = channel_id
  75. self.full_channel = full_channel
  76. self._exp = round(time.time() + exp)
  77. self.ts = time.time()
  78. def expired(self):
  79. return self._exp < time.time()
  80. def __eq__(self, record: "CacheRecordFullChannel"):
  81. return hash(record) == hash(self)
  82. def __hash__(self):
  83. return hash((self._hashable_entity, self._hashable_user))
  84. def __str__(self):
  85. return f"CacheRecordFullChannel of {self.channel_id}"
  86. def __repr__(self):
  87. return (
  88. f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
  89. f" exp={self._exp})"
  90. )
  91. class CacheRecordFullUser:
  92. def __init__(self, user_id: int, full_user: UserFull, exp: int):
  93. self.user_id = user_id
  94. self.full_user = full_user
  95. self._exp = round(time.time() + exp)
  96. self.ts = time.time()
  97. def expired(self):
  98. return self._exp < time.time()
  99. def __eq__(self, record: "CacheRecordFullUser"):
  100. return hash(record) == hash(self)
  101. def __hash__(self):
  102. return hash((self._hashable_entity, self._hashable_user))
  103. def __str__(self):
  104. return f"CacheRecordFullUser of {self.user_id}"
  105. def __repr__(self):
  106. return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
  107. def install_entity_caching(client: TelegramClient):
  108. client._hikka_entity_cache = {}
  109. old = client.get_entity
  110. async def new(
  111. entity: EntityLike,
  112. exp: Optional[int] = 5 * 60,
  113. force: Optional[bool] = False,
  114. ):
  115. # Will be used to determine, which client caused logging messages
  116. # parsed via inspect.stack()
  117. _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
  118. if not hashable(entity):
  119. try:
  120. hashable_entity = next(
  121. getattr(entity, attr)
  122. for attr in {"user_id", "channel_id", "chat_id", "id"}
  123. if getattr(entity, attr, None)
  124. )
  125. except StopIteration:
  126. logger.debug(
  127. f"Can't parse hashable from {entity=}, using legacy resolve"
  128. )
  129. return await old(entity)
  130. else:
  131. hashable_entity = entity
  132. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  133. hashable_entity = int(str(hashable_entity)[4:])
  134. if (
  135. not force
  136. and hashable_entity
  137. and hashable_entity in client._hikka_entity_cache
  138. and (
  139. not exp
  140. or client._hikka_entity_cache[hashable_entity].ts + exp > time.time()
  141. )
  142. ):
  143. logger.debug(
  144. "Using cached entity"
  145. f" {entity} ({type(client._hikka_entity_cache[hashable_entity].entity).__name__})"
  146. )
  147. return copy.deepcopy(client._hikka_entity_cache[hashable_entity].entity)
  148. resolved_entity = await old(entity)
  149. if resolved_entity:
  150. cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
  151. client._hikka_entity_cache[hashable_entity] = cache_record
  152. logger.debug(f"Saved hashable_entity {hashable_entity} to cache")
  153. if getattr(resolved_entity, "id", None):
  154. logger.debug(f"Saved resolved_entity id {resolved_entity.id} to cache")
  155. client._hikka_entity_cache[resolved_entity.id] = cache_record
  156. if getattr(resolved_entity, "username", None):
  157. logger.debug(
  158. f"Saved resolved_entity username @{resolved_entity.username} to"
  159. " cache"
  160. )
  161. client._hikka_entity_cache[
  162. f"@{resolved_entity.username}"
  163. ] = cache_record
  164. client._hikka_entity_cache[resolved_entity.username] = cache_record
  165. return copy.deepcopy(resolved_entity)
  166. async def cleaner(client: TelegramClient):
  167. while True:
  168. for record, record_data in client._hikka_entity_cache.copy().items():
  169. if record_data.expired():
  170. del client._hikka_entity_cache[record]
  171. logger.debug(f"Cleaned outdated cache {record=}")
  172. await asyncio.sleep(3)
  173. client.get_entity = new
  174. client.force_get_entity = old
  175. asyncio.ensure_future(cleaner(client))
  176. logger.debug("Monkeypatched client with entity cacher")
  177. def install_perms_caching(client: TelegramClient):
  178. client._hikka_perms_cache = {}
  179. old = client.get_permissions
  180. async def new(
  181. entity: EntityLike,
  182. user: Optional[EntityLike] = None,
  183. exp: Optional[int] = 5 * 60,
  184. force: Optional[bool] = False,
  185. ):
  186. # Will be used to determine, which client caused logging messages
  187. # parsed via inspect.stack()
  188. _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
  189. entity = await client.get_entity(entity)
  190. user = await client.get_entity(user) if user else None
  191. if not hashable(entity) or not hashable(user):
  192. try:
  193. hashable_entity = next(
  194. getattr(entity, attr)
  195. for attr in {"user_id", "channel_id", "chat_id", "id"}
  196. if getattr(entity, attr, None)
  197. )
  198. except StopIteration:
  199. logger.debug(
  200. f"Can't parse hashable from {entity=}, using legacy method"
  201. )
  202. return await old(entity, user)
  203. try:
  204. hashable_user = next(
  205. getattr(user, attr)
  206. for attr in {"user_id", "channel_id", "chat_id", "id"}
  207. if getattr(user, attr, None)
  208. )
  209. except StopIteration:
  210. logger.debug(f"Can't parse hashable from {user=}, using legacy method")
  211. return await old(entity, user)
  212. else:
  213. hashable_entity = entity
  214. hashable_user = user
  215. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  216. hashable_entity = int(str(hashable_entity)[4:])
  217. if str(hashable_user).isdigit() and int(hashable_user) < 0:
  218. hashable_user = int(str(hashable_user)[4:])
  219. if (
  220. not force
  221. and hashable_entity
  222. and hashable_user
  223. and hashable_user in client._hikka_perms_cache.get(hashable_entity, {})
  224. and (
  225. not exp
  226. or client._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
  227. > time.time()
  228. )
  229. ):
  230. logger.debug(f"Using cached perms {hashable_entity} ({hashable_user})")
  231. return copy.deepcopy(
  232. client._hikka_perms_cache[hashable_entity][hashable_user].perms
  233. )
  234. resolved_perms = await old(entity, user)
  235. if resolved_perms:
  236. cache_record = CacheRecordPerms(
  237. hashable_entity,
  238. hashable_user,
  239. resolved_perms,
  240. exp,
  241. )
  242. client._hikka_perms_cache.setdefault(hashable_entity, {})[
  243. hashable_user
  244. ] = cache_record
  245. logger.debug(f"Saved hashable_entity {hashable_entity} perms to cache")
  246. def save_user(key: Union[str, int]):
  247. nonlocal client, cache_record, user, hashable_user
  248. if getattr(user, "id", None):
  249. client._hikka_perms_cache.setdefault(key, {})[
  250. user.id
  251. ] = cache_record
  252. if getattr(user, "username", None):
  253. client._hikka_perms_cache.setdefault(key, {})[
  254. f"@{user.username}"
  255. ] = cache_record
  256. client._hikka_perms_cache.setdefault(key, {})[
  257. user.username
  258. ] = cache_record
  259. if getattr(entity, "id", None):
  260. logger.debug(f"Saved resolved_entity id {entity.id} perms to cache")
  261. save_user(entity.id)
  262. if getattr(entity, "username", None):
  263. logger.debug(
  264. f"Saved resolved_entity username @{entity.username} perms to cache"
  265. )
  266. save_user(f"@{entity.username}")
  267. save_user(entity.username)
  268. return copy.deepcopy(resolved_perms)
  269. async def cleaner(client: TelegramClient):
  270. while True:
  271. for chat, chat_data in client._hikka_perms_cache.copy().items():
  272. for user, user_data in chat_data.items().copy():
  273. if user_data.expired():
  274. del client._hikka_perms_cache[chat][user]
  275. logger.debug(f"Cleaned outdated perms cache {chat=} {user=}")
  276. await asyncio.sleep(3)
  277. client.get_perms_cached = new
  278. asyncio.ensure_future(cleaner(client))
  279. logger.debug("Monkeypatched client with perms cacher")
  280. def install_fullchannel_caching(client: TelegramClient):
  281. client._hikka_fullchannel_cache = {}
  282. async def get_fullchannel(
  283. entity: EntityLike,
  284. exp: Optional[int] = 300,
  285. force: Optional[bool] = False,
  286. ) -> ChannelFull:
  287. """
  288. Gets the FullChannelRequest and cache it
  289. :param entity: Channel to fetch ChannelFull of
  290. :param exp: Expiration time of the cache record and maximum time of already cached record
  291. :param force: Whether to force refresh the cache (make API request)
  292. :return: :obj:`ChannelFull`
  293. """
  294. if not hashable(entity):
  295. try:
  296. hashable_entity = next(
  297. getattr(entity, attr)
  298. for attr in {"channel_id", "chat_id", "id"}
  299. if getattr(entity, attr, None)
  300. )
  301. except StopIteration:
  302. logger.debug(
  303. f"Can't parse hashable from {entity=}, using legacy fullchannel"
  304. " request"
  305. )
  306. return await client(GetFullChannelRequest(channel=entity))
  307. else:
  308. hashable_entity = entity
  309. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  310. hashable_entity = int(str(hashable_entity)[4:])
  311. if (
  312. not force
  313. and client._hikka_fullchannel_cache.get(hashable_entity)
  314. and not client._hikka_fullchannel_cache[hashable_entity].expired()
  315. and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
  316. ):
  317. return client._hikka_fullchannel_cache[hashable_entity].full_channel
  318. result = await client(GetFullChannelRequest(channel=entity))
  319. client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
  320. hashable_entity,
  321. result,
  322. exp,
  323. )
  324. return result
  325. async def cleaner(client: TelegramClient):
  326. while True:
  327. for channel_id, record in client._hikka_fullchannel_cache.copy().items():
  328. if record.expired():
  329. del client._hikka_fullchannel_cache[channel_id]
  330. logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
  331. await asyncio.sleep(3)
  332. client.get_fullchannel = get_fullchannel
  333. asyncio.ensure_future(cleaner(client))
  334. logger.debug("Monkeypatched client with fullchannel cacher")
  335. def install_fulluser_caching(client: TelegramClient):
  336. client._hikka_fulluser_cache = {}
  337. async def get_fulluser(
  338. entity: EntityLike,
  339. exp: Optional[int] = 300,
  340. force: Optional[bool] = False,
  341. ) -> ChannelFull:
  342. """
  343. Gets the FullUserRequest and cache it
  344. :param entity: User to fetch UserFull of
  345. :param exp: Expiration time of the cache record and maximum time of already cached record
  346. :param force: Whether to force refresh the cache (make API request)
  347. :return: :obj:`UserFull`
  348. """
  349. if not hashable(entity):
  350. try:
  351. hashable_entity = next(
  352. getattr(entity, attr)
  353. for attr in {"user_id", "chat_id", "id"}
  354. if getattr(entity, attr, None)
  355. )
  356. except StopIteration:
  357. logger.debug(
  358. f"Can't parse hashable from {entity=}, using legacy fulluser"
  359. " request"
  360. )
  361. return await client(GetFullUserRequest(entity))
  362. else:
  363. hashable_entity = entity
  364. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  365. hashable_entity = int(str(hashable_entity)[4:])
  366. if (
  367. not force
  368. and client._hikka_fulluser_cache.get(hashable_entity)
  369. and not client._hikka_fulluser_cache[hashable_entity].expired()
  370. and client._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
  371. ):
  372. return client._hikka_fulluser_cache[hashable_entity].full_channel
  373. result = await client(GetFullUserRequest(entity))
  374. client._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
  375. hashable_entity,
  376. result,
  377. exp,
  378. )
  379. return result
  380. async def cleaner(client: TelegramClient):
  381. while True:
  382. for user_id, record in client._hikka_fulluser_cache.copy().items():
  383. if record.expired():
  384. del client._hikka_fulluser_cache[user_id]
  385. logger.debug(f"Cleaned outdated fulluser cache {user_id=}")
  386. await asyncio.sleep(3)
  387. client.get_fulluser = get_fulluser
  388. asyncio.ensure_future(cleaner(client))
  389. logger.debug("Monkeypatched client with fulluser cacher")