database.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import collections
  8. import json
  9. import logging
  10. import os
  11. import time
  12. try:
  13. import redis
  14. except ImportError as e:
  15. if "RAILWAY" in os.environ:
  16. raise e
  17. import typing
  18. from hikkatl.errors.rpcerrorlist import ChannelsTooMuchError
  19. from hikkatl.tl.types import Message, User
  20. from . import main, utils
  21. from .pointers import (
  22. BaseSerializingMiddlewareDict,
  23. BaseSerializingMiddlewareList,
  24. NamedTupleMiddlewareDict,
  25. NamedTupleMiddlewareList,
  26. PointerDict,
  27. PointerList,
  28. )
  29. from .tl_cache import CustomTelegramClient
  30. from .types import JSONSerializable
  31. __all__ = [
  32. "Database",
  33. "PointerList",
  34. "PointerDict",
  35. "NamedTupleMiddlewareDict",
  36. "NamedTupleMiddlewareList",
  37. "BaseSerializingMiddlewareDict",
  38. "BaseSerializingMiddlewareList",
  39. ]
  40. logger = logging.getLogger(__name__)
  41. class NoAssetsChannel(Exception):
  42. """Raised when trying to read/store asset with no asset channel present"""
  43. class Database(dict):
  44. def __init__(self, client: CustomTelegramClient):
  45. super().__init__()
  46. self._client: CustomTelegramClient = client
  47. self._next_revision_call: int = 0
  48. self._revisions: typing.List[dict] = []
  49. self._assets: int = None
  50. self._me: User = None
  51. self._redis: redis.Redis = None
  52. self._saving_task: asyncio.Future = None
  53. def __repr__(self):
  54. return object.__repr__(self)
  55. def _redis_save_sync(self):
  56. with self._redis.pipeline() as pipe:
  57. pipe.set(
  58. str(self._client.tg_id),
  59. json.dumps(self, ensure_ascii=True),
  60. )
  61. pipe.execute()
  62. async def remote_force_save(self) -> bool:
  63. """Force save database to remote endpoint without waiting"""
  64. if not self._redis:
  65. return False
  66. await utils.run_sync(self._redis_save_sync)
  67. logger.debug("Published db to Redis")
  68. return True
  69. async def _redis_save(self) -> bool:
  70. """Save database to redis"""
  71. if not self._redis:
  72. return False
  73. await asyncio.sleep(5)
  74. await utils.run_sync(self._redis_save_sync)
  75. logger.debug("Published db to Redis")
  76. self._saving_task = None
  77. return True
  78. async def redis_init(self) -> bool:
  79. """Init redis database"""
  80. if REDIS_URI := (
  81. os.environ.get("REDIS_URL") or main.get_config_key("redis_uri")
  82. ):
  83. self._redis = redis.Redis.from_url(REDIS_URI)
  84. else:
  85. return False
  86. async def init(self):
  87. """Asynchronous initialization unit"""
  88. if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  89. await self.redis_init()
  90. self._db_file = main.BASE_PATH / f"config-{self._client.tg_id}.json"
  91. self.read()
  92. try:
  93. self._assets, _ = await utils.asset_channel(
  94. self._client,
  95. "hikka-assets",
  96. "🌆 Your Hikka assets will be stored here",
  97. archive=True,
  98. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  99. )
  100. except ChannelsTooMuchError:
  101. self._assets = None
  102. logger.error(
  103. "Can't find and/or create assets folder\n"
  104. "This may cause several consequences, such as:\n"
  105. "- Non working assets feature (e.g. notes)\n"
  106. "- This error will occur every restart\n\n"
  107. "You can solve this by leaving some channels/groups"
  108. )
  109. def read(self):
  110. """Read database and stores it in self"""
  111. if self._redis:
  112. try:
  113. self.update(
  114. **json.loads(
  115. self._redis.get(
  116. str(self._client.tg_id),
  117. ).decode(),
  118. )
  119. )
  120. except Exception:
  121. logger.exception("Error reading redis database")
  122. return
  123. try:
  124. self.update(**json.loads(self._db_file.read_text()))
  125. except json.decoder.JSONDecodeError:
  126. logger.warning("Database read failed! Creating new one...")
  127. except FileNotFoundError:
  128. logger.debug("Database file not found, creating new one...")
  129. def process_db_autofix(self, db: dict) -> bool:
  130. if not utils.is_serializable(db):
  131. return False
  132. for key, value in db.copy().items():
  133. if not isinstance(key, (str, int)):
  134. logger.warning(
  135. "DbAutoFix: Dropped key %s, because it is not string or int",
  136. key,
  137. )
  138. continue
  139. if not isinstance(value, dict):
  140. # If value is not a dict (module values), drop it,
  141. # otherwise it may cause problems
  142. del db[key]
  143. logger.warning(
  144. "DbAutoFix: Dropped key %s, because it is non-dict, but %s",
  145. key,
  146. type(value),
  147. )
  148. continue
  149. for subkey in value:
  150. if not isinstance(subkey, (str, int)):
  151. del db[key][subkey]
  152. logger.warning(
  153. (
  154. "DbAutoFix: Dropped subkey %s of db key %s, because it is"
  155. " not string or int"
  156. ),
  157. subkey,
  158. key,
  159. )
  160. continue
  161. return True
  162. def save(self) -> bool:
  163. """Save database"""
  164. if not self.process_db_autofix(self):
  165. try:
  166. rev = self._revisions.pop()
  167. while not self.process_db_autofix(rev):
  168. rev = self._revisions.pop()
  169. except IndexError:
  170. raise RuntimeError(
  171. "Can't find revision to restore broken database from "
  172. "database is most likely broken and will lead to problems, "
  173. "so its save is forbidden."
  174. )
  175. self.clear()
  176. self.update(**rev)
  177. raise RuntimeError(
  178. "Rewriting database to the last revision because new one destructed it"
  179. )
  180. if self._next_revision_call < time.time():
  181. self._revisions += [dict(self)]
  182. self._next_revision_call = time.time() + 3
  183. while len(self._revisions) > 15:
  184. self._revisions.pop()
  185. if self._redis:
  186. if not self._saving_task:
  187. self._saving_task = asyncio.ensure_future(self._redis_save())
  188. return True
  189. try:
  190. self._db_file.write_text(json.dumps(self, indent=4))
  191. except Exception:
  192. logger.exception("Database save failed!")
  193. return False
  194. return True
  195. async def store_asset(self, message: Message) -> int:
  196. """
  197. Save assets
  198. returns asset_id as integer
  199. """
  200. if not self._assets:
  201. raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
  202. return (
  203. (await self._client.send_message(self._assets, message)).id
  204. if isinstance(message, Message)
  205. else (
  206. await self._client.send_message(
  207. self._assets,
  208. file=message,
  209. force_document=True,
  210. )
  211. ).id
  212. )
  213. async def fetch_asset(self, asset_id: int) -> typing.Optional[Message]:
  214. """Fetch previously saved asset by its asset_id"""
  215. if not self._assets:
  216. raise NoAssetsChannel(
  217. "Tried to fetch asset from non-existing asset channel"
  218. )
  219. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  220. return asset[0] if asset else None
  221. def get(
  222. self,
  223. owner: str,
  224. key: str,
  225. default: typing.Optional[JSONSerializable] = None,
  226. ) -> JSONSerializable:
  227. """Get database key"""
  228. try:
  229. return self[owner][key]
  230. except KeyError:
  231. return default
  232. def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
  233. """Set database key"""
  234. if not utils.is_serializable(owner):
  235. raise RuntimeError(
  236. "Attempted to write object to "
  237. f"{owner=} ({type(owner)=}) of database. It is not "
  238. "JSON-serializable key which will cause errors"
  239. )
  240. if not utils.is_serializable(key):
  241. raise RuntimeError(
  242. "Attempted to write object to "
  243. f"{key=} ({type(key)=}) of database. It is not "
  244. "JSON-serializable key which will cause errors"
  245. )
  246. if not utils.is_serializable(value):
  247. raise RuntimeError(
  248. "Attempted to write object of "
  249. f"{key=} ({type(value)=}) to database. It is not "
  250. "JSON-serializable value which will cause errors"
  251. )
  252. super().setdefault(owner, {})[key] = value
  253. return self.save()
  254. def pointer(
  255. self,
  256. owner: str,
  257. key: str,
  258. default: typing.Optional[JSONSerializable] = None,
  259. item_type: typing.Optional[typing.Any] = None,
  260. ) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
  261. """Get a pointer to database key"""
  262. value = self.get(owner, key, default)
  263. mapping = {
  264. list: PointerList,
  265. dict: PointerDict,
  266. collections.abc.Hashable: lambda v: v,
  267. }
  268. pointer_constructor = next(
  269. (pointer for type_, pointer in mapping.items() if isinstance(value, type_)),
  270. None,
  271. )
  272. if (current_value := self.get(owner, key, None)) and type(
  273. current_value
  274. ) is not type(default):
  275. raise ValueError(
  276. f"Can't switch the type of pointer in database (current: {type(current_value)}, requested: {type(default)})"
  277. )
  278. if pointer_constructor is None:
  279. raise ValueError(
  280. f"Pointer for type {type(value).__name__} is not implemented"
  281. )
  282. if item_type is not None:
  283. if isinstance(value, list):
  284. for item in self.get(owner, key, default):
  285. if not isinstance(item, dict):
  286. raise ValueError(
  287. "Item type can only be specified for dedicated keys and"
  288. " can't be mixed with other ones"
  289. )
  290. return NamedTupleMiddlewareList(
  291. pointer_constructor(self, owner, key, default),
  292. item_type,
  293. )
  294. if isinstance(value, dict):
  295. for item in self.get(owner, key, default).values():
  296. if not isinstance(item, dict):
  297. raise ValueError(
  298. "Item type can only be specified for dedicated keys and"
  299. " can't be mixed with other ones"
  300. )
  301. return NamedTupleMiddlewareDict(
  302. pointer_constructor(self, owner, key, default),
  303. item_type,
  304. )
  305. return pointer_constructor(self, owner, key, default)