database.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import json
  9. import logging
  10. import os
  11. import time
  12. import asyncio
  13. import collections
  14. try:
  15. import redis
  16. except ImportError as e:
  17. if "RAILWAY" in os.environ:
  18. raise e
  19. import typing
  20. from telethon.tl.types import Message
  21. from telethon.errors.rpcerrorlist import ChannelsTooMuchError
  22. from . import utils, main
  23. from .pointers import (
  24. PointerList,
  25. PointerDict,
  26. )
  27. from .types import JSONSerializable
  28. from .tl_cache import CustomTelegramClient
  29. DATA_DIR = (
  30. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  31. if "OKTETO" not in os.environ and "DOCKER" not in os.environ
  32. else "/data"
  33. )
  34. logger = logging.getLogger(__name__)
  35. class NoAssetsChannel(Exception):
  36. """Raised when trying to read/store asset with no asset channel present"""
  37. class Database(dict):
  38. _next_revision_call = 0
  39. _revisions = []
  40. _assets = None
  41. _me = None
  42. _redis = None
  43. _saving_task = None
  44. def __init__(self, client: CustomTelegramClient):
  45. super().__init__()
  46. self._client = client
  47. def __repr__(self):
  48. return object.__repr__(self)
  49. def _redis_save_sync(self):
  50. with self._redis.pipeline() as pipe:
  51. pipe.set(
  52. str(self._client.tg_id),
  53. json.dumps(self, ensure_ascii=True),
  54. )
  55. pipe.execute()
  56. async def remote_force_save(self) -> bool:
  57. """Force save database to remote endpoint without waiting"""
  58. if not self._redis:
  59. return False
  60. await utils.run_sync(self._redis_save_sync)
  61. logger.debug("Published db to Redis")
  62. return True
  63. async def _redis_save(self) -> bool:
  64. """Save database to redis"""
  65. if not self._redis:
  66. return False
  67. await asyncio.sleep(5)
  68. await utils.run_sync(self._redis_save_sync)
  69. logger.debug("Published db to Redis")
  70. self._saving_task = None
  71. return True
  72. async def redis_init(self) -> bool:
  73. """Init redis database"""
  74. if REDIS_URI := os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  75. self._redis = redis.Redis.from_url(REDIS_URI)
  76. else:
  77. return False
  78. async def init(self):
  79. """Asynchronous initialization unit"""
  80. if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  81. await self.redis_init()
  82. self._db_path = os.path.join(DATA_DIR, f"config-{self._client.tg_id}.json")
  83. self.read()
  84. try:
  85. self._assets, _ = await utils.asset_channel(
  86. self._client,
  87. "hikka-assets",
  88. "🌆 Your Hikka assets will be stored here",
  89. archive=True,
  90. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  91. )
  92. except ChannelsTooMuchError:
  93. self._assets = None
  94. logger.error(
  95. "Can't find and/or create assets folder\n"
  96. "This may cause several consequences, such as:\n"
  97. "- Non working assets feature (e.g. notes)\n"
  98. "- This error will occur every restart\n\n"
  99. "You can solve this by leaving some channels/groups"
  100. )
  101. def read(self):
  102. """Read database and stores it in self"""
  103. if self._redis:
  104. try:
  105. self.update(
  106. **json.loads(
  107. self._redis.get(
  108. str(self._client.tg_id),
  109. ).decode(),
  110. )
  111. )
  112. except Exception:
  113. logger.exception("Error reading redis database")
  114. return
  115. try:
  116. with open(self._db_path, "r", encoding="utf-8") as f:
  117. self.update(**json.load(f))
  118. except (FileNotFoundError, json.decoder.JSONDecodeError):
  119. logger.warning("Database read failed! Creating new one...")
  120. def process_db_autofix(self, db: dict) -> bool:
  121. if not utils.is_serializable(db):
  122. return False
  123. for key, value in db.copy().items():
  124. if not isinstance(key, (str, int)):
  125. logger.warning(
  126. "DbAutoFix: Dropped key %s, because it is not string or int",
  127. key,
  128. )
  129. continue
  130. if not isinstance(value, dict):
  131. # If value is not a dict (module values), drop it,
  132. # otherwise it may cause problems
  133. del db[key]
  134. logger.warning(
  135. "DbAutoFix: Dropped key %s, because it is non-dict, but %s",
  136. key,
  137. type(value),
  138. )
  139. continue
  140. for subkey in value:
  141. if not isinstance(subkey, (str, int)):
  142. del db[key][subkey]
  143. logger.warning(
  144. "DbAutoFix: Dropped subkey %s of db key %s, because it is not"
  145. " string or int",
  146. subkey,
  147. key,
  148. )
  149. continue
  150. return True
  151. def save(self) -> bool:
  152. """Save database"""
  153. if not self.process_db_autofix(self):
  154. try:
  155. rev = self._revisions.pop()
  156. while not self.process_db_autofix(rev):
  157. rev = self._revisions.pop()
  158. except IndexError:
  159. raise RuntimeError(
  160. "Can't find revision to restore broken database from "
  161. "database is most likely broken and will lead to problems, "
  162. "so its save is forbidden."
  163. )
  164. self.clear()
  165. self.update(**rev)
  166. raise RuntimeError(
  167. "Rewriting database to the last revision because new one destructed it"
  168. )
  169. if self._next_revision_call < time.time():
  170. self._revisions += [dict(self)]
  171. self._next_revision_call = time.time() + 3
  172. while len(self._revisions) > 15:
  173. self._revisions.pop()
  174. if self._redis:
  175. if not self._saving_task:
  176. self._saving_task = asyncio.ensure_future(self._redis_save())
  177. return True
  178. try:
  179. with open(self._db_path, "w", encoding="utf-8") as f:
  180. json.dump(self, f, indent=4)
  181. except Exception:
  182. logger.exception("Database save failed!")
  183. return False
  184. return True
  185. async def store_asset(self, message: Message) -> int:
  186. """
  187. Save assets
  188. returns asset_id as integer
  189. """
  190. if not self._assets:
  191. raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
  192. return (
  193. (await self._client.send_message(self._assets, message)).id
  194. if isinstance(message, Message)
  195. else (
  196. await self._client.send_message(
  197. self._assets,
  198. file=message,
  199. force_document=True,
  200. )
  201. ).id
  202. )
  203. async def fetch_asset(self, asset_id: int) -> typing.Optional[Message]:
  204. """Fetch previously saved asset by its asset_id"""
  205. if not self._assets:
  206. raise NoAssetsChannel(
  207. "Tried to fetch asset from non-existing asset channel"
  208. )
  209. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  210. return asset[0] if asset else None
  211. def get(
  212. self,
  213. owner: str,
  214. key: str,
  215. default: typing.Optional[JSONSerializable] = None,
  216. ) -> JSONSerializable:
  217. """Get database key"""
  218. try:
  219. return self[owner][key]
  220. except KeyError:
  221. return default
  222. def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
  223. """Set database key"""
  224. if not utils.is_serializable(owner):
  225. raise RuntimeError(
  226. "Attempted to write object to "
  227. f"{owner=} ({type(owner)=}) of database. It is not "
  228. "JSON-serializable key which will cause errors"
  229. )
  230. if not utils.is_serializable(key):
  231. raise RuntimeError(
  232. "Attempted to write object to "
  233. f"{key=} ({type(key)=}) of database. It is not "
  234. "JSON-serializable key which will cause errors"
  235. )
  236. if not utils.is_serializable(value):
  237. raise RuntimeError(
  238. "Attempted to write object of "
  239. f"{key=} ({type(value)=}) to database. It is not "
  240. "JSON-serializable value which will cause errors"
  241. )
  242. super().setdefault(owner, {})[key] = value
  243. return self.save()
  244. def pointer(
  245. self,
  246. owner: str,
  247. key: str,
  248. default: typing.Optional[JSONSerializable] = None,
  249. ) -> JSONSerializable:
  250. """Get a pointer to database key"""
  251. value = self.get(owner, key, default)
  252. mapping = {
  253. list: PointerList,
  254. dict: PointerDict,
  255. collections.abc.Hashable: lambda v: v,
  256. }
  257. pointer_constructor = next(
  258. (pointer for type_, pointer in mapping.items() if isinstance(value, type_)),
  259. None,
  260. )
  261. if pointer_constructor is None:
  262. raise ValueError(
  263. f"Pointer for type {type(value).__name__} is not implemented"
  264. )
  265. return pointer_constructor(self, owner, key, default)