Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

442 строки
17KB

  1. import asyncio
  2. import sys
  3. import time
  4. import unittest
  5. import six
  6. if six.PY3:
  7. from unittest import mock
  8. else:
  9. import mock
  10. from engineio import asyncio_socket
  11. from engineio import exceptions
  12. from engineio import packet
  13. from engineio import payload
  14. def AsyncMock(*args, **kwargs):
  15. """Return a mock asynchronous function."""
  16. m = mock.MagicMock(*args, **kwargs)
  17. async def mock_coro(*args, **kwargs):
  18. return m(*args, **kwargs)
  19. mock_coro.mock = m
  20. return mock_coro
  21. def _run(coro):
  22. """Run the given coroutine."""
  23. return asyncio.get_event_loop().run_until_complete(coro)
  24. @unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+')
  25. class TestSocket(unittest.TestCase):
  26. def _get_read_mock_coro(self, payload):
  27. mock_input = mock.MagicMock()
  28. mock_input.read = AsyncMock()
  29. mock_input.read.mock.return_value = payload
  30. return mock_input
  31. def _get_mock_server(self):
  32. mock_server = mock.Mock()
  33. mock_server.ping_timeout = 0.2
  34. mock_server.ping_interval = 0.2
  35. mock_server.async_handlers = False
  36. mock_server._async = {'asyncio': True,
  37. 'create_route': mock.MagicMock(),
  38. 'translate_request': mock.MagicMock(),
  39. 'make_response': mock.MagicMock(),
  40. 'websocket': 'w'}
  41. mock_server._async['translate_request'].return_value = 'request'
  42. mock_server._async['make_response'].return_value = 'response'
  43. mock_server._trigger_event = AsyncMock()
  44. def create_queue(*args, **kwargs):
  45. queue = asyncio.Queue(*args, **kwargs)
  46. queue.Empty = asyncio.QueueEmpty
  47. return queue
  48. mock_server.create_queue = create_queue
  49. return mock_server
  50. def test_create(self):
  51. mock_server = self._get_mock_server()
  52. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  53. self.assertEqual(s.server, mock_server)
  54. self.assertEqual(s.sid, 'sid')
  55. self.assertFalse(s.upgraded)
  56. self.assertFalse(s.closed)
  57. self.assertTrue(hasattr(s.queue, 'get'))
  58. self.assertTrue(hasattr(s.queue, 'put'))
  59. self.assertTrue(hasattr(s.queue, 'task_done'))
  60. self.assertTrue(hasattr(s.queue, 'join'))
  61. def test_empty_poll(self):
  62. mock_server = self._get_mock_server()
  63. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  64. self.assertRaises(exceptions.QueueEmpty, _run, s.poll())
  65. def test_poll(self):
  66. mock_server = self._get_mock_server()
  67. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  68. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  69. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  70. _run(s.send(pkt1))
  71. _run(s.send(pkt2))
  72. self.assertEqual(_run(s.poll()), [pkt1, pkt2])
  73. def test_poll_none(self):
  74. mock_server = self._get_mock_server()
  75. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  76. _run(s.queue.put(None))
  77. self.assertEqual(_run(s.poll()), [])
  78. def test_ping_pong(self):
  79. mock_server = self._get_mock_server()
  80. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  81. _run(s.receive(packet.Packet(packet.PING, data='abc')))
  82. r = _run(s.poll())
  83. self.assertEqual(len(r), 1)
  84. self.assertTrue(r[0].encode(), b'3abc')
  85. def test_message_sync_handler(self):
  86. mock_server = self._get_mock_server()
  87. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  88. _run(s.receive(packet.Packet(packet.MESSAGE, data='foo')))
  89. mock_server._trigger_event.mock.assert_called_once_with(
  90. 'message', 'sid', 'foo', run_async=False)
  91. def test_message_async_handler(self):
  92. mock_server = self._get_mock_server()
  93. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  94. mock_server.async_handlers = True
  95. _run(s.receive(packet.Packet(packet.MESSAGE, data='foo')))
  96. mock_server._trigger_event.mock.assert_called_once_with(
  97. 'message', 'sid', 'foo', run_async=True)
  98. def test_invalid_packet(self):
  99. mock_server = self._get_mock_server()
  100. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  101. self.assertRaises(exceptions.UnknownPacketError, _run,
  102. s.receive(packet.Packet(packet.OPEN)))
  103. def test_timeout(self):
  104. mock_server = self._get_mock_server()
  105. mock_server.ping_interval = -6
  106. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  107. s.last_ping = time.time() - 1
  108. s.close = AsyncMock()
  109. _run(s.send('packet'))
  110. s.close.mock.assert_called_once_with(wait=False, abort=False)
  111. def test_polling_read(self):
  112. mock_server = self._get_mock_server()
  113. s = asyncio_socket.AsyncSocket(mock_server, 'foo')
  114. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  115. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  116. _run(s.send(pkt1))
  117. _run(s.send(pkt2))
  118. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  119. packets = _run(s.handle_get_request(environ))
  120. self.assertEqual(packets, [pkt1, pkt2])
  121. def test_polling_read_error(self):
  122. mock_server = self._get_mock_server()
  123. s = asyncio_socket.AsyncSocket(mock_server, 'foo')
  124. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  125. self.assertRaises(exceptions.QueueEmpty, _run,
  126. s.handle_get_request(environ))
  127. def test_polling_write(self):
  128. mock_server = self._get_mock_server()
  129. mock_server.max_http_buffer_size = 1000
  130. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  131. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  132. p = payload.Payload(packets=[pkt1, pkt2]).encode()
  133. s = asyncio_socket.AsyncSocket(mock_server, 'foo')
  134. s.receive = AsyncMock()
  135. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
  136. 'CONTENT_LENGTH': len(p),
  137. 'wsgi.input': self._get_read_mock_coro(p)}
  138. _run(s.handle_post_request(environ))
  139. self.assertEqual(s.receive.mock.call_count, 2)
  140. def test_polling_write_too_large(self):
  141. mock_server = self._get_mock_server()
  142. pkt1 = packet.Packet(packet.MESSAGE, data='hello')
  143. pkt2 = packet.Packet(packet.MESSAGE, data='bye')
  144. p = payload.Payload(packets=[pkt1, pkt2]).encode()
  145. mock_server.max_http_buffer_size = len(p) - 1
  146. s = asyncio_socket.AsyncSocket(mock_server, 'foo')
  147. s.receive = AsyncMock()
  148. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
  149. 'CONTENT_LENGTH': len(p),
  150. 'wsgi.input': self._get_read_mock_coro(p)}
  151. self.assertRaises(exceptions.ContentTooLongError, _run,
  152. s.handle_post_request(environ))
  153. def test_upgrade_handshake(self):
  154. mock_server = self._get_mock_server()
  155. s = asyncio_socket.AsyncSocket(mock_server, 'foo')
  156. s._upgrade_websocket = AsyncMock()
  157. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  158. 'HTTP_CONNECTION': 'Foo,Upgrade,Bar',
  159. 'HTTP_UPGRADE': 'websocket'}
  160. _run(s.handle_get_request(environ))
  161. s._upgrade_websocket.mock.assert_called_once_with(environ)
  162. def test_upgrade(self):
  163. mock_server = self._get_mock_server()
  164. mock_server._async['websocket'] = mock.MagicMock()
  165. mock_ws = AsyncMock()
  166. mock_server._async['websocket'].return_value = mock_ws
  167. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  168. s.connected = True
  169. environ = "foo"
  170. _run(s._upgrade_websocket(environ))
  171. mock_server._async['websocket'].assert_called_once_with(
  172. s._websocket_handler)
  173. mock_ws.mock.assert_called_once_with(environ)
  174. def test_upgrade_twice(self):
  175. mock_server = self._get_mock_server()
  176. mock_server._async['websocket'] = mock.MagicMock()
  177. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  178. s.connected = True
  179. s.upgraded = True
  180. environ = "foo"
  181. self.assertRaises(IOError, _run, s._upgrade_websocket(environ))
  182. def test_upgrade_packet(self):
  183. mock_server = self._get_mock_server()
  184. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  185. s.connected = True
  186. _run(s.receive(packet.Packet(packet.UPGRADE)))
  187. r = _run(s.poll())
  188. self.assertEqual(len(r), 1)
  189. self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode())
  190. def test_upgrade_no_probe(self):
  191. mock_server = self._get_mock_server()
  192. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  193. s.connected = True
  194. ws = mock.MagicMock()
  195. ws.wait = AsyncMock()
  196. ws.wait.mock.return_value = packet.Packet(packet.NOOP).encode(
  197. always_bytes=False)
  198. _run(s._websocket_handler(ws))
  199. self.assertFalse(s.upgraded)
  200. def test_upgrade_no_upgrade_packet(self):
  201. mock_server = self._get_mock_server()
  202. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  203. s.connected = True
  204. s.queue.join = AsyncMock(return_value=None)
  205. ws = mock.MagicMock()
  206. ws.send = AsyncMock()
  207. ws.wait = AsyncMock()
  208. probe = six.text_type('probe')
  209. ws.wait.mock.side_effect = [
  210. packet.Packet(packet.PING, data=probe).encode(
  211. always_bytes=False),
  212. packet.Packet(packet.NOOP).encode(always_bytes=False)]
  213. _run(s._websocket_handler(ws))
  214. ws.send.mock.assert_called_once_with(packet.Packet(
  215. packet.PONG, data=probe).encode(always_bytes=False))
  216. self.assertEqual(_run(s.queue.get()).packet_type, packet.NOOP)
  217. self.assertFalse(s.upgraded)
  218. def test_upgrade_not_supported(self):
  219. mock_server = self._get_mock_server()
  220. mock_server._async['websocket'] = None
  221. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  222. s.connected = True
  223. environ = "foo"
  224. _run(s._upgrade_websocket(environ))
  225. mock_server._bad_request.assert_called_once_with()
  226. def test_close_packet(self):
  227. mock_server = self._get_mock_server()
  228. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  229. s.connected = True
  230. s.close = AsyncMock()
  231. _run(s.receive(packet.Packet(packet.CLOSE)))
  232. s.close.mock.assert_called_once_with(wait=False, abort=True)
  233. def test_websocket_read_write(self):
  234. mock_server = self._get_mock_server()
  235. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  236. s.connected = False
  237. s.queue.join = AsyncMock(return_value=None)
  238. foo = six.text_type('foo')
  239. bar = six.text_type('bar')
  240. s.poll = AsyncMock(side_effect=[
  241. [packet.Packet(packet.MESSAGE, data=bar)], None])
  242. ws = mock.MagicMock()
  243. ws.send = AsyncMock()
  244. ws.wait = AsyncMock()
  245. ws.wait.mock.side_effect = [
  246. packet.Packet(packet.MESSAGE, data=foo).encode(
  247. always_bytes=False),
  248. None]
  249. _run(s._websocket_handler(ws))
  250. self.assertTrue(s.connected)
  251. self.assertTrue(s.upgraded)
  252. self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
  253. mock_server._trigger_event.mock.assert_has_calls([
  254. mock.call('message', 'sid', 'foo', run_async=False),
  255. mock.call('disconnect', 'sid')])
  256. ws.send.mock.assert_called_with('4bar')
  257. def test_websocket_upgrade_read_write(self):
  258. mock_server = self._get_mock_server()
  259. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  260. s.connected = True
  261. s.queue.join = AsyncMock(return_value=None)
  262. foo = six.text_type('foo')
  263. bar = six.text_type('bar')
  264. probe = six.text_type('probe')
  265. s.poll = AsyncMock(side_effect=[
  266. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  267. ws = mock.MagicMock()
  268. ws.send = AsyncMock()
  269. ws.wait = AsyncMock()
  270. ws.wait.mock.side_effect = [
  271. packet.Packet(packet.PING, data=probe).encode(
  272. always_bytes=False),
  273. packet.Packet(packet.UPGRADE).encode(always_bytes=False),
  274. packet.Packet(packet.MESSAGE, data=foo).encode(
  275. always_bytes=False),
  276. None]
  277. _run(s._websocket_handler(ws))
  278. self.assertTrue(s.upgraded)
  279. self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
  280. mock_server._trigger_event.mock.assert_has_calls([
  281. mock.call('message', 'sid', 'foo', run_async=False),
  282. mock.call('disconnect', 'sid')])
  283. ws.send.mock.assert_called_with('4bar')
  284. def test_websocket_upgrade_with_payload(self):
  285. mock_server = self._get_mock_server()
  286. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  287. s.connected = True
  288. s.queue.join = AsyncMock(return_value=None)
  289. probe = six.text_type('probe')
  290. ws = mock.MagicMock()
  291. ws.send = AsyncMock()
  292. ws.wait = AsyncMock()
  293. ws.wait.mock.side_effect = [
  294. packet.Packet(packet.PING, data=probe).encode(
  295. always_bytes=False),
  296. packet.Packet(packet.UPGRADE, data=b'2').encode(
  297. always_bytes=False)]
  298. _run(s._websocket_handler(ws))
  299. self.assertTrue(s.upgraded)
  300. def test_websocket_upgrade_with_backlog(self):
  301. mock_server = self._get_mock_server()
  302. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  303. s.connected = True
  304. s.queue.join = AsyncMock(return_value=None)
  305. probe = six.text_type('probe')
  306. foo = six.text_type('foo')
  307. ws = mock.MagicMock()
  308. ws.send = AsyncMock()
  309. ws.wait = AsyncMock()
  310. ws.wait.mock.side_effect = [
  311. packet.Packet(packet.PING, data=probe).encode(
  312. always_bytes=False),
  313. packet.Packet(packet.UPGRADE, data=b'2').encode(
  314. always_bytes=False)]
  315. s.upgrading = True
  316. _run(s.send(packet.Packet(packet.MESSAGE, data=foo)))
  317. _run(s._websocket_handler(ws))
  318. self.assertTrue(s.upgraded)
  319. self.assertFalse(s.upgrading)
  320. self.assertEqual(s.packet_backlog, [])
  321. ws.send.mock.assert_called_with('4foo')
  322. def test_websocket_read_write_wait_fail(self):
  323. mock_server = self._get_mock_server()
  324. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  325. s.connected = False
  326. s.queue.join = AsyncMock(return_value=None)
  327. foo = six.text_type('foo')
  328. bar = six.text_type('bar')
  329. s.poll = AsyncMock(side_effect=[
  330. [packet.Packet(packet.MESSAGE, data=bar)],
  331. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  332. ws = mock.MagicMock()
  333. ws.send = AsyncMock()
  334. ws.wait = AsyncMock()
  335. ws.wait.mock.side_effect = [
  336. packet.Packet(packet.MESSAGE, data=foo).encode(
  337. always_bytes=False),
  338. RuntimeError]
  339. ws.send.mock.side_effect = [None, RuntimeError]
  340. _run(s._websocket_handler(ws))
  341. self.assertEqual(s.closed, True)
  342. def test_websocket_ignore_invalid_packet(self):
  343. mock_server = self._get_mock_server()
  344. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  345. s.connected = False
  346. s.queue.join = AsyncMock(return_value=None)
  347. foo = six.text_type('foo')
  348. bar = six.text_type('bar')
  349. s.poll = AsyncMock(side_effect=[
  350. [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
  351. ws = mock.MagicMock()
  352. ws.send = AsyncMock()
  353. ws.wait = AsyncMock()
  354. ws.wait.mock.side_effect = [
  355. packet.Packet(packet.OPEN).encode(always_bytes=False),
  356. packet.Packet(packet.MESSAGE, data=foo).encode(
  357. always_bytes=False),
  358. None]
  359. _run(s._websocket_handler(ws))
  360. self.assertTrue(s.connected)
  361. self.assertEqual(mock_server._trigger_event.mock.call_count, 2)
  362. mock_server._trigger_event.mock.assert_has_calls([
  363. mock.call('message', 'sid', foo, run_async=False),
  364. mock.call('disconnect', 'sid')])
  365. ws.send.mock.assert_called_with('4bar')
  366. def test_send_after_close(self):
  367. mock_server = self._get_mock_server()
  368. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  369. _run(s.close(wait=False))
  370. self.assertRaises(exceptions.SocketIsClosedError, _run,
  371. s.send(packet.Packet(packet.NOOP)))
  372. def test_close_after_close(self):
  373. mock_server = self._get_mock_server()
  374. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  375. _run(s.close(wait=False))
  376. self.assertTrue(s.closed)
  377. self.assertEqual(mock_server._trigger_event.mock.call_count, 1)
  378. mock_server._trigger_event.mock.assert_called_once_with('disconnect',
  379. 'sid')
  380. _run(s.close())
  381. self.assertEqual(mock_server._trigger_event.mock.call_count, 1)
  382. def test_close_and_wait(self):
  383. mock_server = self._get_mock_server()
  384. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  385. s.queue = mock.MagicMock()
  386. s.queue.put = AsyncMock()
  387. s.queue.join = AsyncMock()
  388. _run(s.close(wait=True))
  389. s.queue.join.mock.assert_called_once_with()
  390. def test_close_without_wait(self):
  391. mock_server = self._get_mock_server()
  392. s = asyncio_socket.AsyncSocket(mock_server, 'sid')
  393. s.queue = mock.MagicMock()
  394. s.queue.put = AsyncMock()
  395. s.queue.join = AsyncMock()
  396. _run(s.close(wait=False))
  397. self.assertEqual(s.queue.join.mock.call_count, 0)