websocket_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # Author: Johan Hanssen Seferidis
  2. # License: MIT
  3. import re
  4. import sys
  5. import struct
  6. from base64 import b64encode
  7. from hashlib import sha1
  8. import logging
  9. if sys.version_info[0] < 3:
  10. from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler
  11. else:
  12. from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
  13. logger = logging.getLogger(__name__)
  14. logging.basicConfig()
  15. '''
  16. +-+-+-+-+-------+-+-------------+-------------------------------+
  17. 0 1 2 3
  18. 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  19. +-+-+-+-+-------+-+-------------+-------------------------------+
  20. |F|R|R|R| opcode|M| Payload len | Extended payload length |
  21. |I|S|S|S| (4) |A| (7) | (16/64) |
  22. |N|V|V|V| |S| | (if payload len==126/127) |
  23. | |1|2|3| |K| | |
  24. +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
  25. | Extended payload length continued, if payload len == 127 |
  26. + - - - - - - - - - - - - - - - +-------------------------------+
  27. | Payload Data continued ... |
  28. +---------------------------------------------------------------+
  29. '''
  30. FIN = 0x80
  31. OPCODE = 0x0f
  32. MASKED = 0x80
  33. PAYLOAD_LEN = 0x7f
  34. PAYLOAD_LEN_EXT16 = 0x7e
  35. PAYLOAD_LEN_EXT64 = 0x7f
  36. OPCODE_CONTINUATION = 0x0
  37. OPCODE_TEXT = 0x1
  38. OPCODE_BINARY = 0x2
  39. OPCODE_CLOSE_CONN = 0x8
  40. OPCODE_PING = 0x9
  41. OPCODE_PONG = 0xA
  42. # -------------------------------- API ---------------------------------
  43. class API():
  44. def run_forever(self):
  45. try:
  46. logger.info("Listening on port %d for clients.." % self.port)
  47. self.serve_forever()
  48. except KeyboardInterrupt:
  49. self.server_close()
  50. logger.info("Server terminated.")
  51. except Exception as e:
  52. logger.error(str(e), exc_info=True)
  53. exit(1)
  54. def new_client(self, client, server):
  55. pass
  56. def client_left(self, client, server):
  57. pass
  58. def message_received(self, client, server, message):
  59. pass
  60. def set_fn_new_client(self, fn):
  61. self.new_client = fn
  62. def set_fn_client_left(self, fn):
  63. self.client_left = fn
  64. def set_fn_message_received(self, fn):
  65. self.message_received = fn
  66. def send_message(self, client, msg):
  67. self._unicast_(client, msg)
  68. def send_message_to_all(self, msg):
  69. self._multicast_(msg)
  70. # ------------------------- Implementation -----------------------------
  71. class WebsocketServer(ThreadingMixIn, TCPServer, API):
  72. """
  73. A websocket server waiting for clients to connect.
  74. Args:
  75. port(int): Port to bind to
  76. host(str): Hostname or IP to listen for connections. By default 127.0.0.1
  77. is being used. To accept connections from any client, you should use
  78. 0.0.0.0.
  79. loglevel: Logging level from logging module to use for logging. By default
  80. warnings and errors are being logged.
  81. Properties:
  82. clients(list): A list of connected clients. A client is a dictionary
  83. like below.
  84. {
  85. 'id' : id,
  86. 'handler' : handler,
  87. 'address' : (addr, port)
  88. }
  89. """
  90. allow_reuse_address = True
  91. daemon_threads = True # comment to keep threads alive until finished
  92. clients = []
  93. id_counter = 0
  94. def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING):
  95. logger.setLevel(loglevel)
  96. self.port = port
  97. TCPServer.__init__(self, (host, port), WebSocketHandler)
  98. def _message_received_(self, handler, msg):
  99. self.message_received(self.handler_to_client(handler), self, msg)
  100. def _ping_received_(self, handler, msg):
  101. handler.send_pong(msg)
  102. def _pong_received_(self, handler, msg):
  103. pass
  104. def _new_client_(self, handler):
  105. self.id_counter += 1
  106. client = {
  107. 'id': self.id_counter,
  108. 'handler': handler,
  109. 'address': handler.client_address
  110. }
  111. self.clients.append(client)
  112. self.new_client(client, self)
  113. def _client_left_(self, handler):
  114. client = self.handler_to_client(handler)
  115. self.client_left(client, self)
  116. if client in self.clients:
  117. self.clients.remove(client)
  118. def _unicast_(self, to_client, msg):
  119. to_client['handler'].send_message(msg)
  120. def _multicast_(self, msg):
  121. for client in self.clients:
  122. self._unicast_(client, msg)
  123. def handler_to_client(self, handler):
  124. for client in self.clients:
  125. if client['handler'] == handler:
  126. return client
  127. class WebSocketHandler(StreamRequestHandler):
  128. def __init__(self, socket, addr, server):
  129. self.server = server
  130. StreamRequestHandler.__init__(self, socket, addr, server)
  131. def setup(self):
  132. StreamRequestHandler.setup(self)
  133. self.keep_alive = True
  134. self.handshake_done = False
  135. self.valid_client = False
  136. def handle(self):
  137. while self.keep_alive:
  138. if not self.handshake_done:
  139. self.handshake()
  140. elif self.valid_client:
  141. self.read_next_message()
  142. def read_bytes(self, num):
  143. # python3 gives ordinal of byte directly
  144. bytes = self.rfile.read(num)
  145. if sys.version_info[0] < 3:
  146. return map(ord, bytes)
  147. else:
  148. return bytes
  149. def read_next_message(self):
  150. try:
  151. b1, b2 = self.read_bytes(2)
  152. except ValueError as e:
  153. b1, b2 = 0, 0
  154. fin = b1 & FIN
  155. opcode = b1 & OPCODE
  156. masked = b2 & MASKED
  157. payload_length = b2 & PAYLOAD_LEN
  158. if not b1:
  159. logger.info("Client closed connection.")
  160. self.keep_alive = 0
  161. return
  162. if opcode == OPCODE_CLOSE_CONN:
  163. logger.info("Client asked to close connection.")
  164. self.keep_alive = 0
  165. return
  166. if not masked:
  167. logger.warn("Client must always be masked.")
  168. self.keep_alive = 0
  169. return
  170. if opcode == OPCODE_CONTINUATION:
  171. logger.warn("Continuation frames are not supported.")
  172. return
  173. elif opcode == OPCODE_BINARY:
  174. logger.warn("Binary frames are not supported.")
  175. return
  176. elif opcode == OPCODE_TEXT:
  177. opcode_handler = self.server._message_received_
  178. elif opcode == OPCODE_PING:
  179. opcode_handler = self.server._ping_received_
  180. elif opcode == OPCODE_PONG:
  181. opcode_handler = self.server._pong_received_
  182. else:
  183. logger.warn("Unknown opcode %#x." + opcode)
  184. self.keep_alive = 0
  185. return
  186. if payload_length == 126:
  187. payload_length = struct.unpack(">H", self.rfile.read(2))[0]
  188. elif payload_length == 127:
  189. payload_length = struct.unpack(">Q", self.rfile.read(8))[0]
  190. masks = self.read_bytes(4)
  191. decoded = ""
  192. for char in self.read_bytes(payload_length):
  193. char ^= masks[len(decoded) % 4]
  194. decoded += chr(char)
  195. opcode_handler(self, decoded)
  196. def send_message(self, message):
  197. self.send_text(message)
  198. def send_pong(self, message):
  199. self.send_text(message, OPCODE_PONG)
  200. def send_text(self, message, opcode=OPCODE_TEXT):
  201. """
  202. Important: Fragmented(=continuation) messages are not supported since
  203. their usage cases are limited - when we don't know the payload length.
  204. """
  205. # Validate message
  206. if isinstance(message, bytes):
  207. message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8
  208. if not message:
  209. logger.warning("Can\'t send message, message is not valid UTF-8")
  210. return False
  211. elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)):
  212. pass
  213. elif isinstance(message, str):
  214. pass
  215. else:
  216. logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message))
  217. return False
  218. header = bytearray()
  219. payload = encode_to_UTF8(message)
  220. payload_length = len(payload)
  221. # Normal payload
  222. if payload_length <= 125:
  223. header.append(FIN | opcode)
  224. header.append(payload_length)
  225. # Extended payload
  226. elif payload_length >= 126 and payload_length <= 65535:
  227. header.append(FIN | opcode)
  228. header.append(PAYLOAD_LEN_EXT16)
  229. header.extend(struct.pack(">H", payload_length))
  230. # Huge extended payload
  231. elif payload_length < 18446744073709551616:
  232. header.append(FIN | opcode)
  233. header.append(PAYLOAD_LEN_EXT64)
  234. header.extend(struct.pack(">Q", payload_length))
  235. else:
  236. raise Exception("Message is too big. Consider breaking it into chunks.")
  237. return
  238. self.request.send(header + payload)
  239. def handshake(self):
  240. message = self.request.recv(1024).decode().strip()
  241. upgrade = re.search('\nupgrade[\s]*:[\s]*websocket', message.lower())
  242. if not upgrade:
  243. self.keep_alive = False
  244. return
  245. key = re.search('\n[sS]ec-[wW]eb[sS]ocket-[kK]ey[\s]*:[\s]*(.*)\r\n', message)
  246. if key:
  247. key = key.group(1)
  248. else:
  249. logger.warning("Client tried to connect but was missing a key")
  250. self.keep_alive = False
  251. return
  252. response = self.make_handshake_response(key)
  253. self.handshake_done = self.request.send(response.encode())
  254. self.valid_client = True
  255. self.server._new_client_(self)
  256. def make_handshake_response(self, key):
  257. return \
  258. 'HTTP/1.1 101 Switching Protocols\r\n'\
  259. 'Upgrade: websocket\r\n' \
  260. 'Connection: Upgrade\r\n' \
  261. 'Sec-WebSocket-Accept: %s\r\n' \
  262. '\r\n' % self.calculate_response_key(key)
  263. def calculate_response_key(self, key):
  264. GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  265. hash = sha1(key.encode() + GUID.encode())
  266. response_key = b64encode(hash.digest()).strip()
  267. return response_key.decode('ASCII')
  268. def finish(self):
  269. self.server._client_left_(self)
  270. def encode_to_UTF8(data):
  271. try:
  272. return data.encode('UTF-8')
  273. except UnicodeEncodeError as e:
  274. logger.error("Could not encode data to UTF-8 -- %s" % e)
  275. return False
  276. except Exception as e:
  277. raise(e)
  278. return False
  279. def try_decode_UTF8(data):
  280. try:
  281. return data.decode('utf-8')
  282. except UnicodeDecodeError:
  283. return False
  284. except Exception as e:
  285. raise(e)