database.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀
  2. # █▀█ █ █ █ █▀█ █▀▄ █
  3. # © Copyright 2022
  4. # https://t.me/hikariatama
  5. #
  6. # 🔒 Licensed under the GNU AGPLv3
  7. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  8. import contextlib
  9. import json
  10. import logging
  11. import os
  12. import time
  13. import asyncio
  14. try:
  15. import psycopg2
  16. except ImportError as e:
  17. if "DYNO" in os.environ:
  18. raise e
  19. try:
  20. import redis
  21. except ImportError as e:
  22. if "DYNO" in os.environ:
  23. raise e
  24. from typing import Any, Union
  25. from telethon.tl.types import Message
  26. from telethon.errors.rpcerrorlist import ChannelsTooMuchError
  27. from . import utils, main
  28. DATA_DIR = (
  29. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  30. if "OKTETO" not in os.environ and "DOCKER" not in os.environ
  31. else "/data"
  32. )
  33. logger = logging.getLogger(__name__)
  34. class NoAssetsChannel(Exception):
  35. """Raised when trying to read/store asset with no asset channel present"""
  36. class Database(dict):
  37. _next_revision_call = 0
  38. _revisions = []
  39. _assets = None
  40. _me = None
  41. _postgre = None
  42. _redis = None
  43. _saving_task = None
  44. def __init__(self, client):
  45. super().__init__()
  46. self._client = client
  47. def __repr__(self):
  48. return object.__repr__(self)
  49. def _postgre_save_sync(self):
  50. with self._postgre, self._postgre.cursor() as cur:
  51. cur.execute(
  52. "UPDATE hikka SET data = %s WHERE id = %s;",
  53. (json.dumps(self), self._client.tg_id),
  54. )
  55. def _redis_save_sync(self):
  56. with self._redis.pipeline() as pipe:
  57. pipe.set(
  58. str(self._client.tg_id),
  59. json.dumps(self, ensure_ascii=True),
  60. )
  61. pipe.execute()
  62. async def remote_force_save(self) -> bool:
  63. """Force save database to remote endpoint without waiting"""
  64. if not self._postgre and not self._redis:
  65. return False
  66. if self._redis:
  67. await utils.run_sync(self._redis_save_sync)
  68. logger.debug("Published db to Redis")
  69. else:
  70. await utils.run_sync(self._postgre_save_sync)
  71. logger.debug("Published db to PostgreSQL")
  72. return True
  73. async def _postgre_save(self) -> bool:
  74. """Save database to postgresql"""
  75. if not self._postgre:
  76. return False
  77. await asyncio.sleep(5)
  78. await utils.run_sync(self._postgre_save_sync)
  79. logger.debug("Published db to PostgreSQL")
  80. self._saving_task = None
  81. return True
  82. async def _redis_save(self) -> bool:
  83. """Save database to redis"""
  84. if not self._redis:
  85. return False
  86. await asyncio.sleep(5)
  87. await utils.run_sync(self._redis_save_sync)
  88. logger.debug("Published db to Redis")
  89. self._saving_task = None
  90. return True
  91. async def postgre_init(self) -> bool:
  92. """Init postgresql database"""
  93. POSTGRE_URI = os.environ.get("DATABASE_URL") or main.get_config_key(
  94. "postgre_uri"
  95. )
  96. if not POSTGRE_URI:
  97. return False
  98. conn = psycopg2.connect(POSTGRE_URI, sslmode="require")
  99. with conn, conn.cursor() as cur:
  100. cur.execute("CREATE TABLE IF NOT EXISTS hikka (id bigint, data text);")
  101. with contextlib.suppress(Exception):
  102. cur.execute(
  103. "SELECT EXISTS(SELECT 1 FROM hikka WHERE id=%s);",
  104. (self._client.tg_id,),
  105. )
  106. if not cur.fetchone()[0]:
  107. cur.execute(
  108. "INSERT INTO hikka (id, data) VALUES (%s, %s);",
  109. (self._client.tg_id, json.dumps(self)),
  110. )
  111. with contextlib.suppress(Exception):
  112. cur.execute(
  113. "SELECT (column_name, data_type) "
  114. "FROM information_schema.columns "
  115. "WHERE table_name = 'hikka' AND column_name = 'id';"
  116. )
  117. if "integer" in cur.fetchone()[0].lower():
  118. logger.warning(
  119. "Made legacy migration from integer to bigint "
  120. "in postgresql database"
  121. )
  122. cur.execute("ALTER TABLE hikka ALTER COLUMN id TYPE bigint;")
  123. self._postgre = conn
  124. async def redis_init(self) -> bool:
  125. """Init redis database"""
  126. if REDIS_URI := os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  127. self._redis = redis.Redis.from_url(REDIS_URI)
  128. else:
  129. return False
  130. async def init(self):
  131. """Asynchronous initialization unit"""
  132. if os.environ.get("REDIS_URL") or main.get_config_key("redis_uri"):
  133. await self.redis_init()
  134. elif os.environ.get("DATABASE_URL") or main.get_config_key("postgre_uri"):
  135. await self.postgre_init()
  136. self._db_path = os.path.join(DATA_DIR, f"config-{self._client.tg_id}.json")
  137. self.read()
  138. try:
  139. self._assets, _ = await utils.asset_channel(
  140. self._client,
  141. "hikka-assets",
  142. "🌆 Your Hikka assets will be stored here",
  143. archive=True,
  144. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  145. )
  146. except ChannelsTooMuchError:
  147. self._assets = None
  148. logger.error(
  149. "Can't find and/or create assets folder\n"
  150. "This may cause several consequences, such as:\n"
  151. "- Non working assets feature (e.g. notes)\n"
  152. "- This error will occur every restart\n\n"
  153. "You can solve this by leaving some channels/groups"
  154. )
  155. def read(self):
  156. """Read database and stores it in self"""
  157. if self._redis:
  158. try:
  159. self.update(
  160. **json.loads(
  161. self._redis.get(
  162. str(self._client.tg_id),
  163. ).decode(),
  164. )
  165. )
  166. except Exception:
  167. logger.exception("Error reading redis database")
  168. return
  169. if self._postgre:
  170. try:
  171. with self._postgre, self._postgre.cursor() as cur:
  172. cur.execute(
  173. "SELECT data FROM hikka WHERE id=%s;",
  174. (self._client.tg_id,),
  175. )
  176. self.update(
  177. **json.loads(
  178. cur.fetchall()[0][0],
  179. ),
  180. )
  181. except Exception:
  182. logger.exception("Error reading postgresql database")
  183. return
  184. try:
  185. with open(self._db_path, "r", encoding="utf-8") as f:
  186. self.update(**json.load(f))
  187. except (FileNotFoundError, json.decoder.JSONDecodeError):
  188. logger.warning("Database read failed! Creating new one...")
  189. def process_db_autofix(self, db: dict) -> bool:
  190. if not utils.is_serializable(db):
  191. return False
  192. for key, value in db.copy().items():
  193. if not isinstance(key, (str, int)):
  194. logger.warning(
  195. f"DbAutoFix: Dropped {key=} , because it is not string or int"
  196. )
  197. continue
  198. if not isinstance(value, dict):
  199. # If value is not a dict (module values), drop it,
  200. # otherwise it may cause problems
  201. del db[key]
  202. logger.warning(
  203. f"DbAutoFix: Dropped {key=}, because it is non-dict {type(value)=}"
  204. )
  205. continue
  206. for subkey in value:
  207. if not isinstance(subkey, (str, int)):
  208. del db[key][subkey]
  209. logger.warning(
  210. f"DbAutoFix: Dropped {subkey=} of db[{key}], because it is not"
  211. " string or int"
  212. )
  213. continue
  214. return True
  215. def save(self) -> bool:
  216. """Save database"""
  217. if not self.process_db_autofix(self):
  218. try:
  219. rev = self._revisions.pop()
  220. while not self.process_db_autofix(rev):
  221. rev = self._revisions.pop()
  222. except IndexError:
  223. raise RuntimeError(
  224. "Can't find revision to restore broken database from "
  225. "database is most likely broken and will lead to problems, "
  226. "so its save is forbidden."
  227. )
  228. self.clear()
  229. self.update(**rev)
  230. raise RuntimeError(
  231. "Rewriting database to the last revision because new one destructed it"
  232. )
  233. if self._next_revision_call < time.time():
  234. self._revisions += [dict(self)]
  235. self._next_revision_call = time.time() + 3
  236. while len(self._revisions) > 15:
  237. self._revisions.pop()
  238. if self._redis:
  239. if not self._saving_task:
  240. self._saving_task = asyncio.ensure_future(self._redis_save())
  241. return True
  242. if self._postgre:
  243. if not self._saving_task:
  244. self._saving_task = asyncio.ensure_future(self._postgre_save())
  245. return True
  246. try:
  247. with open(self._db_path, "w", encoding="utf-8") as f:
  248. json.dump(self, f, indent=4)
  249. except Exception:
  250. logger.exception("Database save failed!")
  251. return False
  252. return True
  253. async def store_asset(self, message: Message) -> int:
  254. """
  255. Save assets
  256. returns asset_id as integer
  257. """
  258. if not self._assets:
  259. raise NoAssetsChannel("Tried to save asset to non-existing asset channel")
  260. return (
  261. (await self._client.send_message(self._assets, message)).id
  262. if isinstance(message, Message)
  263. else (
  264. await self._client.send_message(
  265. self._assets,
  266. file=message,
  267. force_document=True,
  268. )
  269. ).id
  270. )
  271. async def fetch_asset(self, asset_id: int) -> Union[None, Message]:
  272. """Fetch previously saved asset by its asset_id"""
  273. if not self._assets:
  274. raise NoAssetsChannel(
  275. "Tried to fetch asset from non-existing asset channel"
  276. )
  277. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  278. return asset[0] if asset else None
  279. def get(self, owner: str, key: str, default: Any = None) -> Any:
  280. """Get database key"""
  281. try:
  282. return self[owner][key]
  283. except KeyError:
  284. return default
  285. def set(self, owner: str, key: str, value: Any) -> bool:
  286. """Set database key"""
  287. if not utils.is_serializable(owner):
  288. raise RuntimeError(
  289. "Attempted to write object to "
  290. f"{owner=} ({type(owner)=}) of database. It is not "
  291. "JSON-serializable key which will cause errors"
  292. )
  293. if not utils.is_serializable(key):
  294. raise RuntimeError(
  295. "Attempted to write object to "
  296. f"{key=} ({type(key)=}) of database. It is not "
  297. "JSON-serializable key which will cause errors"
  298. )
  299. if not utils.is_serializable(value):
  300. raise RuntimeError(
  301. "Attempted to write object of "
  302. f"{key=} ({type(value)=}) to database. It is not "
  303. "JSON-serializable value which will cause errors"
  304. )
  305. super().setdefault(owner, {})[key] = value
  306. return self.save()