database.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import collections
  8. import json
  9. import logging
  10. import os
  11. import time
  12. try:
  13. import redis
  14. except ImportError as e:
  15. if "RAILWAY" in os.environ:
  16. raise e
  17. import typing
  18. from telethon.errors.rpcerrorlist import ChannelsTooMuchError
  19. from telethon.tl.types import Message
  20. from . import main, utils
  21. from .pointers import PointerDict, PointerList
  22. from .tl_cache import CustomTelegramClient
  23. from .types import JSONSerializable
  24. DATA_DIR = (
  25. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  26. if "DOCKER" not in os.environ
  27. else "/data"
  28. )
  29. logger = logging.getLogger(__name__)
  30. class NoAssetsChannel(Exception):
  31. """Raised when trying to read/store asset with no asset channel present"""
  32. class Database(dict):
  33. _next_revision_call = 0
  34. _revisions = []
  35. _assets = None
  36. _me = None
  37. _redis = None
  38. _saving_task = None
  39. def __init__(self, client: CustomTelegramClient):
  40. super().__init__()
  41. self._client = client
  42. def __repr__(self):
  43. return object.__repr__(self)
  44. def _redis_save_sync(self):
  45. with self._redis.pipeline() as pipe:
  46. pipe.set(
  47. str(self._client.tg_id),
  48. json.dumps(self, ensure_ascii=True),
  49. )
  50. pipe.execute()
  51. async def remote_force_save(self) -> bool:
  52. """Force save database to remote endpoint without waiting"""
  53. if not self._redis:
  54. return False
  55. await utils.run_sync(self._redis_save_sync)
  56. logger.debug("Published db to Redis")
  57. return True
  58. async def _redis_save(self) -> bool:
  59. """Save database to redis"""
  60. if not self._redis:
  61. return False
  62. await asyncio.sleep(5)
  63. await utils.run_sync(self._redis_save_sync)
  64. logger.debug("Published db to Redis")
  65. self._saving_task = None
  66. return True
  67. async def redis_init(self) -> bool:
  68. """Init redis database"""
  69. if REDIS_URI := os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  70. self._redis = redis.Redis.from_url(REDIS_URI)
  71. else:
  72. return False
  73. async def init(self):
  74. """Asynchronous initialization unit"""
  75. if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  76. await self.redis_init()
  77. self._db_path = os.path.join(DATA_DIR, f"config-{self._client.tg_id}.json")
  78. self.read()
  79. try:
  80. self._assets, _ = await utils.asset_channel(
  81. self._client,
  82. "hikka-assets",
  83. "🌆 Your Hikka assets will be stored here",
  84. archive=True,
  85. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  86. )
  87. except ChannelsTooMuchError:
  88. self._assets = None
  89. logger.error(
  90. "Can't find and/or create assets folder\n"
  91. "This may cause several consequences, such as:\n"
  92. "- Non working assets feature (e.g. notes)\n"
  93. "- This error will occur every restart\n\n"
  94. "You can solve this by leaving some channels/groups"
  95. )
  96. def read(self):
  97. """Read database and stores it in self"""
  98. if self._redis:
  99. try:
  100. self.update(
  101. **json.loads(
  102. self._redis.get(
  103. str(self._client.tg_id),
  104. ).decode(),
  105. )
  106. )
  107. except Exception:
  108. logger.exception("Error reading redis database")
  109. return
  110. try:
  111. with open(self._db_path, "r", encoding="utf-8") as f:
  112. self.update(**json.load(f))
  113. except (FileNotFoundError, json.decoder.JSONDecodeError):
  114. logger.warning("Database read failed! Creating new one...")
  115. def process_db_autofix(self, db: dict) -> bool:
  116. if not utils.is_serializable(db):
  117. return False
  118. for key, value in db.copy().items():
  119. if not isinstance(key, (str, int)):
  120. logger.warning(
  121. "DbAutoFix: Dropped key %s, because it is not string or int",
  122. key,
  123. )
  124. continue
  125. if not isinstance(value, dict):
  126. # If value is not a dict (module values), drop it,
  127. # otherwise it may cause problems
  128. del db[key]
  129. logger.warning(
  130. "DbAutoFix: Dropped key %s, because it is non-dict, but %s",
  131. key,
  132. type(value),
  133. )
  134. continue
  135. for subkey in value:
  136. if not isinstance(subkey, (str, int)):
  137. del db[key][subkey]
  138. logger.warning(
  139. "DbAutoFix: Dropped subkey %s of db key %s, because it is not"
  140. " string or int",
  141. subkey,
  142. key,
  143. )
  144. continue
  145. return True
  146. def save(self) -> bool:
  147. """Save database"""
  148. if not self.process_db_autofix(self):
  149. try:
  150. rev = self._revisions.pop()
  151. while not self.process_db_autofix(rev):
  152. rev = self._revisions.pop()
  153. except IndexError:
  154. raise RuntimeError(
  155. "Can't find revision to restore broken database from "
  156. "database is most likely broken and will lead to problems, "
  157. "so its save is forbidden."
  158. )
  159. self.clear()
  160. self.update(**rev)
  161. raise RuntimeError(
  162. "Rewriting database to the last revision because new one destructed it"
  163. )
  164. if self._next_revision_call < time.time():
  165. self._revisions += [dict(self)]
  166. self._next_revision_call = time.time() + 3
  167. while len(self._revisions) > 15:
  168. self._revisions.pop()
  169. if self._redis:
  170. if not self._saving_task:
  171. self._saving_task = asyncio.ensure_future(self._redis_save())
  172. return True
  173. try:
  174. with open(self._db_path, "w", encoding="utf-8") as f:
  175. json.dump(self, f, indent=4)
  176. except Exception:
  177. logger.exception("Database save failed!")
  178. return False
  179. return True
  180. async def store_asset(self, message: Message) -> int:
  181. """
  182. Save assets
  183. returns asset_id as integer
  184. """
  185. if not self._assets:
  186. raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
  187. return (
  188. (await self._client.send_message(self._assets, message)).id
  189. if isinstance(message, Message)
  190. else (
  191. await self._client.send_message(
  192. self._assets,
  193. file=message,
  194. force_document=True,
  195. )
  196. ).id
  197. )
  198. async def fetch_asset(self, asset_id: int) -> typing.Optional[Message]:
  199. """Fetch previously saved asset by its asset_id"""
  200. if not self._assets:
  201. raise NoAssetsChannel(
  202. "Tried to fetch asset from non-existing asset channel"
  203. )
  204. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  205. return asset[0] if asset else None
  206. def get(
  207. self,
  208. owner: str,
  209. key: str,
  210. default: typing.Optional[JSONSerializable] = None,
  211. ) -> JSONSerializable:
  212. """Get database key"""
  213. try:
  214. return self[owner][key]
  215. except KeyError:
  216. return default
  217. def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
  218. """Set database key"""
  219. if not utils.is_serializable(owner):
  220. raise RuntimeError(
  221. "Attempted to write object to "
  222. f"{owner=} ({type(owner)=}) of database. It is not "
  223. "JSON-serializable key which will cause errors"
  224. )
  225. if not utils.is_serializable(key):
  226. raise RuntimeError(
  227. "Attempted to write object to "
  228. f"{key=} ({type(key)=}) of database. It is not "
  229. "JSON-serializable key which will cause errors"
  230. )
  231. if not utils.is_serializable(value):
  232. raise RuntimeError(
  233. "Attempted to write object of "
  234. f"{key=} ({type(value)=}) to database. It is not "
  235. "JSON-serializable value which will cause errors"
  236. )
  237. super().setdefault(owner, {})[key] = value
  238. return self.save()
  239. def pointer(
  240. self,
  241. owner: str,
  242. key: str,
  243. default: typing.Optional[JSONSerializable] = None,
  244. ) -> typing.Union[JSONSerializable, PointerList, PointerDict]:
  245. """Get a pointer to database key"""
  246. value = self.get(owner, key, default)
  247. mapping = {
  248. list: PointerList,
  249. dict: PointerDict,
  250. collections.abc.Hashable: lambda v: v,
  251. }
  252. pointer_constructor = next(
  253. (pointer for type_, pointer in mapping.items() if isinstance(value, type_)),
  254. None,
  255. )
  256. if pointer_constructor is None:
  257. raise ValueError(
  258. f"Pointer for type {type(value).__name__} is not implemented"
  259. )
  260. return pointer_constructor(self, owner, key, default)