api_protection.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import io
  8. import json
  9. import logging
  10. import random
  11. import time
  12. import typing
  13. from hikkatl.tl import functions
  14. from hikkatl.tl.tlobject import TLRequest
  15. from hikkatl.tl.types import Message
  16. from hikkatl.utils import is_list_like
  17. from .. import loader, utils
  18. from ..inline.types import InlineCall
  19. from ..web.debugger import WebDebugger
  20. logger = logging.getLogger(__name__)
  21. GROUPS = [
  22. "auth",
  23. "account",
  24. "users",
  25. "contacts",
  26. "messages",
  27. "updates",
  28. "photos",
  29. "upload",
  30. "help",
  31. "channels",
  32. "bots",
  33. "payments",
  34. "stickers",
  35. "phone",
  36. "langpack",
  37. "folders",
  38. "stats",
  39. ]
  40. CONSTRUCTORS = {
  41. (lambda x: x[0].lower() + x[1:])(
  42. method.__class__.__name__.rsplit("Request", 1)[0]
  43. ): method.CONSTRUCTOR_ID
  44. for method in utils.array_sum(
  45. [
  46. [
  47. method
  48. for method in dir(getattr(functions, group))
  49. if isinstance(method, TLRequest)
  50. ]
  51. for group in GROUPS
  52. ]
  53. )
  54. }
  55. @loader.tds
  56. class APIRatelimiterMod(loader.Module):
  57. """Helps userbot avoid spamming Telegram API"""
  58. strings = {"name": "APILimiter"}
  59. def __init__(self):
  60. self._ratelimiter: typing.List[tuple] = []
  61. self._suspend_until = 0
  62. self._lock = False
  63. self.config = loader.ModuleConfig(
  64. loader.ConfigValue(
  65. "time_sample",
  66. 15,
  67. lambda: self.strings("_cfg_time_sample"),
  68. validator=loader.validators.Integer(minimum=1),
  69. ),
  70. loader.ConfigValue(
  71. "threshold",
  72. 100,
  73. lambda: self.strings("_cfg_threshold"),
  74. validator=loader.validators.Integer(minimum=10),
  75. ),
  76. loader.ConfigValue(
  77. "local_floodwait",
  78. 30,
  79. lambda: self.strings("_cfg_local_floodwait"),
  80. validator=loader.validators.Integer(minimum=10, maximum=3600),
  81. ),
  82. loader.ConfigValue(
  83. "forbidden_methods",
  84. ["joinChannel", "importChatInvite"],
  85. lambda: self.strings("_cfg_forbidden_methods"),
  86. validator=loader.validators.MultiChoice(
  87. [
  88. "sendReaction",
  89. "joinChannel",
  90. "importChatInvite",
  91. ]
  92. ),
  93. on_change=lambda: self._client.forbid_constructors(
  94. map(
  95. lambda x: CONSTRUCTORS[x],
  96. self.config["forbidden_constructors"],
  97. )
  98. ),
  99. ),
  100. )
  101. async def client_ready(self):
  102. asyncio.ensure_future(self._install_protection())
  103. async def _install_protection(self):
  104. await asyncio.sleep(30) # Restart lock
  105. if hasattr(self._client._call, "_old_call_rewritten"):
  106. raise loader.SelfUnload("Already installed")
  107. old_call = self._client._call
  108. async def new_call(
  109. sender: "MTProtoSender", # type: ignore # noqa: F821
  110. request: TLRequest,
  111. ordered: bool = False,
  112. flood_sleep_threshold: int = None,
  113. ):
  114. await asyncio.sleep(random.randint(1, 5) / 100)
  115. req = (request,) if not is_list_like(request) else request
  116. for r in req:
  117. if (
  118. time.perf_counter() > self._suspend_until
  119. and not self.get(
  120. "disable_protection",
  121. True,
  122. )
  123. and (
  124. r.__module__.rsplit(".", maxsplit=1)[1]
  125. in {"messages", "account", "channels"}
  126. )
  127. ):
  128. request_name = type(r).__name__
  129. self._ratelimiter += [(request_name, time.perf_counter())]
  130. self._ratelimiter = list(
  131. filter(
  132. lambda x: time.perf_counter() - x[1]
  133. < int(self.config["time_sample"]),
  134. self._ratelimiter,
  135. )
  136. )
  137. if (
  138. len(self._ratelimiter) > int(self.config["threshold"])
  139. and not self._lock
  140. ):
  141. self._lock = True
  142. report = io.BytesIO(
  143. json.dumps(
  144. self._ratelimiter,
  145. indent=4,
  146. ).encode()
  147. )
  148. report.name = "local_fw_report.json"
  149. await self.inline.bot.send_document(
  150. self.tg_id,
  151. report,
  152. caption=self.inline.sanitise_text(
  153. self.strings("warning").format(
  154. self.config["local_floodwait"],
  155. prefix=utils.escape_html(self.get_prefix()),
  156. )
  157. ),
  158. )
  159. # It is intented to use time.sleep instead of asyncio.sleep
  160. time.sleep(int(self.config["local_floodwait"]))
  161. self._lock = False
  162. return await old_call(sender, request, ordered, flood_sleep_threshold)
  163. self._client._call = new_call
  164. self._client._old_call_rewritten = old_call
  165. self._client._call._hikka_overwritten = True
  166. logger.debug("Successfully installed ratelimiter")
  167. async def on_unload(self):
  168. if hasattr(self._client, "_old_call_rewritten"):
  169. self._client._call = self._client._old_call_rewritten
  170. delattr(self._client, "_old_call_rewritten")
  171. logger.debug("Successfully uninstalled ratelimiter")
  172. @loader.command()
  173. async def suspend_api_protect(self, message: Message):
  174. if not (args := utils.get_args_raw(message)) or not args.isdigit():
  175. await utils.answer(message, self.strings("args_invalid"))
  176. return
  177. self._suspend_until = time.perf_counter() + int(args)
  178. await utils.answer(message, self.strings("suspended_for").format(args))
  179. @loader.command()
  180. async def api_fw_protection(self, message: Message):
  181. await self.inline.form(
  182. message=message,
  183. text=self.strings("u_sure"),
  184. reply_markup=[
  185. {"text": self.strings("btn_no"), "action": "close"},
  186. {"text": self.strings("btn_yes"), "callback": self._finish},
  187. ],
  188. )
  189. @property
  190. def _debugger(self) -> WebDebugger:
  191. return logging.getLogger().handlers[0].web_debugger
  192. async def _show_pin(self, call: InlineCall):
  193. await call.answer(f"Werkzeug PIN: {self._debugger.pin}", show_alert=True)
  194. @loader.command()
  195. async def debugger(self, message: Message):
  196. if not self._debugger:
  197. await utils.answer(message, self.strings("debugger_disabled"))
  198. return
  199. await self.inline.form(
  200. message=message,
  201. text=self.strings("web_pin"),
  202. reply_markup=[
  203. [
  204. {
  205. "text": self.strings("web_pin_btn"),
  206. "callback": self._show_pin,
  207. }
  208. ],
  209. [
  210. {"text": self.strings("proxied_url"), "url": self._debugger.url},
  211. {
  212. "text": self.strings("local_url"),
  213. "url": f"http://127.0.0.1:{self._debugger.port}",
  214. },
  215. ],
  216. ],
  217. )
  218. async def _finish(self, call: InlineCall):
  219. state = self.get("disable_protection", True)
  220. self.set("disable_protection", not state)
  221. await call.edit(self.strings("on" if state else "off"))