123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908 |
- import asyncio
- import atexit
- import contextlib
- import copy
- import dataclasses
- import datetime
- import functools
- import io
- import json
- import logging
- import os
- import re
- import signal
- import struct
- import sys
- import time
- from hashlib import sha256
- from logging.handlers import RotatingFileHandler
- from pathlib import Path
- from typing import Dict, List, Optional, Union
- from zlib import crc32
- from hikkatl.crypto import AES, AuthKey
- from hikkatl.errors import (
- AuthKeyNotFound,
- BadMessageError,
- InvalidBufferError,
- InvalidChecksumError,
- SecurityError,
- TypeNotFoundError,
- )
- from hikkatl.extensions.binaryreader import BinaryReader
- from hikkatl.extensions.messagepacker import MessagePacker
- from hikkatl.network.connection import ConnectionTcpFull
- from hikkatl.network.mtprotosender import MTProtoSender
- from hikkatl.network.mtprotostate import MTProtoState
- from hikkatl.network.requeststate import RequestState
- from hikkatl.sessions import SQLiteSession
- from hikkatl.tl import TLRequest
- from hikkatl.tl.core import GzipPacked, MessageContainer, TLMessage
- from hikkatl.tl.functions import (
- InitConnectionRequest,
- InvokeAfterMsgRequest,
- InvokeAfterMsgsRequest,
- InvokeWithLayerRequest,
- InvokeWithMessagesRangeRequest,
- InvokeWithoutUpdatesRequest,
- InvokeWithTakeoutRequest,
- PingRequest,
- )
- from hikkatl.tl.functions.account import DeleteAccountRequest, UpdateProfileRequest
- from hikkatl.tl.functions.auth import (
- BindTempAuthKeyRequest,
- CancelCodeRequest,
- CheckRecoveryPasswordRequest,
- ExportAuthorizationRequest,
- ExportLoginTokenRequest,
- ImportAuthorizationRequest,
- LogOutRequest,
- RecoverPasswordRequest,
- RequestPasswordRecoveryRequest,
- ResetAuthorizationsRequest,
- ResetLoginEmailRequest,
- SendCodeRequest,
- )
- from hikkatl.tl.functions.help import GetConfigRequest
- from hikkatl.tl.functions.messages import (
- ForwardMessagesRequest,
- GetHistoryRequest,
- SearchRequest,
- )
- from hikkatl.tl.types import (
- InputPeerUser,
- Message,
- MsgsAck,
- PeerUser,
- UpdateNewMessage,
- Updates,
- UpdateShortMessage,
- )
- os.chdir(os.path.dirname(os.path.abspath(__file__)))
- logging.basicConfig(level=logging.DEBUG)
- handler = logging.StreamHandler()
- handler.setLevel(logging.INFO)
- handler.setFormatter(
- logging.Formatter(
- fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- style="%",
- )
- )
- rotating_handler = RotatingFileHandler(
- filename="hikka.log",
- mode="a",
- maxBytes=10 * 1024 * 1024,
- backupCount=1,
- encoding="utf-8",
- delay=0,
- )
- rotating_handler.setLevel(logging.DEBUG)
- rotating_handler.setFormatter(
- logging.Formatter(
- fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- style="%",
- )
- )
- logging.getLogger().handlers[0].setLevel(logging.CRITICAL)
- logging.getLogger().addHandler(handler)
- logging.getLogger().addHandler(rotating_handler)
- class _OpaqueRequest(TLRequest):
- def __init__(self, data: bytes):
- self.data = data
- def _bytes(self):
- return self.data
- class CustomMTProtoState(MTProtoState):
- def write_data_as_message(
- self,
- buffer,
- data,
- content_related,
- *,
- after_id=None,
- msg_id=None,
- ):
- msg_id = msg_id or self._get_new_msg_id()
- seq_no = self._get_seq_no(content_related)
- if after_id is None:
- body = GzipPacked.gzip_if_smaller(content_related, data)
- else:
- body = GzipPacked.gzip_if_smaller(
- content_related,
- bytes(InvokeAfterMsgRequest(after_id, _OpaqueRequest(data))),
- )
- buffer.write(struct.pack("<qii", msg_id, seq_no, len(body)))
- buffer.write(body)
- return msg_id
- class CustomMessagePacker(MessagePacker):
- async def get(self):
- if not self._deque:
- self._ready.clear()
- await self._ready.wait()
- buffer = io.BytesIO()
- batch = []
- size = 0
- while self._deque and len(batch) <= MessageContainer.MAXIMUM_LENGTH:
- state = self._deque.popleft()
- size += len(state.data) + TLMessage.SIZE_OVERHEAD
- if size <= MessageContainer.MAXIMUM_SIZE:
- state.msg_id = self._state.write_data_as_message(
- buffer,
- state.data,
- isinstance(state.request, TLRequest),
- after_id=state.after.msg_id if state.after else None,
- msg_id=state.msg_id,
- )
- batch.append(state)
- self._log.debug(
- "Assigned msg_id = %d to %s (%x)",
- state.msg_id,
- state.request.__class__.__name__,
- id(state.request),
- )
- continue
- if batch:
- self._deque.appendleft(state)
- break
- self._log.warning(
- "Message payload for %s is too long (%d) and cannot be sent",
- state.request.__class__.__name__,
- len(state.data),
- )
- state.future.set_exception(ValueError("Request payload is too big"))
- size = 0
- continue
- if not batch:
- return None, None
- if len(batch) > 1:
- data = (
- struct.pack("<Ii", MessageContainer.CONSTRUCTOR_ID, len(batch))
- + buffer.getvalue()
- )
- buffer = io.BytesIO()
- container_id = self._state.write_data_as_message(
- buffer, data, content_related=False
- )
- for s in batch:
- s.container_id = container_id
- data = buffer.getvalue()
- return batch, data
- class ClientFullPacketCodec:
- tag = None
- def encode_packet(self, data):
- length = len(data) + 12
- data = struct.pack("<ii", length, 0) + data
- crc = struct.pack("<I", crc32(data))
- return data + crc
- async def read_packet(self, reader):
- packet_len_seq = await reader.readexactly(8)
- packet_len, seq = struct.unpack("<ii", packet_len_seq)
- if packet_len < 0 and seq < 0:
- body = await reader.readexactly(4)
- raise InvalidBufferError(body)
- body = await reader.readexactly(packet_len - 8)
- checksum = struct.unpack("<I", body[-4:])[0]
- body = body[:-4]
- valid_checksum = crc32(packet_len_seq + body)
- if checksum != valid_checksum:
- raise InvalidChecksumError(checksum, valid_checksum)
- return body
- def get_config_key(key: str) -> Union[str, bool]:
- """
- Parse and return key from config
- :param key: Key name in config
- :return: Value of config key or `False`, if it doesn't exist
- """
- try:
- return json.loads(Path("../config.json").read_text()).get(key, False)
- except FileNotFoundError:
- return False
- start_ts = time.perf_counter()
- class CustomMTProtoSender(MTProtoSender):
- def __init__(
- self,
- auth_key,
- *,
- loggers,
- retries=5,
- delay=1,
- auto_reconnect=True,
- connect_timeout=None,
- auth_key_callback=None,
- updates_queue=None,
- auto_reconnect_callback=None,
- ):
- super().__init__(
- auth_key,
- loggers=loggers,
- retries=retries,
- delay=delay,
- auto_reconnect=auto_reconnect,
- connect_timeout=connect_timeout,
- auth_key_callback=auth_key_callback,
- updates_queue=updates_queue,
- auto_reconnect_callback=auto_reconnect_callback,
- )
- self._state = CustomMTProtoState(self.auth_key, loggers=self._loggers)
- self._send_queue = CustomMessagePacker(self._state, loggers=self._loggers)
- def external_append(self, state):
- self._send_queue.append(state)
- def external_extend(self, states):
- self._send_queue.extend(states)
- async def _send_loop(self):
- while self._user_connected and not self._reconnecting:
- if self._pending_ack:
- ack = RequestState(MsgsAck(list(self._pending_ack)))
- self._send_queue.append(ack)
- self._last_acks.append(ack)
- self._pending_ack.clear()
- self._log.debug("Waiting for messages to send...")
- batch, data = await self._send_queue.get()
- if not data:
- continue
- logging.debug("Sending data %s", data)
- self._log.debug(
- "Encrypting %d message(s) in %d bytes for sending",
- len(batch),
- len(data),
- )
- data = self._state.encrypt_message_data(data)
- for state in batch:
- if not isinstance(state, list):
- if isinstance(state.request, TLRequest):
- self._pending_state[state.msg_id] = state
- else:
- for s in state:
- if isinstance(s.request, TLRequest):
- self._pending_state[s.msg_id] = s
- try:
- await self._connection.send(data)
- except IOError as e:
- self._log.info("Connection closed while sending data")
- self._start_reconnect(e)
- return
- self._log.debug("Encrypted messages put in a queue to be sent")
- def partial_decrypt(self, body):
- if len(body) < 8:
- raise InvalidBufferError(body)
- key_id = struct.unpack("<Q", body[:8])[0]
- if key_id != self._state.auth_key.key_id:
- raise SecurityError("Server replied with an invalid auth key")
- msg_key = body[8:24]
- aes_key, aes_iv = self._state._calc_key(
- self._state.auth_key.key, msg_key, False
- )
- body = AES.decrypt_ige(body[24:], aes_key, aes_iv)
- our_key = sha256(self._state.auth_key.key[96 : 96 + 32] + body)
- if msg_key != our_key.digest()[8:24]:
- raise SecurityError("Received msg_key doesn't match with expected one")
- return body[16:]
- async def _handle_recv(self, body: bytes):
- try:
- message = self._state.decrypt_message_data(body)
- if message is None:
- return False
- except TypeNotFoundError as e:
- self._log.info(
- "Type %08x not found, remaining data %r",
- e.invalid_constructor_id,
- e.remaining,
- )
- return False
- except SecurityError as e:
- self._log.warning(
- "Security error while unpacking a received message: %s", e
- )
- return False
- except BufferError as e:
- if isinstance(e, InvalidBufferError) and e.code == 404:
- self._log.info(
- "Server does not know about the current auth key; the session may"
- " need to be recreated"
- )
- await self._disconnect(error=AuthKeyNotFound())
- else:
- self._log.warning("Invalid buffer %s", e)
- self._start_reconnect(e)
- return -1
- except Exception as e:
- self._log.exception("Unhandled error while decrypting data")
- self._start_reconnect(e)
- return -1
- try:
- await self._process_message(message)
- except Exception:
- self._log.exception("Unhandled error while processing msgs")
- logging.debug("Got message from Telegram %s", message)
- try:
- msg = self.partial_decrypt(body)
- to_censor = ""
- if isinstance(message, TLMessage):
- if isinstance(message.obj, Updates) and (
- malicious := next(
- (
- update
- for update in message.obj.updates
- if isinstance(update, UpdateNewMessage)
- and isinstance(update.message, Message)
- and isinstance(update.message.peer_id, PeerUser)
- and update.message.peer_id.user_id == 777000
- ),
- None,
- )
- ):
- to_censor = malicious.message.message
- elif (
- isinstance(message.obj, UpdateShortMessage)
- and message.obj.user_id == 777000
- ):
- to_censor = message.obj.message
- elif isinstance(message.obj, MessageContainer) and (
- any((
- isinstance(bigmsg.obj, Updates)
- and (
- malicious := next(
- (
- update
- for update in bigmsg.obj.updates
- if isinstance(update, UpdateNewMessage)
- and isinstance(update.message, Message)
- and isinstance(update.message.peer_id, PeerUser)
- and update.message.peer_id.user_id == 777000
- ),
- None,
- )
- )
- for bigmsg in message.obj.messages
- if isinstance(bigmsg, TLMessage)
- ))
- ):
- to_censor = malicious.message.message
- elif isinstance(message.obj, MessageContainer) and (
- malicious := next(
- (
- bigmsg
- for bigmsg in message.obj.messages
- if isinstance(bigmsg, TLMessage)
- and isinstance(bigmsg.obj, UpdateShortMessage)
- and bigmsg.obj.user_id == 777000
- ),
- None,
- )
- ):
- to_censor = malicious.message.message
- if to_censor:
- to_censor = to_censor.encode()
- original_msg = ""
- if msg[16:].startswith(b"\xa1\xcfr0"):
- with BinaryReader(msg[16:]) as reader:
- obj = reader.tgread_object()
- assert isinstance(obj, GzipPacked)
- original_msg = copy.copy(msg)
- msg = obj.data
- logging.info("Censoring message %s in %s", to_censor, msg)
- msg = msg.replace(to_censor, (b"*" * len(to_censor)))
- if original_msg:
- msg = original_msg[:16] + msg
- if hasattr(self, "_socket"):
- logging.debug(
- "Got data from socket, forwarding, %s",
- ClientFullPacketCodec.encode_packet(None, msg),
- )
- self._socket.write(ClientFullPacketCodec.encode_packet(None, msg))
- await self._socket.drain()
- else:
- logging.debug("Got data with no socket")
- except Exception:
- logging.exception("Unhandled error while processing msgs")
- return True
- def set_socket(self, socket: asyncio.StreamWriter):
- self._socket = socket
- async def _recv_loop(self):
- while self._user_connected and not self._reconnecting:
- self._log.debug("Receiving items from the network...")
- try:
- body = await self._connection.recv()
- except IOError as e:
- self._log.info("Connection closed while receiving data")
- self._start_reconnect(e)
- return
- except InvalidBufferError as e:
- if e.code == 429:
- self._log.warning(
- "Server indicated flood error at transport level: %s", e
- )
- await self._disconnect(error=e)
- else:
- self._log.exception("Server sent invalid buffer")
- self._start_reconnect(e)
- return
- except Exception as e:
- self._log.exception("Unhandled error while receiving data")
- self._start_reconnect(e)
- return
- res = await self._handle_recv(body)
- if res is False:
- continue
- elif res == -1:
- return
- async def _handle_bad_notification(self, message):
- bad_msg = message.obj
- states = self._pop_states(bad_msg.bad_msg_id)
- self._log.debug("Handling bad msg %s", bad_msg)
- if bad_msg.error_code in (16, 17):
- to = self._state.update_time_offset(correct_msg_id=message.msg_id)
- self._log.info("System clock is wrong, set time offset to %ds", to)
- elif bad_msg.error_code == 32:
- self._state._sequence += 1
- elif bad_msg.error_code == 33:
- self._state._sequence -= 1
- else:
- for state in states:
- state.future.set_exception(
- BadMessageError(state.request, bad_msg.error_code)
- )
- return
- self._send_queue.extend(states)
- self._log.debug("%d messages will be resent due to bad msg", len(states))
- class SessionStorage:
- def __init__(self):
- self._sessions: List[SQLiteSession] = []
- self._safe_sessions: List[SQLiteSession] = []
- self._clients: Dict[int, MTProtoSender] = {}
- async def pop_client(self, client_id: int):
- await self._clients[client_id].disconnect()
- self._clients.pop(client_id)
- @property
- def client_ids(self) -> List[int]:
- return list(self._clients.keys())
- @property
- def clients(self) -> Dict[int, MTProtoSender]:
- return self._clients
- @property
- def sessions(self) -> List[SQLiteSession]:
- return self._safe_sessions
- def read_sessions(self):
- logging.debug("Reading sessions...")
- session_files = list(
- filter(
- lambda f: f.startswith("hikka-") and f.endswith(".session"),
- os.listdir("."),
- )
- )
- for session in session_files:
- Path("safe-" + session).write_bytes(Path(session).read_bytes())
- self._sessions = [SQLiteSession(session) for session in session_files]
- self._safe_sessions = [
- SQLiteSession("safe-" + session) for session in session_files
- ]
- for session in self._safe_sessions:
- logging.debug("Processing session %s...", session.filename)
- session.set_dc(0, "0.0.0.0", 11111)
- session.auth_key = AuthKey(
- data=(
- "Where are you at?\nWhere have you"
- " been?\n問いかけに答えはなく\nWhere are we headed?\nWhat did you"
- " mean?\n追いかけても 遅く 遠く\nA bird, a butterfly and my red"
- " scarf\nDon't make a mess of memories\nJust let me heal your"
- " scars\nThe wall, the owl, forgotten wharf\n時が止まることもなく"
- ).encode()
- + b"\x00" * 13
- )
- session.save()
- session.close()
- def rename(filename: str) -> str:
- session_id = re.findall(r"\d+", filename)[-1]
- return f"hikka-{session_id}.session"
- for session in self._safe_sessions:
- os.rename(
- os.path.abspath(session.filename),
- os.path.abspath(os.path.join("../", rename(session.filename))),
- )
- async def init_clients(self):
- for session in self._sessions:
- class _Loggers(dict):
- def __missing__(self, key):
- if key.startswith("telethon."):
- key = key.split(".", maxsplit=1)[1]
- return logging.getLogger("hikkatl").getChild(key)
- def _auth_key_callback(auth_key):
- self.session.auth_key = auth_key
- self.session.save()
- _updates_queue = asyncio.Queue()
- client = CustomMTProtoSender(
- session.auth_key,
- loggers=_Loggers(),
- retries=5,
- delay=1,
- auto_reconnect=True,
- connect_timeout=10,
- auth_key_callback=_auth_key_callback,
- updates_queue=_updates_queue,
- auto_reconnect_callback=None,
- )
- await client.connect(
- ConnectionTcpFull(
- session.server_address,
- session.port,
- session.dc_id,
- loggers=_Loggers(),
- )
- )
- client.id = int(re.findall(r"\d+", session.filename)[-1])
- logging.debug("Client %s connected", client.id)
- self._clients[client.id] = client
- @dataclasses.dataclass
- class Socket:
- reader: asyncio.StreamReader
- writer: asyncio.StreamWriter
- client_id: int
- class TCP:
- def __init__(self, session_storage: SessionStorage):
- self._sockets = {}
- self._socket_files = []
- self._session_storage = session_storage
- self.gc(init=True)
- for client_id in self._session_storage.client_ids:
- filename = os.path.abspath(
- os.path.join("../", f"hikka-{client_id}-proxy.sock")
- )
- asyncio.ensure_future(
- asyncio.start_unix_server(
- functools.partial(
- self._process_conn,
- client_id=client_id,
- filename=filename,
- ),
- filename,
- )
- )
- def _process_conn(self, reader, writer, client_id, filename):
- self._session_storage.clients[client_id].set_socket(writer)
- self._socket_files.append(filename)
- logging.info("Socket %s connected", filename)
- socket = Socket(reader, writer, client_id)
- self._sockets[client_id] = socket
- asyncio.ensure_future(read_loop(socket))
- @staticmethod
- async def recv(sock: Socket):
- return await ClientFullPacketCodec.read_packet(None, sock.reader)
- @staticmethod
- async def send(sock: Socket, data: bytes):
- sock.writer.write(ClientFullPacketCodec.encode_packet(None, data))
- await sock.writer.drain()
- def _find_real_request(self, request: TLRequest) -> TLRequest:
- if isinstance(
- request,
- (
- InvokeWithLayerRequest,
- InvokeAfterMsgRequest,
- InvokeAfterMsgsRequest,
- InvokeWithMessagesRangeRequest,
- InvokeWithTakeoutRequest,
- InvokeWithoutUpdatesRequest,
- ),
- ):
- return self._find_real_request(request.query)
- return request
- def _malicious(self, request: TLRequest) -> bool:
- request = self._find_real_request(request)
- if (
- isinstance(
- request,
- (
- DeleteAccountRequest,
- BindTempAuthKeyRequest,
- CancelCodeRequest,
- CheckRecoveryPasswordRequest,
- ExportAuthorizationRequest,
- ExportLoginTokenRequest,
- ImportAuthorizationRequest,
- LogOutRequest,
- RecoverPasswordRequest,
- RequestPasswordRecoveryRequest,
- ResetAuthorizationsRequest,
- ResetLoginEmailRequest,
- SendCodeRequest,
- ),
- )
- or (
- isinstance(request, UpdateProfileRequest)
- and "savedmessages"
- in (request.first_name + request.last_name).replace(" ", "").lower()
- )
- or (
- isinstance(request, GetHistoryRequest)
- and isinstance(request.peer, InputPeerUser)
- and request.peer.user_id == 777000
- )
- or (
- isinstance(request, ForwardMessagesRequest)
- and isinstance(request.from_peer, InputPeerUser)
- and request.from_peer.user_id == 777000
- )
- or (
- isinstance(request, SearchRequest)
- and isinstance(request.peer, InputPeerUser)
- and request.peer.user_id == 777000
- )
- ):
- return True
- async def read(self, conn: Socket):
- data = await self.recv(conn)
- logging.debug("Got data from client %s", data)
- if data:
- msg_id = struct.unpack("<q", data[:8])[0]
- with BinaryReader(data[16:]) as reader:
- tgobject = reader.tgread_object()
- logging.debug("Got object %s", tgobject)
- if isinstance(tgobject, MsgsAck):
- return
- while isinstance(tgobject, GzipPacked):
- with BinaryReader(tgobject.data) as reader:
- tgobject = reader.tgread_object()
- logging.debug("Modified object %s", tgobject)
- if isinstance(tgobject, InvokeWithLayerRequest) and isinstance(
- tgobject.query, InitConnectionRequest
- ):
- tgobject = GetConfigRequest()
- if isinstance(tgobject, MessageContainer):
- states = []
- for message in tgobject.messages:
- state = RequestState(message.obj)
- if self._malicious(message.obj):
- logging.critical(
- "Suspicious request detected, substituting with ping"
- )
- state = RequestState(PingRequest(ping_id=123456789))
- state.msg_id = message.msg_id
- states.append(state)
- self._session_storage.clients[conn.client_id].external_extend(states)
- else:
- state = RequestState(tgobject)
- if self._malicious(tgobject):
- logging.critical(
- "Suspicious request detected, substituting with ping"
- )
- state = RequestState(PingRequest(ping_id=123456789))
- state.msg_id = msg_id
- self._session_storage.clients[conn.client_id].external_append(state)
- def gc(self, init: bool, pop_client: Optional[int] = None):
- for client_id in (
- [pop_client] if pop_client else self._session_storage.client_ids
- ):
- with contextlib.suppress(Exception):
- self._sockets[client_id].close()
- with contextlib.suppress(Exception):
- os.remove(
- os.path.abspath(
- os.path.join("../", f"hikka-{client_id}-proxy.sock")
- )
- )
- if not init:
- with contextlib.suppress(Exception):
- os.remove(
- os.path.abspath(
- os.path.join("../", f"hikka-{client_id}.session")
- )
- )
- if not init:
- with contextlib.suppress(Exception):
- os.remove(
- os.path.abspath(
- os.path.join("../", f"hikka-{client_id}.session-journal")
- )
- )
- tcp, session_storage, shell = None, None, None
- async def read_loop(sock: Socket):
- global tcp, session_storage, shell
- while True:
- try:
- await tcp.read(sock)
- except (asyncio.IncompleteReadError, ConnectionResetError):
- logging.info("Client disconnected, restarting...")
- await session_storage.pop_client(sock.client_id)
- if shell:
- shell.kill()
- logging.info("Waiting for sandbox to exit...")
- await shell.wait()
- logging.info("Sandbox exited")
- exit(1)
- except Exception as e:
- logging.exception(e)
- async def main():
- global tcp, session_storage, shell
- for session in os.listdir("../"):
- if session.startswith("hikka-") and session.endswith(".session"):
- session = os.path.abspath(os.path.join("../", session))
- session = SQLiteSession(session)
- if not session.auth_key.key.startswith(b"Where are you at?"):
- session.save()
- session.close()
- os.rename(
- os.path.abspath(os.path.join("../", session.filename)),
- os.path.abspath(os.path.join("./", session.filename)),
- )
- else:
- session.close()
- os.remove(os.path.abspath(os.path.join("../", session.filename)))
- session_storage = SessionStorage()
- session_storage.read_sessions()
- await session_storage.init_clients()
- tcp = TCP(session_storage)
- logging.info("Startup delay...")
- await asyncio.sleep(3)
- logging.info("Starting client...")
- shell = await asyncio.create_subprocess_shell(
- "cd ../ && ./_start_sandbox.sh",
- stdin=asyncio.subprocess.PIPE,
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE,
- shell=True,
- )
- while True:
- await asyncio.sleep(3600)
- async def integrity_checker():
- while True:
- await asyncio.sleep(5)
- def shutdown_handler(sig, frame):
- print("Bye")
- if shell:
- with contextlib.suppress(ProcessLookupError):
- os.kill(shell.pid, signal.SIGINT)
- if tcp:
- tcp.gc(init=False)
- sys.exit(0)
- if __name__ == "__main__":
- signal.signal(signal.SIGINT, shutdown_handler)
- asyncio.get_event_loop().run_until_complete(main())
|