database.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import contextlib
  9. import json
  10. import logging
  11. import os
  12. import time
  13. import asyncio
  14. import collections
  15. try:
  16. import psycopg2
  17. except ImportError as e:
  18. if "DYNO" in os.environ:
  19. raise e
  20. try:
  21. import redis
  22. except ImportError as e:
  23. if "DYNO" in os.environ or "RAILWAY" in os.environ:
  24. raise e
  25. from typing import Optional, Union
  26. from telethon.tl.types import Message
  27. from telethon.errors.rpcerrorlist import ChannelsTooMuchError
  28. from . import utils, main
  29. from .pointers import (
  30. PointerList,
  31. PointerDict,
  32. )
  33. from .types import JSONSerializable
  34. from .tl_cache import CustomTelegramClient
  35. DATA_DIR = (
  36. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  37. if "OKTETO" not in os.environ and "DOCKER" not in os.environ
  38. else "/data"
  39. )
  40. logger = logging.getLogger(__name__)
  41. class NoAssetsChannel(Exception):
  42. """Raised when trying to read/store asset with no asset channel present"""
  43. class Database(dict):
  44. _next_revision_call = 0
  45. _revisions = []
  46. _assets = None
  47. _me = None
  48. _postgre = None
  49. _redis = None
  50. _saving_task = None
  51. def __init__(self, client: CustomTelegramClient):
  52. super().__init__()
  53. self._client = client
  54. def __repr__(self):
  55. return object.__repr__(self)
  56. def _postgre_save_sync(self):
  57. with self._postgre, self._postgre.cursor() as cur:
  58. cur.execute(
  59. "UPDATE hikka SET data = %s WHERE id = %s;",
  60. (json.dumps(self), self._client.tg_id),
  61. )
  62. def _redis_save_sync(self):
  63. with self._redis.pipeline() as pipe:
  64. pipe.set(
  65. str(self._client.tg_id),
  66. json.dumps(self, ensure_ascii=True),
  67. )
  68. pipe.execute()
  69. async def remote_force_save(self) -> bool:
  70. """Force save database to remote endpoint without waiting"""
  71. if not self._postgre and not self._redis:
  72. return False
  73. if self._redis:
  74. await utils.run_sync(self._redis_save_sync)
  75. logger.debug("Published db to Redis")
  76. else:
  77. await utils.run_sync(self._postgre_save_sync)
  78. logger.debug("Published db to PostgreSQL")
  79. return True
  80. async def _postgre_save(self) -> bool:
  81. """Save database to postgresql"""
  82. if not self._postgre:
  83. return False
  84. await asyncio.sleep(5)
  85. await utils.run_sync(self._postgre_save_sync)
  86. logger.debug("Published db to PostgreSQL")
  87. self._saving_task = None
  88. return True
  89. async def _redis_save(self) -> bool:
  90. """Save database to redis"""
  91. if not self._redis:
  92. return False
  93. await asyncio.sleep(5)
  94. await utils.run_sync(self._redis_save_sync)
  95. logger.debug("Published db to Redis")
  96. self._saving_task = None
  97. return True
  98. async def postgre_init(self) -> bool:
  99. """Init postgresql database"""
  100. POSTGRE_URI = os.environ.get("DATABASE_URL") or main.get_config_key(
  101. "postgre_uri"
  102. )
  103. if not POSTGRE_URI:
  104. return False
  105. conn = psycopg2.connect(POSTGRE_URI, sslmode="require")
  106. with conn, conn.cursor() as cur:
  107. cur.execute("CREATE TABLE IF NOT EXISTS hikka (id bigint, data text);")
  108. with contextlib.suppress(Exception):
  109. cur.execute(
  110. "SELECT EXISTS(SELECT 1 FROM hikka WHERE id=%s);",
  111. (self._client.tg_id,),
  112. )
  113. if not cur.fetchone()[0]:
  114. cur.execute(
  115. "INSERT INTO hikka (id, data) VALUES (%s, %s);",
  116. (self._client.tg_id, json.dumps(self)),
  117. )
  118. with contextlib.suppress(Exception):
  119. cur.execute(
  120. "SELECT (column_name, data_type) "
  121. "FROM information_schema.columns "
  122. "WHERE table_name = 'hikka' AND column_name = 'id';"
  123. )
  124. if "integer" in cur.fetchone()[0].lower():
  125. logger.warning(
  126. "Made legacy migration from integer to bigint "
  127. "in postgresql database"
  128. )
  129. cur.execute("ALTER TABLE hikka ALTER COLUMN id TYPE bigint;")
  130. self._postgre = conn
  131. async def redis_init(self) -> bool:
  132. """Init redis database"""
  133. if REDIS_URI := os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  134. self._redis = redis.Redis.from_url(REDIS_URI)
  135. else:
  136. return False
  137. async def init(self):
  138. """Asynchronous initialization unit"""
  139. if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  140. await self.redis_init()
  141. elif os.environ.get("DATABASE_URL") or main.get_config_key("postgre_uri"):
  142. await self.postgre_init()
  143. self._db_path = os.path.join(DATA_DIR, f"config-{self._client.tg_id}.json")
  144. self.read()
  145. try:
  146. self._assets, _ = await utils.asset_channel(
  147. self._client,
  148. "hikka-assets",
  149. "🌆 Your Hikka assets will be stored here",
  150. archive=True,
  151. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  152. )
  153. except ChannelsTooMuchError:
  154. self._assets = None
  155. logger.error(
  156. "Can't find and/or create assets folder\n"
  157. "This may cause several consequences, such as:\n"
  158. "- Non working assets feature (e.g. notes)\n"
  159. "- This error will occur every restart\n\n"
  160. "You can solve this by leaving some channels/groups"
  161. )
  162. def read(self):
  163. """Read database and stores it in self"""
  164. if self._redis:
  165. try:
  166. self.update(
  167. **json.loads(
  168. self._redis.get(
  169. str(self._client.tg_id),
  170. ).decode(),
  171. )
  172. )
  173. except Exception:
  174. logger.exception("Error reading redis database")
  175. return
  176. if self._postgre:
  177. try:
  178. with self._postgre, self._postgre.cursor() as cur:
  179. cur.execute(
  180. "SELECT data FROM hikka WHERE id=%s;",
  181. (self._client.tg_id,),
  182. )
  183. self.update(
  184. **json.loads(
  185. cur.fetchall()[0][0],
  186. ),
  187. )
  188. except Exception:
  189. logger.exception("Error reading postgresql database")
  190. return
  191. try:
  192. with open(self._db_path, "r", encoding="utf-8") as f:
  193. self.update(**json.load(f))
  194. except (FileNotFoundError, json.decoder.JSONDecodeError):
  195. logger.warning("Database read failed! Creating new one...")
  196. def process_db_autofix(self, db: dict) -> bool:
  197. if not utils.is_serializable(db):
  198. return False
  199. for key, value in db.copy().items():
  200. if not isinstance(key, (str, int)):
  201. logger.warning(
  202. f"DbAutoFix: Dropped {key=} , because it is not string or int"
  203. )
  204. continue
  205. if not isinstance(value, dict):
  206. # If value is not a dict (module values), drop it,
  207. # otherwise it may cause problems
  208. del db[key]
  209. logger.warning(
  210. f"DbAutoFix: Dropped {key=}, because it is non-dict {type(value)=}"
  211. )
  212. continue
  213. for subkey in value:
  214. if not isinstance(subkey, (str, int)):
  215. del db[key][subkey]
  216. logger.warning(
  217. f"DbAutoFix: Dropped {subkey=} of db[{key}], because it is not"
  218. " string or int"
  219. )
  220. continue
  221. return True
  222. def save(self) -> bool:
  223. """Save database"""
  224. if not self.process_db_autofix(self):
  225. try:
  226. rev = self._revisions.pop()
  227. while not self.process_db_autofix(rev):
  228. rev = self._revisions.pop()
  229. except IndexError:
  230. raise RuntimeError(
  231. "Can't find revision to restore broken database from "
  232. "database is most likely broken and will lead to problems, "
  233. "so its save is forbidden."
  234. )
  235. self.clear()
  236. self.update(**rev)
  237. raise RuntimeError(
  238. "Rewriting database to the last revision because new one destructed it"
  239. )
  240. if self._next_revision_call < time.time():
  241. self._revisions += [dict(self)]
  242. self._next_revision_call = time.time() + 3
  243. while len(self._revisions) > 15:
  244. self._revisions.pop()
  245. if self._redis:
  246. if not self._saving_task:
  247. self._saving_task = asyncio.ensure_future(self._redis_save())
  248. return True
  249. if self._postgre:
  250. if not self._saving_task:
  251. self._saving_task = asyncio.ensure_future(self._postgre_save())
  252. return True
  253. try:
  254. with open(self._db_path, "w", encoding="utf-8") as f:
  255. json.dump(self, f, indent=4)
  256. except Exception:
  257. logger.exception("Database save failed!")
  258. return False
  259. return True
  260. async def store_asset(self, message: Message) -> int:
  261. """
  262. Save assets
  263. returns asset_id as integer
  264. """
  265. if not self._assets:
  266. raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
  267. return (
  268. (await self._client.send_message(self._assets, message)).id
  269. if isinstance(message, Message)
  270. else (
  271. await self._client.send_message(
  272. self._assets,
  273. file=message,
  274. force_document=True,
  275. )
  276. ).id
  277. )
  278. async def fetch_asset(self, asset_id: int) -> Union[None, Message]:
  279. """Fetch previously saved asset by its asset_id"""
  280. if not self._assets:
  281. raise NoAssetsChannel(
  282. "Tried to fetch asset from non-existing asset channel"
  283. )
  284. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  285. return asset[0] if asset else None
  286. def get(
  287. self,
  288. owner: str,
  289. key: str,
  290. default: Optional[JSONSerializable] = None,
  291. ) -> JSONSerializable:
  292. """Get database key"""
  293. try:
  294. return self[owner][key]
  295. except KeyError:
  296. return default
  297. def set(self, owner: str, key: str, value: JSONSerializable) -> bool:
  298. """Set database key"""
  299. if not utils.is_serializable(owner):
  300. raise RuntimeError(
  301. "Attempted to write object to "
  302. f"{owner=} ({type(owner)=}) of database. It is not "
  303. "JSON-serializable key which will cause errors"
  304. )
  305. if not utils.is_serializable(key):
  306. raise RuntimeError(
  307. "Attempted to write object to "
  308. f"{key=} ({type(key)=}) of database. It is not "
  309. "JSON-serializable key which will cause errors"
  310. )
  311. if not utils.is_serializable(value):
  312. raise RuntimeError(
  313. "Attempted to write object of "
  314. f"{key=} ({type(value)=}) to database. It is not "
  315. "JSON-serializable value which will cause errors"
  316. )
  317. super().setdefault(owner, {})[key] = value
  318. return self.save()
  319. def pointer(
  320. self,
  321. owner: str,
  322. key: str,
  323. default: Optional[JSONSerializable] = None,
  324. ) -> JSONSerializable:
  325. """Get a pointer to database key"""
  326. value = self.get(owner, key, default)
  327. mapping = {
  328. list: PointerList,
  329. dict: PointerDict,
  330. collections.abc.Hashable: lambda v: v,
  331. }
  332. pointer_constructor = next(
  333. (pointer for type_, pointer in mapping.items() if isinstance(value, type_)),
  334. None,
  335. )
  336. if pointer_constructor is None:
  337. raise ValueError(
  338. f"Pointer for type {type(value).__name__} is not implemented"
  339. )
  340. return pointer_constructor(self, owner, key, default)