123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- # Author: Johan Hanssen Seferidis
- # License: MIT
- import re
- import sys
- import struct
- from base64 import b64encode
- from hashlib import sha1
- import logging
- if sys.version_info[0] < 3:
- from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler
- else:
- from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
- logger = logging.getLogger(__name__)
- logging.basicConfig()
- '''
- +-+-+-+-+-------+-+-------------+-------------------------------+
- 0 1 2 3
- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- +-+-+-+-+-------+-+-------------+-------------------------------+
- |F|R|R|R| opcode|M| Payload len | Extended payload length |
- |I|S|S|S| (4) |A| (7) | (16/64) |
- |N|V|V|V| |S| | (if payload len==126/127) |
- | |1|2|3| |K| | |
- +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
- | Extended payload length continued, if payload len == 127 |
- + - - - - - - - - - - - - - - - +-------------------------------+
- | Payload Data continued ... |
- +---------------------------------------------------------------+
- '''
- FIN = 0x80
- OPCODE = 0x0f
- MASKED = 0x80
- PAYLOAD_LEN = 0x7f
- PAYLOAD_LEN_EXT16 = 0x7e
- PAYLOAD_LEN_EXT64 = 0x7f
- OPCODE_CONTINUATION = 0x0
- OPCODE_TEXT = 0x1
- OPCODE_BINARY = 0x2
- OPCODE_CLOSE_CONN = 0x8
- OPCODE_PING = 0x9
- OPCODE_PONG = 0xA
- # -------------------------------- API ---------------------------------
- class API():
- def run_forever(self):
- try:
- logger.info("Listening on port %d for clients.." % self.port)
- self.serve_forever()
- except KeyboardInterrupt:
- self.server_close()
- logger.info("Server terminated.")
- except Exception as e:
- logger.error(str(e), exc_info=True)
- exit(1)
- def new_client(self, client, server):
- pass
- def client_left(self, client, server):
- pass
- def message_received(self, client, server, message):
- pass
- def set_fn_new_client(self, fn):
- self.new_client = fn
- def set_fn_client_left(self, fn):
- self.client_left = fn
- def set_fn_message_received(self, fn):
- self.message_received = fn
- def send_message(self, client, msg):
- self._unicast_(client, msg)
- def send_message_to_all(self, msg):
- self._multicast_(msg)
- # ------------------------- Implementation -----------------------------
- class WebsocketServer(ThreadingMixIn, TCPServer, API):
- """
- A websocket server waiting for clients to connect.
- Args:
- port(int): Port to bind to
- host(str): Hostname or IP to listen for connections. By default 127.0.0.1
- is being used. To accept connections from any client, you should use
- 0.0.0.0.
- loglevel: Logging level from logging module to use for logging. By default
- warnings and errors are being logged.
- Properties:
- clients(list): A list of connected clients. A client is a dictionary
- like below.
- {
- 'id' : id,
- 'handler' : handler,
- 'address' : (addr, port)
- }
- """
- allow_reuse_address = True
- daemon_threads = True # comment to keep threads alive until finished
- clients = []
- id_counter = 0
- def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING):
- logger.setLevel(loglevel)
- self.port = port
- TCPServer.__init__(self, (host, port), WebSocketHandler)
- def _message_received_(self, handler, msg):
- self.message_received(self.handler_to_client(handler), self, msg)
- def _ping_received_(self, handler, msg):
- handler.send_pong(msg)
- def _pong_received_(self, handler, msg):
- pass
- def _new_client_(self, handler):
- self.id_counter += 1
- client = {
- 'id': self.id_counter,
- 'handler': handler,
- 'address': handler.client_address
- }
- self.clients.append(client)
- self.new_client(client, self)
- def _client_left_(self, handler):
- client = self.handler_to_client(handler)
- self.client_left(client, self)
- if client in self.clients:
- self.clients.remove(client)
- def _unicast_(self, to_client, msg):
- to_client['handler'].send_message(msg)
- def _multicast_(self, msg):
- for client in self.clients:
- self._unicast_(client, msg)
- def handler_to_client(self, handler):
- for client in self.clients:
- if client['handler'] == handler:
- return client
- class WebSocketHandler(StreamRequestHandler):
- def __init__(self, socket, addr, server):
- self.server = server
- StreamRequestHandler.__init__(self, socket, addr, server)
- def setup(self):
- StreamRequestHandler.setup(self)
- self.keep_alive = True
- self.handshake_done = False
- self.valid_client = False
- def handle(self):
- while self.keep_alive:
- if not self.handshake_done:
- self.handshake()
- elif self.valid_client:
- self.read_next_message()
- def read_bytes(self, num):
- # python3 gives ordinal of byte directly
- bytes = self.rfile.read(num)
- if sys.version_info[0] < 3:
- return map(ord, bytes)
- else:
- return bytes
- def read_next_message(self):
- try:
- b1, b2 = self.read_bytes(2)
- except ValueError as e:
- b1, b2 = 0, 0
- fin = b1 & FIN
- opcode = b1 & OPCODE
- masked = b2 & MASKED
- payload_length = b2 & PAYLOAD_LEN
- if not b1:
- logger.info("Client closed connection.")
- self.keep_alive = 0
- return
- if opcode == OPCODE_CLOSE_CONN:
- logger.info("Client asked to close connection.")
- self.keep_alive = 0
- return
- if not masked:
- logger.warn("Client must always be masked.")
- self.keep_alive = 0
- return
- if opcode == OPCODE_CONTINUATION:
- logger.warn("Continuation frames are not supported.")
- return
- elif opcode == OPCODE_BINARY:
- logger.warn("Binary frames are not supported.")
- return
- elif opcode == OPCODE_TEXT:
- opcode_handler = self.server._message_received_
- elif opcode == OPCODE_PING:
- opcode_handler = self.server._ping_received_
- elif opcode == OPCODE_PONG:
- opcode_handler = self.server._pong_received_
- else:
- logger.warn("Unknown opcode %#x." + opcode)
- self.keep_alive = 0
- return
- if payload_length == 126:
- payload_length = struct.unpack(">H", self.rfile.read(2))[0]
- elif payload_length == 127:
- payload_length = struct.unpack(">Q", self.rfile.read(8))[0]
- masks = self.read_bytes(4)
- decoded = ""
- for char in self.read_bytes(payload_length):
- char ^= masks[len(decoded) % 4]
- decoded += chr(char)
- opcode_handler(self, decoded)
- def send_message(self, message):
- self.send_text(message)
- def send_pong(self, message):
- self.send_text(message, OPCODE_PONG)
- def send_text(self, message, opcode=OPCODE_TEXT):
- """
- Important: Fragmented(=continuation) messages are not supported since
- their usage cases are limited - when we don't know the payload length.
- """
- # Validate message
- if isinstance(message, bytes):
- message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8
- if not message:
- logger.warning("Can\'t send message, message is not valid UTF-8")
- return False
- elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)):
- pass
- elif isinstance(message, str):
- pass
- else:
- logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message))
- return False
- header = bytearray()
- payload = encode_to_UTF8(message)
- payload_length = len(payload)
- # Normal payload
- if payload_length <= 125:
- header.append(FIN | opcode)
- header.append(payload_length)
- # Extended payload
- elif payload_length >= 126 and payload_length <= 65535:
- header.append(FIN | opcode)
- header.append(PAYLOAD_LEN_EXT16)
- header.extend(struct.pack(">H", payload_length))
- # Huge extended payload
- elif payload_length < 18446744073709551616:
- header.append(FIN | opcode)
- header.append(PAYLOAD_LEN_EXT64)
- header.extend(struct.pack(">Q", payload_length))
- else:
- raise Exception("Message is too big. Consider breaking it into chunks.")
- return
- self.request.send(header + payload)
- def handshake(self):
- message = self.request.recv(1024).decode().strip()
- upgrade = re.search('\nupgrade[\s]*:[\s]*websocket', message.lower())
- if not upgrade:
- self.keep_alive = False
- return
- key = re.search('\n[sS]ec-[wW]eb[sS]ocket-[kK]ey[\s]*:[\s]*(.*)\r\n', message)
- if key:
- key = key.group(1)
- else:
- logger.warning("Client tried to connect but was missing a key")
- self.keep_alive = False
- return
- response = self.make_handshake_response(key)
- self.handshake_done = self.request.send(response.encode())
- self.valid_client = True
- self.server._new_client_(self)
- def make_handshake_response(self, key):
- return \
- 'HTTP/1.1 101 Switching Protocols\r\n'\
- 'Upgrade: websocket\r\n' \
- 'Connection: Upgrade\r\n' \
- 'Sec-WebSocket-Accept: %s\r\n' \
- '\r\n' % self.calculate_response_key(key)
- def calculate_response_key(self, key):
- GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
- hash = sha1(key.encode() + GUID.encode())
- response_key = b64encode(hash.digest()).strip()
- return response_key.decode('ASCII')
- def finish(self):
- self.server._client_left_(self)
- def encode_to_UTF8(data):
- try:
- return data.encode('UTF-8')
- except UnicodeEncodeError as e:
- logger.error("Could not encode data to UTF-8 -- %s" % e)
- return False
- except Exception as e:
- raise(e)
- return False
- def try_decode_UTF8(data):
- try:
- return data.decode('utf-8')
- except UnicodeDecodeError:
- return False
- except Exception as e:
- raise(e)
|