proxypass.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # ©️ Dan Gazizullin, 2021-2023
  2. # This file is a part of Hikka Userbot
  3. # 🌐 https://github.com/hikariatama/Hikka
  4. # You can redistribute it and/or modify it under the terms of the GNU AGPLv3
  5. # 🔑 https://www.gnu.org/licenses/agpl-3.0.html
  6. import asyncio
  7. import logging
  8. import os
  9. import re
  10. import typing
  11. from .. import utils
  12. logger = logging.getLogger(__name__)
  13. class ProxyPasser:
  14. def __init__(self, change_url_callback: callable = lambda _: None):
  15. self._tunnel_url = None
  16. self._sproc = None
  17. self._url_available = asyncio.Event()
  18. self._url_available.set()
  19. self._lock = asyncio.Lock()
  20. self._change_url_callback = change_url_callback
  21. async def _read_stream(
  22. self,
  23. callback: callable,
  24. stream: typing.BinaryIO,
  25. delay: int,
  26. ) -> None:
  27. for getline in iter(stream.readline, ""):
  28. await asyncio.sleep(delay)
  29. data_chunk = await getline
  30. if await callback(data_chunk.decode("utf-8")):
  31. if not self._url_available.is_set():
  32. self._url_available.set()
  33. def kill(self):
  34. try:
  35. self._sproc.terminate()
  36. except Exception:
  37. logger.exception("Failed to kill proxy pass process")
  38. else:
  39. logger.debug("Proxy pass tunnel killed")
  40. async def _process_stream(self, stdout_line: str) -> None:
  41. logger.debug(stdout_line)
  42. regex = r"tunneled.*?(https:\/\/.+)"
  43. if re.search(regex, stdout_line):
  44. self._tunnel_url = re.search(regex, stdout_line)[1]
  45. self._change_url_callback(self._tunnel_url)
  46. logger.debug("Proxy pass tunneled: %s", self._tunnel_url)
  47. self._url_available.set()
  48. async def get_url(self, port: int, no_retry: bool = False) -> typing.Optional[str]:
  49. async with self._lock:
  50. if self._tunnel_url:
  51. try:
  52. await asyncio.wait_for(self._sproc.wait(), timeout=0.05)
  53. except asyncio.TimeoutError:
  54. return self._tunnel_url
  55. else:
  56. self.kill()
  57. if "DOCKER" in os.environ:
  58. # We're in a Docker container, so we can't use ssh
  59. # Also, the concept of Docker is to keep
  60. # everything isolated, so we can't proxy-pass to
  61. # open web.
  62. return None
  63. logger.debug("Starting proxy pass shell for port %d", port)
  64. self._sproc = await asyncio.create_subprocess_shell(
  65. (
  66. "ssh -o StrictHostKeyChecking=no -R"
  67. f" 80:127.0.0.1:{port} nokey@localhost.run"
  68. ),
  69. stdin=asyncio.subprocess.PIPE,
  70. stdout=asyncio.subprocess.PIPE,
  71. stderr=asyncio.subprocess.PIPE,
  72. )
  73. utils.atexit(self.kill)
  74. self._url_available = asyncio.Event()
  75. logger.debug("Starting proxy pass reader for port %d", port)
  76. asyncio.ensure_future(
  77. self._read_stream(
  78. self._process_stream,
  79. self._sproc.stdout,
  80. 1,
  81. )
  82. )
  83. try:
  84. await asyncio.wait_for(self._url_available.wait(), 15)
  85. except asyncio.TimeoutError:
  86. self.kill()
  87. self._tunnel_url = None
  88. if no_retry:
  89. return None
  90. return await self.get_url(port, no_retry=True)
  91. logger.debug("Proxy pass tunnel url to port %d: %s", port, self._tunnel_url)
  92. return self._tunnel_url