123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
- # █▀█ █ █ █ █▀█ █▀▄ █
- # © Copyright 2022
- # https://t.me/hikariatama
- #
- # 🔒 Licensed under the GNU AGPLv3
- # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
- import copy
- import time
- import asyncio
- import logging
- from typing import Optional, Union
- from telethon.hints import EntityLike
- from telethon import TelegramClient
- from telethon.tl.functions.channels import GetFullChannelRequest
- from telethon.tl.functions.users import GetFullUserRequest
- from telethon.tl.types import ChannelFull, UserFull
- logger = logging.getLogger(__name__)
- def hashable(value):
- """Determine whether `value` can be hashed."""
- try:
- hash(value)
- except TypeError:
- return False
- return True
- class CacheRecord:
- def __init__(
- self,
- hashable_entity: "Hashable", # type: ignore
- resolved_entity: EntityLike,
- exp: int,
- ):
- self.entity = copy.deepcopy(resolved_entity)
- self._hashable_entity = copy.deepcopy(hashable_entity)
- self._exp = round(time.time() + exp)
- self.ts = time.time()
- def expired(self):
- return self._exp < time.time()
- def __eq__(self, record: "CacheRecord"):
- return hash(record) == hash(self)
- def __hash__(self):
- return hash(self._hashable_entity)
- def __str__(self):
- return f"CacheRecord of {self.entity}"
- def __repr__(self):
- return f"CacheRecord(entity={type(self.entity).__name__}(...), exp={self._exp})"
- class CacheRecordPerms:
- def __init__(
- self,
- hashable_entity: "Hashable", # type: ignore
- hashable_user: "Hashable", # type: ignore
- resolved_perms: EntityLike,
- exp: int,
- ):
- self.perms = copy.deepcopy(resolved_perms)
- self._hashable_entity = copy.deepcopy(hashable_entity)
- self._hashable_user = copy.deepcopy(hashable_user)
- self._exp = round(time.time() + exp)
- self.ts = time.time()
- def expired(self):
- return self._exp < time.time()
- def __eq__(self, record: "CacheRecordPerms"):
- return hash(record) == hash(self)
- def __hash__(self):
- return hash((self._hashable_entity, self._hashable_user))
- def __str__(self):
- return f"CacheRecordPerms of {self.perms}"
- def __repr__(self):
- return (
- f"CacheRecordPerms(perms={type(self.perms).__name__}(...), exp={self._exp})"
- )
- class CacheRecordFullChannel:
- def __init__(self, channel_id: int, full_channel: ChannelFull, exp: int):
- self.channel_id = channel_id
- self.full_channel = full_channel
- self._exp = round(time.time() + exp)
- self.ts = time.time()
- def expired(self):
- return self._exp < time.time()
- def __eq__(self, record: "CacheRecordFullChannel"):
- return hash(record) == hash(self)
- def __hash__(self):
- return hash((self._hashable_entity, self._hashable_user))
- def __str__(self):
- return f"CacheRecordFullChannel of {self.channel_id}"
- def __repr__(self):
- return (
- f"CacheRecordFullChannel(channel_id={self.channel_id}(...),"
- f" exp={self._exp})"
- )
- class CacheRecordFullUser:
- def __init__(self, user_id: int, full_user: UserFull, exp: int):
- self.user_id = user_id
- self.full_user = full_user
- self._exp = round(time.time() + exp)
- self.ts = time.time()
- def expired(self):
- return self._exp < time.time()
- def __eq__(self, record: "CacheRecordFullUser"):
- return hash(record) == hash(self)
- def __hash__(self):
- return hash((self._hashable_entity, self._hashable_user))
- def __str__(self):
- return f"CacheRecordFullUser of {self.user_id}"
- def __repr__(self):
- return f"CacheRecordFullUser(channel_id={self.user_id}(...), exp={self._exp})"
- def install_entity_caching(client: TelegramClient):
- client._hikka_entity_cache = {}
- old = client.get_entity
- async def new(
- entity: EntityLike,
- exp: Optional[int] = 5 * 60,
- force: Optional[bool] = False,
- ):
- # Will be used to determine, which client caused logging messages
- # parsed via inspect.stack()
- _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
- if not hashable(entity):
- try:
- hashable_entity = next(
- getattr(entity, attr)
- for attr in {"user_id", "channel_id", "chat_id", "id"}
- if getattr(entity, attr, None)
- )
- except StopIteration:
- logger.debug(
- f"Can't parse hashable from {entity=}, using legacy resolve"
- )
- return await old(entity)
- else:
- hashable_entity = entity
- if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
- hashable_entity = int(str(hashable_entity)[4:])
- if (
- not force
- and hashable_entity
- and hashable_entity in client._hikka_entity_cache
- and (
- not exp
- or client._hikka_entity_cache[hashable_entity].ts + exp > time.time()
- )
- ):
- logger.debug(
- "Using cached entity"
- f" {entity} ({type(client._hikka_entity_cache[hashable_entity].entity).__name__})"
- )
- return copy.deepcopy(client._hikka_entity_cache[hashable_entity].entity)
- resolved_entity = await old(entity)
- if resolved_entity:
- cache_record = CacheRecord(hashable_entity, resolved_entity, exp)
- client._hikka_entity_cache[hashable_entity] = cache_record
- logger.debug(f"Saved hashable_entity {hashable_entity} to cache")
- if getattr(resolved_entity, "id", None):
- logger.debug(f"Saved resolved_entity id {resolved_entity.id} to cache")
- client._hikka_entity_cache[resolved_entity.id] = cache_record
- if getattr(resolved_entity, "username", None):
- logger.debug(
- f"Saved resolved_entity username @{resolved_entity.username} to"
- " cache"
- )
- client._hikka_entity_cache[
- f"@{resolved_entity.username}"
- ] = cache_record
- client._hikka_entity_cache[resolved_entity.username] = cache_record
- return copy.deepcopy(resolved_entity)
- async def cleaner(client: TelegramClient):
- while True:
- for record, record_data in client._hikka_entity_cache.copy().items():
- if record_data.expired():
- del client._hikka_entity_cache[record]
- logger.debug(f"Cleaned outdated cache {record=}")
- await asyncio.sleep(3)
- client.get_entity = new
- client.force_get_entity = old
- asyncio.ensure_future(cleaner(client))
- logger.debug("Monkeypatched client with entity cacher")
- def install_perms_caching(client: TelegramClient):
- client._hikka_perms_cache = {}
- old = client.get_permissions
- async def new(
- entity: EntityLike,
- user: Optional[EntityLike] = None,
- exp: Optional[int] = 5 * 60,
- force: Optional[bool] = False,
- ):
- # Will be used to determine, which client caused logging messages
- # parsed via inspect.stack()
- _hikka_client_id_logging_tag = copy.copy(client.tg_id) # skipcq
- entity = await client.get_entity(entity)
- user = await client.get_entity(user) if user else None
- if not hashable(entity) or not hashable(user):
- try:
- hashable_entity = next(
- getattr(entity, attr)
- for attr in {"user_id", "channel_id", "chat_id", "id"}
- if getattr(entity, attr, None)
- )
- except StopIteration:
- logger.debug(
- f"Can't parse hashable from {entity=}, using legacy method"
- )
- return await old(entity, user)
- try:
- hashable_user = next(
- getattr(user, attr)
- for attr in {"user_id", "channel_id", "chat_id", "id"}
- if getattr(user, attr, None)
- )
- except StopIteration:
- logger.debug(f"Can't parse hashable from {user=}, using legacy method")
- return await old(entity, user)
- else:
- hashable_entity = entity
- hashable_user = user
- if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
- hashable_entity = int(str(hashable_entity)[4:])
- if str(hashable_user).isdigit() and int(hashable_user) < 0:
- hashable_user = int(str(hashable_user)[4:])
- if (
- not force
- and hashable_entity
- and hashable_user
- and hashable_user in client._hikka_perms_cache.get(hashable_entity, {})
- and (
- not exp
- or client._hikka_perms_cache[hashable_entity][hashable_user].ts + exp
- > time.time()
- )
- ):
- logger.debug(f"Using cached perms {hashable_entity} ({hashable_user})")
- return copy.deepcopy(
- client._hikka_perms_cache[hashable_entity][hashable_user].perms
- )
- resolved_perms = await old(entity, user)
- if resolved_perms:
- cache_record = CacheRecordPerms(
- hashable_entity,
- hashable_user,
- resolved_perms,
- exp,
- )
- client._hikka_perms_cache.setdefault(hashable_entity, {})[
- hashable_user
- ] = cache_record
- logger.debug(f"Saved hashable_entity {hashable_entity} perms to cache")
- def save_user(key: Union[str, int]):
- nonlocal client, cache_record, user, hashable_user
- if getattr(user, "id", None):
- client._hikka_perms_cache.setdefault(key, {})[
- user.id
- ] = cache_record
- if getattr(user, "username", None):
- client._hikka_perms_cache.setdefault(key, {})[
- f"@{user.username}"
- ] = cache_record
- client._hikka_perms_cache.setdefault(key, {})[
- user.username
- ] = cache_record
- if getattr(entity, "id", None):
- logger.debug(f"Saved resolved_entity id {entity.id} perms to cache")
- save_user(entity.id)
- if getattr(entity, "username", None):
- logger.debug(
- f"Saved resolved_entity username @{entity.username} perms to cache"
- )
- save_user(f"@{entity.username}")
- save_user(entity.username)
- return copy.deepcopy(resolved_perms)
- async def cleaner(client: TelegramClient):
- while True:
- for chat, chat_data in client._hikka_perms_cache.copy().items():
- for user, user_data in chat_data.items().copy():
- if user_data.expired():
- del client._hikka_perms_cache[chat][user]
- logger.debug(f"Cleaned outdated perms cache {chat=} {user=}")
- await asyncio.sleep(3)
- client.get_perms_cached = new
- asyncio.ensure_future(cleaner(client))
- logger.debug("Monkeypatched client with perms cacher")
- def install_fullchannel_caching(client: TelegramClient):
- client._hikka_fullchannel_cache = {}
- async def get_fullchannel(
- entity: EntityLike,
- exp: Optional[int] = 300,
- force: Optional[bool] = False,
- ) -> ChannelFull:
- """
- Gets the FullChannelRequest and cache it
- :param entity: Channel to fetch ChannelFull of
- :param exp: Expiration time of the cache record and maximum time of already cached record
- :param force: Whether to force refresh the cache (make API request)
- :return: :obj:`ChannelFull`
- """
- if not hashable(entity):
- try:
- hashable_entity = next(
- getattr(entity, attr)
- for attr in {"channel_id", "chat_id", "id"}
- if getattr(entity, attr, None)
- )
- except StopIteration:
- logger.debug(
- f"Can't parse hashable from {entity=}, using legacy fullchannel"
- " request"
- )
- return await client(GetFullChannelRequest(channel=entity))
- else:
- hashable_entity = entity
- if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
- hashable_entity = int(str(hashable_entity)[4:])
- if (
- not force
- and client._hikka_fullchannel_cache.get(hashable_entity)
- and not client._hikka_fullchannel_cache[hashable_entity].expired()
- and client._hikka_fullchannel_cache[hashable_entity].ts + exp > time.time()
- ):
- return client._hikka_fullchannel_cache[hashable_entity].full_channel
- result = await client(GetFullChannelRequest(channel=entity))
- client._hikka_fullchannel_cache[hashable_entity] = CacheRecordFullChannel(
- hashable_entity,
- result,
- exp,
- )
- return result
- async def cleaner(client: TelegramClient):
- while True:
- for channel_id, record in client._hikka_fullchannel_cache.copy().items():
- if record.expired():
- del client._hikka_fullchannel_cache[channel_id]
- logger.debug(f"Cleaned outdated fullchannel cache {channel_id=}")
- await asyncio.sleep(3)
- client.get_fullchannel = get_fullchannel
- asyncio.ensure_future(cleaner(client))
- logger.debug("Monkeypatched client with fullchannel cacher")
- def install_fulluser_caching(client: TelegramClient):
- client._hikka_fulluser_cache = {}
- async def get_fulluser(
- entity: EntityLike,
- exp: Optional[int] = 300,
- force: Optional[bool] = False,
- ) -> ChannelFull:
- """
- Gets the FullUserRequest and cache it
- :param entity: User to fetch UserFull of
- :param exp: Expiration time of the cache record and maximum time of already cached record
- :param force: Whether to force refresh the cache (make API request)
- :return: :obj:`UserFull`
- """
- if not hashable(entity):
- try:
- hashable_entity = next(
- getattr(entity, attr)
- for attr in {"user_id", "chat_id", "id"}
- if getattr(entity, attr, None)
- )
- except StopIteration:
- logger.debug(
- f"Can't parse hashable from {entity=}, using legacy fulluser"
- " request"
- )
- return await client(GetFullUserRequest(entity))
- else:
- hashable_entity = entity
- if str(hashable_entity).isdigit() and int(hashable_entity) < 0:
- hashable_entity = int(str(hashable_entity)[4:])
- if (
- not force
- and client._hikka_fulluser_cache.get(hashable_entity)
- and not client._hikka_fulluser_cache[hashable_entity].expired()
- and client._hikka_fulluser_cache[hashable_entity].ts + exp > time.time()
- ):
- return client._hikka_fulluser_cache[hashable_entity].full_channel
- result = await client(GetFullUserRequest(entity))
- client._hikka_fulluser_cache[hashable_entity] = CacheRecordFullUser(
- hashable_entity,
- result,
- exp,
- )
- return result
- async def cleaner(client: TelegramClient):
- while True:
- for user_id, record in client._hikka_fulluser_cache.copy().items():
- if record.expired():
- del client._hikka_fulluser_cache[user_id]
- logger.debug(f"Cleaned outdated fulluser cache {user_id=}")
- await asyncio.sleep(3)
- client.get_fulluser = get_fulluser
- asyncio.ensure_future(cleaner(client))
- logger.debug("Monkeypatched client with fulluser cacher")
|