Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

425 строки
17KB

  1. import time
  2. import unittest
  3. import six
  4. if six.PY3:
  5. from unittest import mock
  6. else:
  7. import mock
  8. from engineio import exceptions
  9. from engineio import packet
  10. from engineio import payload
  11. from engineio import socket
  12. class TestSocket(unittest.TestCase):
  13. def setUp(self):
  14. self.bg_tasks = []
  15. def _get_mock_server(self):
  16. mock_server = mock.Mock()
  17. mock_server.ping_timeout = 0.2
  18. mock_server.ping_interval = 0.2
  19. mock_server.async_handlers = True
  20. try:
  21. import queue
  22. except ImportError:
  23. import Queue as queue
  24. import threading
  25. mock_server._async = {'threading': threading.Thread,
  26. 'queue': queue.Queue,
  27. 'websocket': None}
  28. def bg_task(target, *args, **kwargs):
  29. th = threading.Thread(target=target, args=args, kwargs=kwargs)
  30. self.bg_tasks.append(th)
  31. th.start()
  32. return th
  33. def create_queue(*args, **kwargs):
  34. return queue.Queue(*args, **kwargs)
  35. mock_server.start_background_task = bg_task
  36. mock_server.create_queue = create_queue
  37. mock_server.get_queue_empty_exception.return_value = queue.Empty
  38. return mock_server
  39. def _join_bg_tasks(self):
  40. for task in self.bg_tasks:
  41. task.join()
  42. def test_create(self):
  43. mock_server = self._get_mock_server()
  44. s = socket.Socket(mock_server, 'sid')
  45. self.assertEqual(s.server, mock_server)
  46. self.assertEqual(s.sid, 'sid')
  47. self.assertFalse(s.upgraded)
  48. self.assertFalse(s.closed)
  49. self.assertTrue(hasattr(s.queue, 'get'))
  50. self.assertTrue(hasattr(s.queue, 'put'))
  51. self.assertTrue(hasattr(s.queue, 'task_done'))
  52. self.assertTrue(hasattr(s.queue, 'join'))
  53. def test_empty_poll(self):
  54. mock_server = self._get_mock_server()
  55. s = socket.Socket(mock_server, 'sid')
  56. self.assertRaises(exceptions.QueueEmpty, s.poll)
  57. def test_poll(self):
  58. mock_server = self._get_mock_server()
  59. s = socket.Socket(mock_server, 'sid')
  60. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  61. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  62. s.send(pkt1)
  63. s.send(pkt2)
  64. self.assertEqual(s.poll(), [pkt1, pkt2])
  65. def test_ping_pong(self):
  66. mock_server = self._get_mock_server()
  67. s = socket.Socket(mock_server, 'sid')
  68. s.receive(packet.Packet(packet.PING, data='abc'))
  69. r = s.poll()
  70. self.assertEqual(len(r), 1)
  71. self.assertTrue(r[0].encode(), b'3abc')
  72. def test_message_async_handler(self):
  73. mock_server = self._get_mock_server()
  74. s = socket.Socket(mock_server, 'sid')
  75. s.receive(packet.Packet(packet.MESSAGE, data='foo'))
  76. mock_server._trigger_event.assert_called_once_with('message', 'sid',
  77. 'foo',
  78. run_async=True)
  79. def test_message_sync_handler(self):
  80. mock_server = self._get_mock_server()
  81. mock_server.async_handlers = False
  82. s = socket.Socket(mock_server, 'sid')
  83. s.receive(packet.Packet(packet.MESSAGE, data='foo'))
  84. mock_server._trigger_event.assert_called_once_with('message', 'sid',
  85. 'foo',
  86. run_async=False)
  87. def test_invalid_packet(self):
  88. mock_server = self._get_mock_server()
  89. s = socket.Socket(mock_server, 'sid')
  90. self.assertRaises(exceptions.UnknownPacketError, s.receive,
  91. packet.Packet(packet.OPEN))
  92. def test_timeout(self):
  93. mock_server = self._get_mock_server()
  94. mock_server.ping_interval = -6
  95. s = socket.Socket(mock_server, 'sid')
  96. s.last_ping = time.time() - 1
  97. s.close = mock.MagicMock()
  98. s.send('packet')
  99. s.close.assert_called_once_with(wait=False, abort=False)
  100. def test_polling_read(self):
  101. mock_server = self._get_mock_server()
  102. s = socket.Socket(mock_server, 'foo')
  103. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  104. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  105. s.send(pkt1)
  106. s.send(pkt2)
  107. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  108. start_response = mock.MagicMock()
  109. packets = s.handle_get_request(environ, start_response)
  110. self.assertEqual(packets, [pkt1, pkt2])
  111. def test_polling_read_error(self):
  112. mock_server = self._get_mock_server()
  113. s = socket.Socket(mock_server, 'foo')
  114. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  115. start_response = mock.MagicMock()
  116. self.assertRaises(exceptions.QueueEmpty, s.handle_get_request, environ,
  117. start_response)
  118. def test_polling_write(self):
  119. mock_server = self._get_mock_server()
  120. mock_server.max_http_buffer_size = 1000
  121. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  122. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  123. p = payload.Payload(packets=[pkt1, pkt2]).encode()
  124. s = socket.Socket(mock_server, 'foo')
  125. s.receive = mock.MagicMock()
  126. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
  127. 'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
  128. s.handle_post_request(environ)
  129. self.assertEqual(s.receive.call_count, 2)
  130. def test_polling_write_too_large(self):
  131. mock_server = self._get_mock_server()
  132. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  133. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  134. p = payload.Payload(packets=[pkt1, pkt2]).encode()
  135. mock_server.max_http_buffer_size = len(p) - 1
  136. s = socket.Socket(mock_server, 'foo')
  137. s.receive = mock.MagicMock()
  138. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
  139. 'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
  140. self.assertRaises(exceptions.ContentTooLongError,
  141. s.handle_post_request, environ)
  142. def test_upgrade_handshake(self):
  143. mock_server = self._get_mock_server()
  144. s = socket.Socket(mock_server, 'foo')
  145. s._upgrade_websocket = mock.MagicMock()
  146. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  147. 'HTTP_CONNECTION': 'Foo,Upgrade,Bar',
  148. 'HTTP_UPGRADE': 'websocket'}
  149. start_response = mock.MagicMock()
  150. s.handle_get_request(environ, start_response)
  151. s._upgrade_websocket.assert_called_once_with(environ, start_response)
  152. def test_upgrade(self):
  153. mock_server = self._get_mock_server()
  154. mock_server._async['websocket'] = mock.MagicMock()
  155. mock_ws = mock.MagicMock()
  156. mock_server._async['websocket'].return_value = mock_ws
  157. s = socket.Socket(mock_server, 'sid')
  158. s.connected = True
  159. environ = "foo"
  160. start_response = "bar"
  161. s._upgrade_websocket(environ, start_response)
  162. mock_server._async['websocket'].assert_called_once_with(
  163. s._websocket_handler)
  164. mock_ws.assert_called_once_with(environ, start_response)
  165. def test_upgrade_twice(self):
  166. mock_server = self._get_mock_server()
  167. mock_server._async['websocket'] = mock.MagicMock()
  168. s = socket.Socket(mock_server, 'sid')
  169. s.connected = True
  170. s.upgraded = True
  171. environ = "foo"
  172. start_response = "bar"
  173. self.assertRaises(IOError, s._upgrade_websocket,
  174. environ, start_response)
  175. def test_upgrade_packet(self):
  176. mock_server = self._get_mock_server()
  177. s = socket.Socket(mock_server, 'sid')
  178. s.connected = True
  179. s.receive(packet.Packet(packet.UPGRADE))
  180. r = s.poll()
  181. self.assertEqual(len(r), 1)
  182. self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode())
  183. def test_upgrade_no_probe(self):
  184. mock_server = self._get_mock_server()
  185. s = socket.Socket(mock_server, 'sid')
  186. s.connected = True
  187. ws = mock.MagicMock()
  188. ws.wait.return_value = packet.Packet(packet.NOOP).encode(
  189. always_bytes=False)
  190. s._websocket_handler(ws)
  191. self.assertFalse(s.upgraded)
  192. def test_upgrade_no_upgrade_packet(self):
  193. mock_server = self._get_mock_server()
  194. s = socket.Socket(mock_server, 'sid')
  195. s.connected = True
  196. s.queue.join = mock.MagicMock(return_value=None)
  197. ws = mock.MagicMock()
  198. probe = six.text_type('probe')
  199. ws.wait.side_effect = [
  200. packet.Packet(packet.PING, data=probe).encode(
  201. always_bytes=False),
  202. packet.Packet(packet.NOOP).encode(always_bytes=False)]
  203. s._websocket_handler(ws)
  204. ws.send.assert_called_once_with(packet.Packet(
  205. packet.PONG, data=probe).encode(always_bytes=False))
  206. self.assertEqual(s.queue.get().packet_type, packet.NOOP)
  207. self.assertFalse(s.upgraded)
  208. def test_close_packet(self):
  209. mock_server = self._get_mock_server()
  210. s = socket.Socket(mock_server, 'sid')
  211. s.connected = True
  212. s.close = mock.MagicMock()
  213. s.receive(packet.Packet(packet.CLOSE))
  214. s.close.assert_called_once_with(wait=False, abort=True)
  215. def test_invalid_packet_type(self):
  216. mock_server = self._get_mock_server()
  217. s = socket.Socket(mock_server, 'sid')
  218. pkt = packet.Packet(packet_type=99)
  219. self.assertRaises(exceptions.UnknownPacketError, s.receive, pkt)
  220. def test_upgrade_not_supported(self):
  221. mock_server = self._get_mock_server()
  222. mock_server._async['websocket'] = None
  223. s = socket.Socket(mock_server, 'sid')
  224. s.connected = True
  225. environ = "foo"
  226. start_response = "bar"
  227. s._upgrade_websocket(environ, start_response)
  228. mock_server._bad_request.assert_called_once_with()
  229. def test_websocket_read_write(self):
  230. mock_server = self._get_mock_server()
  231. s = socket.Socket(mock_server, 'sid')
  232. s.connected = False
  233. s.queue.join = mock.MagicMock(return_value=None)
  234. foo = six.text_type('foo')
  235. bar = six.text_type('bar')
  236. s.poll = mock.MagicMock(side_effect=[
  237. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  238. ws = mock.MagicMock()
  239. ws.wait.side_effect = [
  240. packet.Packet(packet.MESSAGE, data=foo).encode(
  241. always_bytes=False),
  242. None]
  243. s._websocket_handler(ws)
  244. self._join_bg_tasks()
  245. self.assertTrue(s.connected)
  246. self.assertTrue(s.upgraded)
  247. self.assertEqual(mock_server._trigger_event.call_count, 2)
  248. mock_server._trigger_event.assert_has_calls([
  249. mock.call('message', 'sid', 'foo', run_async=True),
  250. mock.call('disconnect', 'sid', run_async=False)])
  251. ws.send.assert_called_with('4bar')
  252. def test_websocket_upgrade_read_write(self):
  253. mock_server = self._get_mock_server()
  254. s = socket.Socket(mock_server, 'sid')
  255. s.connected = True
  256. s.queue.join = mock.MagicMock(return_value=None)
  257. foo = six.text_type('foo')
  258. bar = six.text_type('bar')
  259. probe = six.text_type('probe')
  260. s.poll = mock.MagicMock(side_effect=[
  261. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  262. ws = mock.MagicMock()
  263. ws.wait.side_effect = [
  264. packet.Packet(packet.PING, data=probe).encode(
  265. always_bytes=False),
  266. packet.Packet(packet.UPGRADE).encode(always_bytes=False),
  267. packet.Packet(packet.MESSAGE, data=foo).encode(
  268. always_bytes=False),
  269. None]
  270. s._websocket_handler(ws)
  271. self._join_bg_tasks()
  272. self.assertTrue(s.upgraded)
  273. self.assertEqual(mock_server._trigger_event.call_count, 2)
  274. mock_server._trigger_event.assert_has_calls([
  275. mock.call('message', 'sid', 'foo', run_async=True),
  276. mock.call('disconnect', 'sid', run_async=False)])
  277. ws.send.assert_called_with('4bar')
  278. def test_websocket_upgrade_with_payload(self):
  279. mock_server = self._get_mock_server()
  280. s = socket.Socket(mock_server, 'sid')
  281. s.connected = True
  282. s.queue.join = mock.MagicMock(return_value=None)
  283. probe = six.text_type('probe')
  284. ws = mock.MagicMock()
  285. ws.wait.side_effect = [
  286. packet.Packet(packet.PING, data=probe).encode(
  287. always_bytes=False),
  288. packet.Packet(packet.UPGRADE, data=b'2').encode(
  289. always_bytes=False)]
  290. s._websocket_handler(ws)
  291. self._join_bg_tasks()
  292. self.assertTrue(s.upgraded)
  293. def test_websocket_upgrade_with_backlog(self):
  294. mock_server = self._get_mock_server()
  295. s = socket.Socket(mock_server, 'sid')
  296. s.connected = True
  297. s.queue.join = mock.MagicMock(return_value=None)
  298. probe = six.text_type('probe')
  299. foo = six.text_type('foo')
  300. ws = mock.MagicMock()
  301. ws.wait.side_effect = [
  302. packet.Packet(packet.PING, data=probe).encode(
  303. always_bytes=False),
  304. packet.Packet(packet.UPGRADE, data=b'2').encode(
  305. always_bytes=False)]
  306. s.upgrading = True
  307. s.send(packet.Packet(packet.MESSAGE, data=foo))
  308. s._websocket_handler(ws)
  309. self._join_bg_tasks()
  310. self.assertTrue(s.upgraded)
  311. self.assertFalse(s.upgrading)
  312. self.assertEqual(s.packet_backlog, [])
  313. ws.send.assert_called_with('4foo')
  314. def test_websocket_read_write_wait_fail(self):
  315. mock_server = self._get_mock_server()
  316. s = socket.Socket(mock_server, 'sid')
  317. s.connected = False
  318. s.queue.join = mock.MagicMock(return_value=None)
  319. foo = six.text_type('foo')
  320. bar = six.text_type('bar')
  321. s.poll = mock.MagicMock(side_effect=[
  322. [packet.Packet(packet.MESSAGE, data=bar)],
  323. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  324. ws = mock.MagicMock()
  325. ws.wait.side_effect = [
  326. packet.Packet(packet.MESSAGE, data=foo).encode(
  327. always_bytes=False),
  328. RuntimeError]
  329. ws.send.side_effect = [None, RuntimeError]
  330. s._websocket_handler(ws)
  331. self._join_bg_tasks()
  332. self.assertEqual(s.closed, True)
  333. def test_websocket_ignore_invalid_packet(self):
  334. mock_server = self._get_mock_server()
  335. s = socket.Socket(mock_server, 'sid')
  336. s.connected = False
  337. s.queue.join = mock.MagicMock(return_value=None)
  338. foo = six.text_type('foo')
  339. bar = six.text_type('bar')
  340. s.poll = mock.MagicMock(side_effect=[
  341. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  342. ws = mock.MagicMock()
  343. ws.wait.side_effect = [
  344. packet.Packet(packet.OPEN).encode(always_bytes=False),
  345. packet.Packet(packet.MESSAGE, data=foo).encode(
  346. always_bytes=False),
  347. None]
  348. s._websocket_handler(ws)
  349. self._join_bg_tasks()
  350. self.assertTrue(s.connected)
  351. self.assertEqual(mock_server._trigger_event.call_count, 2)
  352. mock_server._trigger_event.assert_has_calls([
  353. mock.call('message', 'sid', foo, run_async=True),
  354. mock.call('disconnect', 'sid', run_async=False)])
  355. ws.send.assert_called_with('4bar')
  356. def test_send_after_close(self):
  357. mock_server = self._get_mock_server()
  358. s = socket.Socket(mock_server, 'sid')
  359. s.close(wait=False)
  360. self.assertRaises(exceptions.SocketIsClosedError, s.send,
  361. packet.Packet(packet.NOOP))
  362. def test_close_after_close(self):
  363. mock_server = self._get_mock_server()
  364. s = socket.Socket(mock_server, 'sid')
  365. s.close(wait=False)
  366. self.assertTrue(s.closed)
  367. self.assertEqual(mock_server._trigger_event.call_count, 1)
  368. mock_server._trigger_event.assert_called_once_with('disconnect', 'sid',
  369. run_async=False)
  370. s.close()
  371. self.assertEqual(mock_server._trigger_event.call_count, 1)
  372. def test_close_and_wait(self):
  373. mock_server = self._get_mock_server()
  374. s = socket.Socket(mock_server, 'sid')
  375. s.queue = mock.MagicMock()
  376. s.close(wait=True)
  377. s.queue.join.assert_called_once_with()
  378. def test_close_without_wait(self):
  379. mock_server = self._get_mock_server()
  380. s = socket.Socket(mock_server, 'sid')
  381. s.queue = mock.MagicMock()
  382. s.close(wait=False)
  383. self.assertEqual(s.queue.join.call_count, 0)