pointers.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import typing
  7. class PointerList(list):
  8. """Pointer to list saved in database"""
  9. def __init__(
  10. self,
  11. db: "Database", # type: ignore # noqa: F821
  12. module: str,
  13. key: str,
  14. default: typing.Optional[typing.Any] = None,
  15. ):
  16. self._db = db
  17. self._module = module
  18. self._key = key
  19. self._default = default
  20. super().__init__(db.get(module, key, default))
  21. @property
  22. def data(self) -> list:
  23. return list(self)
  24. @data.setter
  25. def data(self, value: list):
  26. self.clear()
  27. self.extend(value)
  28. self._save()
  29. def __repr__(self):
  30. return f"PointerList({list(self)})"
  31. def __str__(self):
  32. return f"PointerList({list(self)})"
  33. def __delitem__(self, __i: typing.Union[typing.SupportsIndex, slice]) -> None:
  34. a = super().__delitem__(__i)
  35. self._save()
  36. return a
  37. def __setitem__(
  38. self,
  39. __i: typing.Union[typing.SupportsIndex, slice],
  40. __v: typing.Any,
  41. ) -> None:
  42. a = super().__setitem__(__i, __v)
  43. self._save()
  44. return a
  45. def __iadd__(self, __x: typing.Iterable) -> "Self": # type: ignore # noqa: F821
  46. a = super().__iadd__(__x)
  47. self._save()
  48. return a
  49. def __imul__(self, __x: int) -> "Self": # type: ignore # noqa: F821
  50. a = super().__imul__(__x)
  51. self._save()
  52. return a
  53. def append(self, value: typing.Any):
  54. super().append(value)
  55. self._save()
  56. def extend(self, value: typing.Iterable):
  57. super().extend(value)
  58. self._save()
  59. def insert(self, index: int, value: typing.Any):
  60. super().insert(index, value)
  61. self._save()
  62. def remove(self, value: typing.Any):
  63. super().remove(value)
  64. self._save()
  65. def pop(self, index: int = -1) -> typing.Any:
  66. a = super().pop(index)
  67. self._save()
  68. return a
  69. def clear(self) -> None:
  70. super().clear()
  71. self._save()
  72. def _save(self):
  73. self._db.set(self._module, self._key, list(self))
  74. def tolist(self):
  75. return self._db.get(self._module, self._key, self._default)
  76. class PointerDict(dict):
  77. """Pointer to dict saved in database"""
  78. def __init__(
  79. self,
  80. db: "Database", # type: ignore # noqa: F821
  81. module: str,
  82. key: str,
  83. default: typing.Optional[typing.Any] = None,
  84. ):
  85. self._db = db
  86. self._module = module
  87. self._key = key
  88. self._default = default
  89. super().__init__(db.get(module, key, default))
  90. @property
  91. def data(self) -> dict:
  92. return dict(self)
  93. @data.setter
  94. def data(self, value: dict):
  95. self.clear()
  96. self.update(value)
  97. self._save()
  98. def __repr__(self):
  99. return f"PointerDict({dict(self)})"
  100. def __bool__(self) -> bool:
  101. return bool(self._db.get(self._module, self._key, self._default))
  102. def __setitem__(self, key: str, value: typing.Any):
  103. super().__setitem__(key, value)
  104. self._save()
  105. def __delitem__(self, key: str):
  106. super().__delitem__(key)
  107. self._save()
  108. def __str__(self):
  109. return f"PointerDict({dict(self)})"
  110. def update(self, __m: dict) -> None:
  111. super().update(__m)
  112. self._save()
  113. def setdefault(self, key: str, default: typing.Any = None) -> typing.Any:
  114. a = super().setdefault(key, default)
  115. self._save()
  116. return a
  117. def pop(self, key: str, default: typing.Any = None) -> typing.Any:
  118. a = super().pop(key, default)
  119. self._save()
  120. return a
  121. def popitem(self) -> tuple:
  122. a = super().popitem()
  123. self._save()
  124. return a
  125. def clear(self) -> None:
  126. super().clear()
  127. self._save()
  128. def _save(self):
  129. self._db.set(self._module, self._key, dict(self))
  130. def todict(self):
  131. return self._db.get(self._module, self._key, self._default)
  132. class BaseSerializingMiddlewareDict:
  133. def __init__(self, pointer: PointerDict):
  134. self._pointer = pointer
  135. def serialize(self, item: typing.Any) -> "JSONSerializable": # type: ignore # noqa: F821
  136. raise NotImplementedError
  137. def deserialize(self, item: "JSONSerializable") -> typing.Any: # type: ignore # noqa: F821
  138. raise NotImplementedError
  139. def __getitem__(self, key: typing.Any) -> typing.Any:
  140. return self.deserialize(self._pointer[key])
  141. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  142. self._pointer[key] = self.serialize(value)
  143. def __delitem__(self, key: typing.Any) -> None:
  144. del self._pointer[key]
  145. def __iter__(self) -> typing.Iterator[typing.Any]:
  146. for key, value in self._pointer.items():
  147. yield (key, self.deserialize(value))
  148. def __len__(self) -> int:
  149. return len(self._pointer)
  150. def __contains__(self, item: typing.Any) -> bool:
  151. return item in self._pointer
  152. def __str__(self) -> str:
  153. return f"{self.__class__.__name__}({self._pointer})"
  154. def __repr__(self) -> str:
  155. return f"{self.__class__.__name__}({self._pointer})"
  156. def pop(self, key: typing.Any) -> typing.Any:
  157. return self.deserialize(self._pointer.pop(key))
  158. def popitem(self) -> typing.Any:
  159. return self.deserialize(self._pointer.popitem())
  160. def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  161. return self.deserialize(self._pointer[key]) if key in self._pointer else default
  162. def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  163. return self.deserialize(self._pointer.setdefault(key, self.serialize(default)))
  164. def clear(self) -> None:
  165. self._pointer.clear()
  166. def todict(self) -> dict:
  167. return {
  168. key: self.deserialize(value) for key, value in self._pointer.data.items()
  169. }
  170. def keys(self) -> typing.KeysView:
  171. return self._pointer.keys()
  172. def values(self) -> typing.Iterable[typing.Any]:
  173. return (self.deserialize(value) for value in self._pointer.values())
  174. class BaseSerializingMiddlewareList:
  175. def __init__(self, pointer: PointerList):
  176. self._pointer = pointer
  177. def serialize(self, item: typing.Any) -> "JSONSerializable": # type: ignore # noqa: F821
  178. raise NotImplementedError
  179. def deserialize(self, item: "JSONSerializable") -> typing.Any: # type: ignore # noqa: F821
  180. raise NotImplementedError
  181. def remove(self, item: typing.Any) -> None:
  182. self._pointer.remove(self.serialize(item))
  183. def pop(self, index: int) -> typing.Any:
  184. return self.deserialize(self._pointer.pop(index))
  185. def insert(self, index: int, item: typing.Any) -> None:
  186. self._pointer.insert(index, self.serialize(item))
  187. def append(self, item: typing.Any) -> None:
  188. self._pointer.append(self.serialize(item))
  189. def extend(self, items: typing.Iterable[typing.Any]) -> None:
  190. self._pointer.extend([self.serialize(item) for item in items])
  191. def __getitem__(self, key: typing.Any) -> typing.Any:
  192. return self.deserialize(self._pointer[key])
  193. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  194. self._pointer[key] = self.serialize(value)
  195. def __delitem__(self, key: typing.Any) -> None:
  196. del self._pointer[key]
  197. def __iter__(self) -> typing.Iterator[typing.Any]:
  198. return (self.deserialize(item) for item in self._pointer)
  199. def __len__(self) -> int:
  200. return len(self._pointer)
  201. def __contains__(self, item: typing.Any) -> bool:
  202. return self.serialize(item) in self._pointer
  203. def __reversed__(self) -> typing.Iterator[typing.Any]:
  204. return (self.deserialize(item) for item in reversed(self._pointer))
  205. def __str__(self) -> str:
  206. return f"{self.__class__.__name__}({self._pointer})"
  207. def __repr__(self) -> str:
  208. return f"{self.__class__.__name__}({self._pointer})"
  209. def tolist(self) -> list:
  210. return [self.deserialize(item) for item in self._pointer.data]
  211. class NamedTupleMiddlewareList(BaseSerializingMiddlewareList):
  212. def __init__(self, pointer: PointerList, item_type: typing.Type[typing.Any]):
  213. super().__init__(pointer)
  214. self._item_type = item_type
  215. def serialize(self, item: typing.Any) -> "JSONSerializable": # type: ignore # noqa: F821
  216. return item._asdict()
  217. def deserialize(self, item: "JSONSerializable") -> typing.Any: # type: ignore # noqa: F821
  218. return self._item_type(**item)
  219. class NamedTupleMiddlewareDict(BaseSerializingMiddlewareDict):
  220. def __init__(self, pointer: PointerList, item_type: typing.Type[typing.Any]):
  221. super().__init__(pointer)
  222. self._item_type = item_type
  223. def serialize(self, item: typing.Any) -> "JSONSerializable": # type: ignore # noqa: F821
  224. return item._asdict()
  225. def deserialize(self, item: "JSONSerializable") -> typing.Any: # type: ignore # noqa: F821
  226. return self._item_type(**item)