listener.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. import asyncio
  2. import atexit
  3. import contextlib
  4. import copy
  5. import dataclasses
  6. import datetime
  7. import functools
  8. import io
  9. import json
  10. import logging
  11. import os
  12. import re
  13. import signal
  14. import struct
  15. import sys
  16. import time
  17. from hashlib import sha256
  18. from logging.handlers import RotatingFileHandler
  19. from pathlib import Path
  20. from typing import Dict, List, Optional, Union
  21. from zlib import crc32
  22. from hikkatl.crypto import AES, AuthKey
  23. from hikkatl.errors import (
  24. AuthKeyNotFound,
  25. BadMessageError,
  26. InvalidBufferError,
  27. InvalidChecksumError,
  28. SecurityError,
  29. TypeNotFoundError,
  30. )
  31. from hikkatl.extensions.binaryreader import BinaryReader
  32. from hikkatl.extensions.messagepacker import MessagePacker
  33. from hikkatl.network.connection import ConnectionTcpFull
  34. from hikkatl.network.mtprotosender import MTProtoSender
  35. from hikkatl.network.mtprotostate import MTProtoState
  36. from hikkatl.network.requeststate import RequestState
  37. from hikkatl.sessions import SQLiteSession
  38. from hikkatl.tl import TLRequest
  39. from hikkatl.tl.core import GzipPacked, MessageContainer, TLMessage
  40. from hikkatl.tl.functions import (
  41. InitConnectionRequest,
  42. InvokeAfterMsgRequest,
  43. InvokeAfterMsgsRequest,
  44. InvokeWithLayerRequest,
  45. InvokeWithMessagesRangeRequest,
  46. InvokeWithoutUpdatesRequest,
  47. InvokeWithTakeoutRequest,
  48. PingRequest,
  49. )
  50. from hikkatl.tl.functions.account import DeleteAccountRequest, UpdateProfileRequest
  51. from hikkatl.tl.functions.auth import (
  52. BindTempAuthKeyRequest,
  53. CancelCodeRequest,
  54. CheckRecoveryPasswordRequest,
  55. ExportAuthorizationRequest,
  56. ExportLoginTokenRequest,
  57. ImportAuthorizationRequest,
  58. LogOutRequest,
  59. RecoverPasswordRequest,
  60. RequestPasswordRecoveryRequest,
  61. ResetAuthorizationsRequest,
  62. ResetLoginEmailRequest,
  63. SendCodeRequest,
  64. )
  65. from hikkatl.tl.functions.help import GetConfigRequest
  66. from hikkatl.tl.functions.messages import (
  67. ForwardMessagesRequest,
  68. GetHistoryRequest,
  69. SearchRequest,
  70. )
  71. from hikkatl.tl.types import (
  72. InputPeerUser,
  73. Message,
  74. MsgsAck,
  75. PeerUser,
  76. UpdateNewMessage,
  77. Updates,
  78. UpdateShortMessage,
  79. )
  80. os.chdir(os.path.dirname(os.path.abspath(__file__)))
  81. logging.basicConfig(level=logging.DEBUG)
  82. handler = logging.StreamHandler()
  83. handler.setLevel(logging.INFO)
  84. handler.setFormatter(
  85. logging.Formatter(
  86. fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  87. datefmt="%Y-%m-%d %H:%M:%S",
  88. style="%",
  89. )
  90. )
  91. rotating_handler = RotatingFileHandler(
  92. filename="hikka.log",
  93. mode="a",
  94. maxBytes=10 * 1024 * 1024,
  95. backupCount=1,
  96. encoding="utf-8",
  97. delay=0,
  98. )
  99. rotating_handler.setLevel(logging.DEBUG)
  100. rotating_handler.setFormatter(
  101. logging.Formatter(
  102. fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
  103. datefmt="%Y-%m-%d %H:%M:%S",
  104. style="%",
  105. )
  106. )
  107. logging.getLogger().handlers[0].setLevel(logging.CRITICAL)
  108. logging.getLogger().addHandler(handler)
  109. logging.getLogger().addHandler(rotating_handler)
  110. class _OpaqueRequest(TLRequest):
  111. def __init__(self, data: bytes):
  112. self.data = data
  113. def _bytes(self):
  114. return self.data
  115. class CustomMTProtoState(MTProtoState):
  116. def write_data_as_message(
  117. self,
  118. buffer,
  119. data,
  120. content_related,
  121. *,
  122. after_id=None,
  123. msg_id=None,
  124. ):
  125. msg_id = msg_id or self._get_new_msg_id()
  126. seq_no = self._get_seq_no(content_related)
  127. if after_id is None:
  128. body = GzipPacked.gzip_if_smaller(content_related, data)
  129. else:
  130. body = GzipPacked.gzip_if_smaller(
  131. content_related,
  132. bytes(InvokeAfterMsgRequest(after_id, _OpaqueRequest(data))),
  133. )
  134. buffer.write(struct.pack("<qii", msg_id, seq_no, len(body)))
  135. buffer.write(body)
  136. return msg_id
  137. class CustomMessagePacker(MessagePacker):
  138. async def get(self):
  139. if not self._deque:
  140. self._ready.clear()
  141. await self._ready.wait()
  142. buffer = io.BytesIO()
  143. batch = []
  144. size = 0
  145. while self._deque and len(batch) <= MessageContainer.MAXIMUM_LENGTH:
  146. state = self._deque.popleft()
  147. size += len(state.data) + TLMessage.SIZE_OVERHEAD
  148. if size <= MessageContainer.MAXIMUM_SIZE:
  149. state.msg_id = self._state.write_data_as_message(
  150. buffer,
  151. state.data,
  152. isinstance(state.request, TLRequest),
  153. after_id=state.after.msg_id if state.after else None,
  154. msg_id=state.msg_id,
  155. )
  156. batch.append(state)
  157. self._log.debug(
  158. "Assigned msg_id = %d to %s (%x)",
  159. state.msg_id,
  160. state.request.__class__.__name__,
  161. id(state.request),
  162. )
  163. continue
  164. if batch:
  165. self._deque.appendleft(state)
  166. break
  167. self._log.warning(
  168. "Message payload for %s is too long (%d) and cannot be sent",
  169. state.request.__class__.__name__,
  170. len(state.data),
  171. )
  172. state.future.set_exception(ValueError("Request payload is too big"))
  173. size = 0
  174. continue
  175. if not batch:
  176. return None, None
  177. if len(batch) > 1:
  178. data = (
  179. struct.pack("<Ii", MessageContainer.CONSTRUCTOR_ID, len(batch))
  180. + buffer.getvalue()
  181. )
  182. buffer = io.BytesIO()
  183. container_id = self._state.write_data_as_message(
  184. buffer, data, content_related=False
  185. )
  186. for s in batch:
  187. s.container_id = container_id
  188. data = buffer.getvalue()
  189. return batch, data
  190. class ClientFullPacketCodec:
  191. tag = None
  192. def encode_packet(self, data):
  193. length = len(data) + 12
  194. data = struct.pack("<ii", length, 0) + data
  195. crc = struct.pack("<I", crc32(data))
  196. return data + crc
  197. async def read_packet(self, reader):
  198. packet_len_seq = await reader.readexactly(8)
  199. packet_len, seq = struct.unpack("<ii", packet_len_seq)
  200. if packet_len < 0 and seq < 0:
  201. body = await reader.readexactly(4)
  202. raise InvalidBufferError(body)
  203. body = await reader.readexactly(packet_len - 8)
  204. checksum = struct.unpack("<I", body[-4:])[0]
  205. body = body[:-4]
  206. valid_checksum = crc32(packet_len_seq + body)
  207. if checksum != valid_checksum:
  208. raise InvalidChecksumError(checksum, valid_checksum)
  209. return body
  210. def get_config_key(key: str) -> Union[str, bool]:
  211. """
  212. Parse and return key from config
  213. :param key: Key name in config
  214. :return: Value of config key or `False`, if it doesn't exist
  215. """
  216. try:
  217. return json.loads(Path("../config.json").read_text()).get(key, False)
  218. except FileNotFoundError:
  219. return False
  220. start_ts = time.perf_counter()
  221. class CustomMTProtoSender(MTProtoSender):
  222. def __init__(
  223. self,
  224. auth_key,
  225. *,
  226. loggers,
  227. retries=5,
  228. delay=1,
  229. auto_reconnect=True,
  230. connect_timeout=None,
  231. auth_key_callback=None,
  232. updates_queue=None,
  233. auto_reconnect_callback=None,
  234. ):
  235. super().__init__(
  236. auth_key,
  237. loggers=loggers,
  238. retries=retries,
  239. delay=delay,
  240. auto_reconnect=auto_reconnect,
  241. connect_timeout=connect_timeout,
  242. auth_key_callback=auth_key_callback,
  243. updates_queue=updates_queue,
  244. auto_reconnect_callback=auto_reconnect_callback,
  245. )
  246. self._state = CustomMTProtoState(self.auth_key, loggers=self._loggers)
  247. self._send_queue = CustomMessagePacker(self._state, loggers=self._loggers)
  248. def external_append(self, state):
  249. self._send_queue.append(state)
  250. def external_extend(self, states):
  251. self._send_queue.extend(states)
  252. async def _send_loop(self):
  253. while self._user_connected and not self._reconnecting:
  254. if self._pending_ack:
  255. ack = RequestState(MsgsAck(list(self._pending_ack)))
  256. self._send_queue.append(ack)
  257. self._last_acks.append(ack)
  258. self._pending_ack.clear()
  259. self._log.debug("Waiting for messages to send...")
  260. batch, data = await self._send_queue.get()
  261. if not data:
  262. continue
  263. logging.debug("Sending data %s", data)
  264. self._log.debug(
  265. "Encrypting %d message(s) in %d bytes for sending",
  266. len(batch),
  267. len(data),
  268. )
  269. data = self._state.encrypt_message_data(data)
  270. for state in batch:
  271. if not isinstance(state, list):
  272. if isinstance(state.request, TLRequest):
  273. self._pending_state[state.msg_id] = state
  274. else:
  275. for s in state:
  276. if isinstance(s.request, TLRequest):
  277. self._pending_state[s.msg_id] = s
  278. try:
  279. await self._connection.send(data)
  280. except IOError as e:
  281. self._log.info("Connection closed while sending data")
  282. self._start_reconnect(e)
  283. return
  284. self._log.debug("Encrypted messages put in a queue to be sent")
  285. def partial_decrypt(self, body):
  286. if len(body) < 8:
  287. raise InvalidBufferError(body)
  288. key_id = struct.unpack("<Q", body[:8])[0]
  289. if key_id != self._state.auth_key.key_id:
  290. raise SecurityError("Server replied with an invalid auth key")
  291. msg_key = body[8:24]
  292. aes_key, aes_iv = self._state._calc_key(
  293. self._state.auth_key.key, msg_key, False
  294. )
  295. body = AES.decrypt_ige(body[24:], aes_key, aes_iv)
  296. our_key = sha256(self._state.auth_key.key[96 : 96 + 32] + body)
  297. if msg_key != our_key.digest()[8:24]:
  298. raise SecurityError("Received msg_key doesn't match with expected one")
  299. return body[16:]
  300. async def _handle_recv(self, body: bytes):
  301. try:
  302. message = self._state.decrypt_message_data(body)
  303. if message is None:
  304. return False
  305. except TypeNotFoundError as e:
  306. self._log.info(
  307. "Type %08x not found, remaining data %r",
  308. e.invalid_constructor_id,
  309. e.remaining,
  310. )
  311. return False
  312. except SecurityError as e:
  313. self._log.warning(
  314. "Security error while unpacking a received message: %s", e
  315. )
  316. return False
  317. except BufferError as e:
  318. if isinstance(e, InvalidBufferError) and e.code == 404:
  319. self._log.info(
  320. "Server does not know about the current auth key; the session may"
  321. " need to be recreated"
  322. )
  323. await self._disconnect(error=AuthKeyNotFound())
  324. else:
  325. self._log.warning("Invalid buffer %s", e)
  326. self._start_reconnect(e)
  327. return -1
  328. except Exception as e:
  329. self._log.exception("Unhandled error while decrypting data")
  330. self._start_reconnect(e)
  331. return -1
  332. try:
  333. await self._process_message(message)
  334. except Exception:
  335. self._log.exception("Unhandled error while processing msgs")
  336. logging.debug("Got message from Telegram %s", message)
  337. try:
  338. msg = self.partial_decrypt(body)
  339. to_censor = ""
  340. if isinstance(message, TLMessage):
  341. if isinstance(message.obj, Updates) and (
  342. malicious := next(
  343. (
  344. update
  345. for update in message.obj.updates
  346. if isinstance(update, UpdateNewMessage)
  347. and isinstance(update.message, Message)
  348. and isinstance(update.message.peer_id, PeerUser)
  349. and update.message.peer_id.user_id == 777000
  350. ),
  351. None,
  352. )
  353. ):
  354. to_censor = malicious.message.message
  355. elif (
  356. isinstance(message.obj, UpdateShortMessage)
  357. and message.obj.user_id == 777000
  358. ):
  359. to_censor = message.obj.message
  360. elif isinstance(message.obj, MessageContainer) and (
  361. any((
  362. isinstance(bigmsg.obj, Updates)
  363. and (
  364. malicious := next(
  365. (
  366. update
  367. for update in bigmsg.obj.updates
  368. if isinstance(update, UpdateNewMessage)
  369. and isinstance(update.message, Message)
  370. and isinstance(update.message.peer_id, PeerUser)
  371. and update.message.peer_id.user_id == 777000
  372. ),
  373. None,
  374. )
  375. )
  376. for bigmsg in message.obj.messages
  377. if isinstance(bigmsg, TLMessage)
  378. ))
  379. ):
  380. to_censor = malicious.message.message
  381. elif isinstance(message.obj, MessageContainer) and (
  382. malicious := next(
  383. (
  384. bigmsg
  385. for bigmsg in message.obj.messages
  386. if isinstance(bigmsg, TLMessage)
  387. and isinstance(bigmsg.obj, UpdateShortMessage)
  388. and bigmsg.obj.user_id == 777000
  389. ),
  390. None,
  391. )
  392. ):
  393. to_censor = malicious.message.message
  394. if to_censor:
  395. to_censor = to_censor.encode()
  396. original_msg = ""
  397. if msg[16:].startswith(b"\xa1\xcfr0"):
  398. with BinaryReader(msg[16:]) as reader:
  399. obj = reader.tgread_object()
  400. assert isinstance(obj, GzipPacked)
  401. original_msg = copy.copy(msg)
  402. msg = obj.data
  403. logging.info("Censoring message %s in %s", to_censor, msg)
  404. msg = msg.replace(to_censor, (b"*" * len(to_censor)))
  405. if original_msg:
  406. msg = original_msg[:16] + msg
  407. if hasattr(self, "_socket"):
  408. logging.debug(
  409. "Got data from socket, forwarding, %s",
  410. ClientFullPacketCodec.encode_packet(None, msg),
  411. )
  412. self._socket.write(ClientFullPacketCodec.encode_packet(None, msg))
  413. await self._socket.drain()
  414. else:
  415. logging.debug("Got data with no socket")
  416. except Exception:
  417. logging.exception("Unhandled error while processing msgs")
  418. return True
  419. def set_socket(self, socket: asyncio.StreamWriter):
  420. self._socket = socket
  421. async def _recv_loop(self):
  422. while self._user_connected and not self._reconnecting:
  423. self._log.debug("Receiving items from the network...")
  424. try:
  425. body = await self._connection.recv()
  426. except IOError as e:
  427. self._log.info("Connection closed while receiving data")
  428. self._start_reconnect(e)
  429. return
  430. except InvalidBufferError as e:
  431. if e.code == 429:
  432. self._log.warning(
  433. "Server indicated flood error at transport level: %s", e
  434. )
  435. await self._disconnect(error=e)
  436. else:
  437. self._log.exception("Server sent invalid buffer")
  438. self._start_reconnect(e)
  439. return
  440. except Exception as e:
  441. self._log.exception("Unhandled error while receiving data")
  442. self._start_reconnect(e)
  443. return
  444. res = await self._handle_recv(body)
  445. if res is False:
  446. continue
  447. elif res == -1:
  448. return
  449. async def _handle_bad_notification(self, message):
  450. bad_msg = message.obj
  451. states = self._pop_states(bad_msg.bad_msg_id)
  452. self._log.debug("Handling bad msg %s", bad_msg)
  453. if bad_msg.error_code in (16, 17):
  454. to = self._state.update_time_offset(correct_msg_id=message.msg_id)
  455. self._log.info("System clock is wrong, set time offset to %ds", to)
  456. elif bad_msg.error_code == 32:
  457. self._state._sequence += 1
  458. elif bad_msg.error_code == 33:
  459. self._state._sequence -= 1
  460. else:
  461. for state in states:
  462. state.future.set_exception(
  463. BadMessageError(state.request, bad_msg.error_code)
  464. )
  465. return
  466. self._send_queue.extend(states)
  467. self._log.debug("%d messages will be resent due to bad msg", len(states))
  468. class SessionStorage:
  469. def __init__(self):
  470. self._sessions: List[SQLiteSession] = []
  471. self._safe_sessions: List[SQLiteSession] = []
  472. self._clients: Dict[int, MTProtoSender] = {}
  473. async def pop_client(self, client_id: int):
  474. await self._clients[client_id].disconnect()
  475. self._clients.pop(client_id)
  476. @property
  477. def client_ids(self) -> List[int]:
  478. return list(self._clients.keys())
  479. @property
  480. def clients(self) -> Dict[int, MTProtoSender]:
  481. return self._clients
  482. @property
  483. def sessions(self) -> List[SQLiteSession]:
  484. return self._safe_sessions
  485. def read_sessions(self):
  486. logging.debug("Reading sessions...")
  487. session_files = list(
  488. filter(
  489. lambda f: f.startswith("hikka-") and f.endswith(".session"),
  490. os.listdir("."),
  491. )
  492. )
  493. for session in session_files:
  494. Path("safe-" + session).write_bytes(Path(session).read_bytes())
  495. self._sessions = [SQLiteSession(session) for session in session_files]
  496. self._safe_sessions = [
  497. SQLiteSession("safe-" + session) for session in session_files
  498. ]
  499. for session in self._safe_sessions:
  500. logging.debug("Processing session %s...", session.filename)
  501. session.set_dc(0, "0.0.0.0", 11111)
  502. session.auth_key = AuthKey(
  503. data=(
  504. "Where are you at?\nWhere have you"
  505. " been?\n問いかけに答えはなく\nWhere are we headed?\nWhat did you"
  506. " mean?\n追いかけても 遅く 遠く\nA bird, a butterfly and my red"
  507. " scarf\nDon't make a mess of memories\nJust let me heal your"
  508. " scars\nThe wall, the owl, forgotten wharf\n時が止まることもなく"
  509. ).encode()
  510. + b"\x00" * 13
  511. )
  512. session.save()
  513. session.close()
  514. def rename(filename: str) -> str:
  515. session_id = re.findall(r"\d+", filename)[-1]
  516. return f"hikka-{session_id}.session"
  517. for session in self._safe_sessions:
  518. os.rename(
  519. os.path.abspath(session.filename),
  520. os.path.abspath(os.path.join("../", rename(session.filename))),
  521. )
  522. async def init_clients(self):
  523. for session in self._sessions:
  524. class _Loggers(dict):
  525. def __missing__(self, key):
  526. if key.startswith("telethon."):
  527. key = key.split(".", maxsplit=1)[1]
  528. return logging.getLogger("hikkatl").getChild(key)
  529. def _auth_key_callback(auth_key):
  530. self.session.auth_key = auth_key
  531. self.session.save()
  532. _updates_queue = asyncio.Queue()
  533. client = CustomMTProtoSender(
  534. session.auth_key,
  535. loggers=_Loggers(),
  536. retries=5,
  537. delay=1,
  538. auto_reconnect=True,
  539. connect_timeout=10,
  540. auth_key_callback=_auth_key_callback,
  541. updates_queue=_updates_queue,
  542. auto_reconnect_callback=None,
  543. )
  544. await client.connect(
  545. ConnectionTcpFull(
  546. session.server_address,
  547. session.port,
  548. session.dc_id,
  549. loggers=_Loggers(),
  550. )
  551. )
  552. client.id = int(re.findall(r"\d+", session.filename)[-1])
  553. logging.debug("Client %s connected", client.id)
  554. self._clients[client.id] = client
  555. @dataclasses.dataclass
  556. class Socket:
  557. reader: asyncio.StreamReader
  558. writer: asyncio.StreamWriter
  559. client_id: int
  560. class TCP:
  561. def __init__(self, session_storage: SessionStorage):
  562. self._sockets = {}
  563. self._socket_files = []
  564. self._session_storage = session_storage
  565. self.gc(init=True)
  566. for client_id in self._session_storage.client_ids:
  567. filename = os.path.abspath(
  568. os.path.join("../", f"hikka-{client_id}-proxy.sock")
  569. )
  570. asyncio.ensure_future(
  571. asyncio.start_unix_server(
  572. functools.partial(
  573. self._process_conn,
  574. client_id=client_id,
  575. filename=filename,
  576. ),
  577. filename,
  578. )
  579. )
  580. def _process_conn(self, reader, writer, client_id, filename):
  581. self._session_storage.clients[client_id].set_socket(writer)
  582. self._socket_files.append(filename)
  583. logging.info("Socket %s connected", filename)
  584. socket = Socket(reader, writer, client_id)
  585. self._sockets[client_id] = socket
  586. asyncio.ensure_future(read_loop(socket))
  587. @staticmethod
  588. async def recv(sock: Socket):
  589. return await ClientFullPacketCodec.read_packet(None, sock.reader)
  590. @staticmethod
  591. async def send(sock: Socket, data: bytes):
  592. sock.writer.write(ClientFullPacketCodec.encode_packet(None, data))
  593. await sock.writer.drain()
  594. def _find_real_request(self, request: TLRequest) -> TLRequest:
  595. if isinstance(
  596. request,
  597. (
  598. InvokeWithLayerRequest,
  599. InvokeAfterMsgRequest,
  600. InvokeAfterMsgsRequest,
  601. InvokeWithMessagesRangeRequest,
  602. InvokeWithTakeoutRequest,
  603. InvokeWithoutUpdatesRequest,
  604. ),
  605. ):
  606. return self._find_real_request(request.query)
  607. return request
  608. def _malicious(self, request: TLRequest) -> bool:
  609. request = self._find_real_request(request)
  610. if (
  611. isinstance(
  612. request,
  613. (
  614. DeleteAccountRequest,
  615. BindTempAuthKeyRequest,
  616. CancelCodeRequest,
  617. CheckRecoveryPasswordRequest,
  618. ExportAuthorizationRequest,
  619. ExportLoginTokenRequest,
  620. ImportAuthorizationRequest,
  621. LogOutRequest,
  622. RecoverPasswordRequest,
  623. RequestPasswordRecoveryRequest,
  624. ResetAuthorizationsRequest,
  625. ResetLoginEmailRequest,
  626. SendCodeRequest,
  627. ),
  628. )
  629. or (
  630. isinstance(request, UpdateProfileRequest)
  631. and "savedmessages"
  632. in (request.first_name + request.last_name).replace(" ", "").lower()
  633. )
  634. or (
  635. isinstance(request, GetHistoryRequest)
  636. and isinstance(request.peer, InputPeerUser)
  637. and request.peer.user_id == 777000
  638. )
  639. or (
  640. isinstance(request, ForwardMessagesRequest)
  641. and isinstance(request.from_peer, InputPeerUser)
  642. and request.from_peer.user_id == 777000
  643. )
  644. or (
  645. isinstance(request, SearchRequest)
  646. and isinstance(request.peer, InputPeerUser)
  647. and request.peer.user_id == 777000
  648. )
  649. ):
  650. return True
  651. async def read(self, conn: Socket):
  652. data = await self.recv(conn)
  653. logging.debug("Got data from client %s", data)
  654. if data:
  655. msg_id = struct.unpack("<q", data[:8])[0]
  656. with BinaryReader(data[16:]) as reader:
  657. tgobject = reader.tgread_object()
  658. logging.debug("Got object %s", tgobject)
  659. if isinstance(tgobject, MsgsAck):
  660. return
  661. while isinstance(tgobject, GzipPacked):
  662. with BinaryReader(tgobject.data) as reader:
  663. tgobject = reader.tgread_object()
  664. logging.debug("Modified object %s", tgobject)
  665. if isinstance(tgobject, InvokeWithLayerRequest) and isinstance(
  666. tgobject.query, InitConnectionRequest
  667. ):
  668. tgobject = GetConfigRequest()
  669. if isinstance(tgobject, MessageContainer):
  670. states = []
  671. for message in tgobject.messages:
  672. state = RequestState(message.obj)
  673. if self._malicious(message.obj):
  674. logging.critical(
  675. "Suspicious request detected, substituting with ping"
  676. )
  677. state = RequestState(PingRequest(ping_id=123456789))
  678. state.msg_id = message.msg_id
  679. states.append(state)
  680. self._session_storage.clients[conn.client_id].external_extend(states)
  681. else:
  682. state = RequestState(tgobject)
  683. if self._malicious(tgobject):
  684. logging.critical(
  685. "Suspicious request detected, substituting with ping"
  686. )
  687. state = RequestState(PingRequest(ping_id=123456789))
  688. state.msg_id = msg_id
  689. self._session_storage.clients[conn.client_id].external_append(state)
  690. def gc(self, init: bool, pop_client: Optional[int] = None):
  691. for client_id in (
  692. [pop_client] if pop_client else self._session_storage.client_ids
  693. ):
  694. with contextlib.suppress(Exception):
  695. self._sockets[client_id].close()
  696. with contextlib.suppress(Exception):
  697. os.remove(
  698. os.path.abspath(
  699. os.path.join("../", f"hikka-{client_id}-proxy.sock")
  700. )
  701. )
  702. if not init:
  703. with contextlib.suppress(Exception):
  704. os.remove(
  705. os.path.abspath(
  706. os.path.join("../", f"hikka-{client_id}.session")
  707. )
  708. )
  709. if not init:
  710. with contextlib.suppress(Exception):
  711. os.remove(
  712. os.path.abspath(
  713. os.path.join("../", f"hikka-{client_id}.session-journal")
  714. )
  715. )
  716. tcp, session_storage, shell = None, None, None
  717. async def read_loop(sock: Socket):
  718. global tcp, session_storage, shell
  719. while True:
  720. try:
  721. await tcp.read(sock)
  722. except (asyncio.IncompleteReadError, ConnectionResetError):
  723. logging.info("Client disconnected, restarting...")
  724. await session_storage.pop_client(sock.client_id)
  725. if shell:
  726. shell.kill()
  727. logging.info("Waiting for sandbox to exit...")
  728. await shell.wait()
  729. logging.info("Sandbox exited")
  730. exit(1)
  731. except Exception as e:
  732. logging.exception(e)
  733. async def main():
  734. global tcp, session_storage, shell
  735. for session in os.listdir("../"):
  736. if session.startswith("hikka-") and session.endswith(".session"):
  737. session = os.path.abspath(os.path.join("../", session))
  738. session = SQLiteSession(session)
  739. if not session.auth_key.key.startswith(b"Where are you at?"):
  740. session.save()
  741. session.close()
  742. os.rename(
  743. os.path.abspath(os.path.join("../", session.filename)),
  744. os.path.abspath(os.path.join("./", session.filename)),
  745. )
  746. else:
  747. session.close()
  748. os.remove(os.path.abspath(os.path.join("../", session.filename)))
  749. session_storage = SessionStorage()
  750. session_storage.read_sessions()
  751. await session_storage.init_clients()
  752. tcp = TCP(session_storage)
  753. logging.info("Startup delay...")
  754. await asyncio.sleep(3)
  755. logging.info("Starting client...")
  756. shell = await asyncio.create_subprocess_shell(
  757. "cd ../ && ./_start_sandbox.sh",
  758. stdin=asyncio.subprocess.PIPE,
  759. stdout=asyncio.subprocess.PIPE,
  760. stderr=asyncio.subprocess.PIPE,
  761. shell=True,
  762. )
  763. while True:
  764. await asyncio.sleep(3600)
  765. async def integrity_checker():
  766. while True:
  767. await asyncio.sleep(5)
  768. def shutdown_handler(sig, frame):
  769. print("Bye")
  770. if shell:
  771. with contextlib.suppress(ProcessLookupError):
  772. os.kill(shell.pid, signal.SIGINT)
  773. if tcp:
  774. tcp.gc(init=False)
  775. sys.exit(0)
  776. if __name__ == "__main__":
  777. signal.signal(signal.SIGINT, shutdown_handler)
  778. asyncio.get_event_loop().run_until_complete(main())