database.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # █ █ ▀ █▄▀ ▄▀█ █▀█ ▀ ▄▀█ ▀█▀ ▄▀█ █▀▄▀█ ▄▀█
  2. # █▀█ █ █ █ █▀█ █▀▄ █ ▄ █▀█ █ █▀█ █ ▀ █ █▀█
  3. #
  4. # © Copyright 2022
  5. #
  6. # https://t.me/hikariatama
  7. #
  8. # 🔒 Licensed under the GNU GPLv3
  9. # 🌐 https://www.gnu.org/licenses/agpl-3.0.html
  10. import json
  11. import logging
  12. import os
  13. import time
  14. from typing import Any, Union
  15. from telethon.tl.functions.channels import EditTitleRequest
  16. from telethon.tl.types import Message
  17. from telethon.errors.rpcerrorlist import ChannelsTooMuchError
  18. from . import utils
  19. DATA_DIR = (
  20. os.path.normpath(os.path.join(utils.get_base_dir(), ".."))
  21. if "OKTETO" not in os.environ
  22. else "/data"
  23. )
  24. logger = logging.getLogger(__name__)
  25. class NoAssetsChannel(Exception):
  26. """Raised when trying to read/store asset with no asset channel present"""
  27. class Database(dict):
  28. _next_revision_call = 0
  29. _revisions = []
  30. _assets = None
  31. _me = None
  32. def __init__(self, client):
  33. super().__init__()
  34. self._client = client
  35. def __repr__(self):
  36. return object.__repr__(self)
  37. async def init(self):
  38. """Asynchronous initialization unit"""
  39. self._me = await self._client.get_me()
  40. self._db_path = os.path.join(DATA_DIR, f"config-{self._me.id}.json")
  41. self.read()
  42. try:
  43. channel_entity = await (
  44. dialog.entity
  45. async for dialog in self._client.iter_dialogs(
  46. None,
  47. ignore_migrated=True,
  48. )
  49. if (
  50. dialog.name in {f"hikka-{self._me.id}-assets", "hikka-assets"}
  51. and dialog.is_channel
  52. and dialog.entity.participants_count == 1
  53. )
  54. ).__anext__()
  55. if channel_entity.title != "hikka-assets":
  56. await self._client(EditTitleRequest(channel_entity, "hikka-assets"))
  57. await utils.set_avatar(
  58. self._client,
  59. channel_entity,
  60. "https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  61. )
  62. logger.info("Made legacy assets migration")
  63. except Exception:
  64. pass
  65. try:
  66. self._assets, _ = await utils.asset_channel(
  67. self._client,
  68. "hikka-assets",
  69. "🌆 Your Hikka assets will be stored here",
  70. archive=True,
  71. avatar="https://raw.githubusercontent.com/hikariatama/assets/master/hikka-assets.png",
  72. )
  73. except ChannelsTooMuchError:
  74. self._assets = None
  75. logger.critical(
  76. "Can't find and/or create assets folder\n"
  77. "This may cause several consequences, such as:\n"
  78. "- Non working assets feature (e.g. notes)\n"
  79. "- This error will occur every restart\n\n"
  80. "You can solve this by leaving some channels/groups"
  81. )
  82. def read(self) -> str:
  83. """Read database"""
  84. try:
  85. with open(self._db_path, "r", encoding="utf-8") as f:
  86. data = json.loads(f.read())
  87. self.update(**data)
  88. return data
  89. except (FileNotFoundError, json.decoder.JSONDecodeError):
  90. logger.warning("Database read failed! Creating new one...")
  91. return {}
  92. def process_db_autofix(self, db: dict) -> bool:
  93. if not utils.is_serializable(db):
  94. return False
  95. for key, value in db.copy().items():
  96. if not isinstance(key, (str, int)):
  97. logger.warning(f"DbAutoFix: Dropped {key=} , because it is not string or int") # fmt: skip
  98. continue
  99. if not isinstance(value, dict):
  100. # If value is not a dict (module values), drop it,
  101. # otherwise it may cause problems
  102. del db[key]
  103. logger.warning(f"DbAutoFix: Dropped {key=}, because it is non-dict {type(value)=}") # fmt: skip
  104. continue
  105. for subkey in value:
  106. if not isinstance(subkey, (str, int)):
  107. del db[key][subkey]
  108. logger.warning(f"DbAutoFix: Dropped {subkey=} of db[{key}], because it is not string or int") # fmt: skip
  109. continue
  110. return True
  111. def save(self) -> bool:
  112. """Save database"""
  113. if not self.process_db_autofix(self):
  114. try:
  115. rev = self._revisions.pop()
  116. while not self.process_db_autofix(rev):
  117. rev = self._revisions.pop()
  118. except IndexError:
  119. raise RuntimeError(
  120. "Can't find revision to restore broken database from "
  121. "database is most likely broken and will lead to problems, "
  122. "so its save is forbidden."
  123. )
  124. self.clear()
  125. self.update(**rev)
  126. raise RuntimeError(
  127. "Rewriting database to the last revision "
  128. "because new one destructed it"
  129. )
  130. if self._next_revision_call < time.time():
  131. self._revisions += [dict(self)]
  132. self._next_revision_call = time.time() + 3
  133. while len(self._revisions) > 15:
  134. self._revisions.pop()
  135. try:
  136. with open(self._db_path, "w", encoding="utf-8") as f:
  137. f.write(json.dumps(self))
  138. except Exception:
  139. logger.exception("Database save failed!")
  140. return False
  141. return True
  142. async def store_asset(self, message: Message) -> int:
  143. """
  144. Save assets
  145. returns asset_id as integer
  146. """
  147. if not self._assets:
  148. raise NoAssetsChannel("Tried to save asset to non-existing asset channel") # fmt: skip
  149. return (
  150. (await self._client.send_message(self._assets, message)).id
  151. if isinstance(message, Message)
  152. else (
  153. await self._client.send_message(
  154. self._assets,
  155. file=message,
  156. force_document=True,
  157. )
  158. ).id
  159. )
  160. async def fetch_asset(self, asset_id: int) -> Union[None, Message]:
  161. """Fetch previously saved asset by its asset_id"""
  162. if not self._assets:
  163. raise NoAssetsChannel("Tried to fetch asset from non-existing asset channel") # fmt: skip
  164. asset = await self._client.get_messages(self._assets, ids=[asset_id])
  165. if not asset:
  166. return None
  167. return asset[0]
  168. def get(self, owner: str, key: str, default: Any = None) -> Any:
  169. """Get database key"""
  170. try:
  171. return self[owner][key]
  172. except KeyError:
  173. return default
  174. def set(self, owner: str, key: str, value: Any) -> bool:
  175. """Set database key"""
  176. if not utils.is_serializable(owner):
  177. raise RuntimeError(
  178. "Attempted to write object to "
  179. f"{owner=} ({type(owner)=}) of database. It is not "
  180. "JSON-serializable key which will cause errors"
  181. )
  182. if not utils.is_serializable(key):
  183. raise RuntimeError(
  184. "Attempted to write object to "
  185. f"{key=} ({type(key)=}) of database. It is not "
  186. "JSON-serializable key which will cause errors"
  187. )
  188. if not utils.is_serializable(value):
  189. raise RuntimeError(
  190. "Attempted to write object of "
  191. f"{key=} ({type(value)=}) to database. It is not "
  192. "JSON-serializable value which will cause errors"
  193. )
  194. super().setdefault(owner, {})[key] = value
  195. return self.save()
  196. def __setitem__(self, key: str, value: dict) -> bool:
  197. if not isinstance(value, dict):
  198. raise RuntimeError("Attempted to write non-dict value in a first layer of database") # fmt: skip
  199. dict.__setitem__(self, key, value)
  200. return self.save()
  201. def __delitem__(self, key: str) -> bool:
  202. dict.__delitem__(self, key)
  203. return self.save()