throttling.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import asyncio
  2. from aiogram import Dispatcher, types
  3. from aiogram.dispatcher import DEFAULT_RATE_LIMIT
  4. from aiogram.dispatcher.handler import CancelHandler, current_handler
  5. from aiogram.dispatcher.middlewares import BaseMiddleware
  6. from aiogram.utils.exceptions import Throttled
  7. class ThrottlingMiddleware(BaseMiddleware):
  8. """
  9. Simple middleware
  10. """
  11. def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix="antiflood_"):
  12. self.rate_limit = limit
  13. self.prefix = key_prefix
  14. super(ThrottlingMiddleware, self).__init__()
  15. async def on_process_message(self, message: types.Message, data: dict):
  16. """
  17. This handler is called when dispatcher receives a message
  18. :param message:
  19. """
  20. # Get current handler
  21. handler = current_handler.get()
  22. # Get dispatcher from context
  23. dispatcher = Dispatcher.get_current()
  24. # If handler was configured, get rate limit and key from handler
  25. if handler:
  26. limit = getattr(handler, "throttling_rate_limit", self.rate_limit)
  27. key = getattr(
  28. handler, "throttling_key", f"{self.prefix}_{handler.__name__}"
  29. )
  30. else:
  31. limit = self.rate_limit
  32. key = f"{self.prefix}_message"
  33. # Use Dispatcher.throttle method.
  34. try:
  35. await dispatcher.throttle(key, rate=limit)
  36. except Throttled as t:
  37. # Execute action
  38. await self.message_throttled(message, t)
  39. # Cancel current handler
  40. raise CancelHandler()
  41. async def message_throttled(self, message: types.Message, throttled: Throttled):
  42. """
  43. Notify user only on first exceed and notify about unlocking only on last exceed
  44. :param message:
  45. :param throttled:
  46. """
  47. handler = current_handler.get()
  48. dispatcher = Dispatcher.get_current()
  49. if handler:
  50. key = getattr(
  51. handler, "throttling_key", f"{self.prefix}_{handler.__name__}"
  52. )
  53. else:
  54. key = f"{self.prefix}_message"
  55. # Calculate how many time is left till the block ends
  56. delta = throttled.rate - throttled.delta
  57. # Prevent flooding
  58. if throttled.exceeded_count <= 2:
  59. await message.reply("Too many requests! ")
  60. # Sleep.
  61. await asyncio.sleep(delta)
  62. # Check lock status
  63. thr = await dispatcher.check_key(key)
  64. # If current message is not last with current key - do not send message
  65. if thr.exceeded_count == throttled.exceeded_count:
  66. await message.reply("Unlocked.")