entity_cache.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import copy
  9. import time
  10. import asyncio
  11. import logging
  12. from telethon.hints import EntityLike
  13. from telethon import TelegramClient
  14. logger = logging.getLogger(__name__)
  15. def hashable(value):
  16. """Determine whether `value` can be hashed."""
  17. try:
  18. hash(value)
  19. except TypeError:
  20. return False
  21. return True
  22. class CacheRecord:
  23. def __init__(
  24. self,
  25. hashable_entity: "Hashable", # type: ignore
  26. resolved_entity: EntityLike,
  27. ):
  28. self.entity = resolved_entity
  29. self._hashable_entity = hashable_entity
  30. self._exp = round(time.time() + 5 * 60)
  31. def expired(self):
  32. return self._exp < time.time()
  33. def __eq__(self, record: "CacheRecord"):
  34. return hash(record._hashable_entity) == hash(self._hashable_entity)
  35. def __hash__(self):
  36. return hash(self._hashable_entity)
  37. def __str__(self):
  38. return f"CacheRecord of {self.entity}"
  39. def __repr__(self):
  40. return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
  41. def install_entity_caching(client: TelegramClient):
  42. client._hikka_cache = {}
  43. old = client.get_entity
  44. async def new(entity: EntityLike):
  45. # Will be used to determine, which client caused logging messages
  46. # parsed via inspect.stack()
  47. _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
  48. if not hashable(entity):
  49. try:
  50. hashable_entity = next(
  51. getattr(entity, attr)
  52. for attr in {"user_id", "channel_id", "chat_id", "id"}
  53. if getattr(entity, attr, None)
  54. )
  55. except StopIteration:
  56. logger.debug(
  57. f"Can't parse hashable from {entity=}, using legacy resolve"
  58. )
  59. return await client.get_entity(entity)
  60. else:
  61. hashable_entity = entity
  62. if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
  63. hashable_entity = int(str(hashable_entity)[4:])
  64. if hashable_entity and hashable_entity in client._hikka_cache:
  65. logger.debug(
  66. "Using cached entity"
  67. f" {entity} ({type(client._hikka_cache[hashable_entity].entity).__name__})"
  68. )
  69. return client._hikka_cache[hashable_entity].entity
  70. resolved_entity = await old(entity)
  71. if resolved_entity:
  72. cache_record = CacheRecord(hashable_entity, resolved_entity)
  73. client._hikka_cache[hashable_entity] = cache_record
  74. logger.debug(f"Saved hashable_entity {hashable_entity} to cache")
  75. if getattr(resolved_entity, "id", None):
  76. logger.debug(f"Saved resolved_entity id {resolved_entity.id} to cache")
  77. client._hikka_cache[resolved_entity.id] = cache_record
  78. if getattr(resolved_entity, "username", None):
  79. logger.debug(
  80. f"Saved resolved_entity username @{resolved_entity.username} to"
  81. " cache"
  82. )
  83. client._hikka_cache[f"@{resolved_entity.username}"] = cache_record
  84. return resolved_entity
  85. async def cleaner(client: TelegramClient):
  86. while True:
  87. for record, record_data in client._hikka_cache.copy().items():
  88. if record_data.expired():
  89. del client._hikka_cache[record]
  90. logger.debug(f"Cleaned outdated cache {record=}")
  91. await asyncio.sleep(3)
  92. client.get_entity = new
  93. client.force_get_entity = old
  94. asyncio.ensure_future(cleaner(client))
  95. logger.debug("Monkeypatched client with cacher")