You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

920 lines
39KB

  1. import asyncio
  2. import gzip
  3. import json
  4. import logging
  5. import sys
  6. import unittest
  7. import zlib
  8. import six
  9. if six.PY3:
  10. from unittest import mock
  11. else:
  12. import mock
  13. from engineio import asyncio_server
  14. from engineio.async_drivers import aiohttp as async_aiohttp
  15. from engineio import exceptions
  16. from engineio import packet
  17. from engineio import payload
  18. def AsyncMock(*args, **kwargs):
  19. """Return a mock asynchronous function."""
  20. m = mock.MagicMock(*args, **kwargs)
  21. async def mock_coro(*args, **kwargs):
  22. return m(*args, **kwargs)
  23. mock_coro.mock = m
  24. return mock_coro
  25. def _run(coro):
  26. """Run the given coroutine."""
  27. return asyncio.get_event_loop().run_until_complete(coro)
  28. @unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+')
  29. class TestAsyncServer(unittest.TestCase):
  30. @staticmethod
  31. def get_async_mock(environ={'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}):
  32. a = mock.MagicMock()
  33. a._async = {
  34. 'asyncio': True,
  35. 'create_route': mock.MagicMock(),
  36. 'translate_request': mock.MagicMock(),
  37. 'make_response': mock.MagicMock(),
  38. 'websocket': 'w',
  39. }
  40. a._async['translate_request'].return_value = environ
  41. a._async['make_response'].return_value = 'response'
  42. return a
  43. def _get_mock_socket(self):
  44. mock_socket = mock.MagicMock()
  45. mock_socket.connected = False
  46. mock_socket.closed = False
  47. mock_socket.closing = False
  48. mock_socket.upgraded = False
  49. mock_socket.send = AsyncMock()
  50. mock_socket.handle_get_request = AsyncMock()
  51. mock_socket.handle_post_request = AsyncMock()
  52. mock_socket.check_ping_timeout = AsyncMock()
  53. mock_socket.close = AsyncMock()
  54. mock_socket.session = {}
  55. return mock_socket
  56. @classmethod
  57. def setUpClass(cls):
  58. asyncio_server.AsyncServer._default_monitor_clients = False
  59. @classmethod
  60. def tearDownClass(cls):
  61. asyncio_server.AsyncServer._default_monitor_clients = True
  62. def setUp(self):
  63. logging.getLogger('engineio').setLevel(logging.NOTSET)
  64. def tearDown(self):
  65. # restore JSON encoder, in case a test changed it
  66. packet.Packet.json = json
  67. def test_is_asyncio_based(self):
  68. s = asyncio_server.AsyncServer()
  69. self.assertEqual(s.is_asyncio_based(), True)
  70. def test_async_modes(self):
  71. s = asyncio_server.AsyncServer()
  72. self.assertEqual(s.async_modes(), ['aiohttp', 'sanic', 'tornado',
  73. 'asgi'])
  74. def test_async_mode_aiohttp(self):
  75. s = asyncio_server.AsyncServer(async_mode='aiohttp')
  76. self.assertEqual(s.async_mode, 'aiohttp')
  77. self.assertEqual(s._async['asyncio'], True)
  78. self.assertEqual(s._async['create_route'], async_aiohttp.create_route)
  79. self.assertEqual(s._async['translate_request'],
  80. async_aiohttp.translate_request)
  81. self.assertEqual(s._async['make_response'],
  82. async_aiohttp.make_response)
  83. self.assertEqual(s._async['websocket'].__name__, 'WebSocket')
  84. @mock.patch('importlib.import_module')
  85. def test_async_mode_auto_aiohttp(self, import_module):
  86. import_module.side_effect = [self.get_async_mock()]
  87. s = asyncio_server.AsyncServer()
  88. self.assertEqual(s.async_mode, 'aiohttp')
  89. def test_async_modes_wsgi(self):
  90. self.assertRaises(ValueError, asyncio_server.AsyncServer,
  91. async_mode='eventlet')
  92. self.assertRaises(ValueError, asyncio_server.AsyncServer,
  93. async_mode='gevent')
  94. self.assertRaises(ValueError, asyncio_server.AsyncServer,
  95. async_mode='gevent_uwsgi')
  96. self.assertRaises(ValueError, asyncio_server.AsyncServer,
  97. async_mode='threading')
  98. @mock.patch('importlib.import_module')
  99. def test_attach(self, import_module):
  100. a = self.get_async_mock()
  101. import_module.side_effect = [a]
  102. s = asyncio_server.AsyncServer()
  103. s.attach('app', engineio_path='abc')
  104. a._async['create_route'].assert_called_with('app', s, '/abc/')
  105. s.attach('app', engineio_path='/def/')
  106. a._async['create_route'].assert_called_with('app', s, '/def/')
  107. s.attach('app', engineio_path='/ghi')
  108. a._async['create_route'].assert_called_with('app', s, '/ghi/')
  109. s.attach('app', engineio_path='jkl/')
  110. a._async['create_route'].assert_called_with('app', s, '/jkl/')
  111. def test_session(self):
  112. s = asyncio_server.AsyncServer()
  113. s.sockets['foo'] = self._get_mock_socket()
  114. async def _func():
  115. async with s.session('foo') as session:
  116. await s.sleep(0)
  117. session['username'] = 'bar'
  118. self.assertEqual(await s.get_session('foo'), {'username': 'bar'})
  119. _run(_func())
  120. def test_disconnect(self):
  121. s = asyncio_server.AsyncServer()
  122. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  123. _run(s.disconnect('foo'))
  124. self.assertEqual(mock_socket.close.mock.call_count, 1)
  125. mock_socket.close.mock.assert_called_once_with()
  126. self.assertNotIn('foo', s.sockets)
  127. def test_disconnect_all(self):
  128. s = asyncio_server.AsyncServer()
  129. s.sockets['foo'] = mock_foo = self._get_mock_socket()
  130. s.sockets['bar'] = mock_bar = self._get_mock_socket()
  131. _run(s.disconnect())
  132. self.assertEqual(mock_foo.close.mock.call_count, 1)
  133. self.assertEqual(mock_bar.close.mock.call_count, 1)
  134. mock_foo.close.mock.assert_called_once_with()
  135. mock_bar.close.mock.assert_called_once_with()
  136. self.assertNotIn('foo', s.sockets)
  137. self.assertNotIn('bar', s.sockets)
  138. @mock.patch('importlib.import_module')
  139. def test_jsonp_not_supported(self, import_module):
  140. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  141. 'QUERY_STRING': 'j=abc'})
  142. import_module.side_effect = [a]
  143. s = asyncio_server.AsyncServer()
  144. response = _run(s.handle_request('request'))
  145. self.assertEqual(response, 'response')
  146. a._async['translate_request'].assert_called_once_with('request')
  147. self.assertEqual(a._async['make_response'].call_count, 1)
  148. self.assertEqual(a._async['make_response'].call_args[0][0],
  149. '400 BAD REQUEST')
  150. @mock.patch('importlib.import_module')
  151. def test_jsonp_index(self, import_module):
  152. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  153. 'QUERY_STRING': 'j=233'})
  154. import_module.side_effect = [a]
  155. s = asyncio_server.AsyncServer()
  156. response = _run(s.handle_request('request'))
  157. self.assertEqual(response, 'response')
  158. a._async['translate_request'].assert_called_once_with('request')
  159. self.assertEqual(a._async['make_response'].call_count, 1)
  160. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  161. print('***', a._async['make_response'].call_args[0][2])
  162. self.assertTrue(a._async['make_response'].call_args[0][2].startswith(
  163. b'___eio[233]("'))
  164. self.assertTrue(a._async['make_response'].call_args[0][2].endswith(
  165. b'");'))
  166. @mock.patch('importlib.import_module')
  167. def test_connect(self, import_module):
  168. a = self.get_async_mock()
  169. import_module.side_effect = [a]
  170. s = asyncio_server.AsyncServer()
  171. _run(s.handle_request('request'))
  172. self.assertEqual(len(s.sockets), 1)
  173. self.assertEqual(a._async['make_response'].call_count, 1)
  174. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  175. self.assertIn(('Content-Type', 'application/octet-stream'),
  176. a._async['make_response'].call_args[0][1])
  177. packets = payload.Payload(
  178. encoded_payload=a._async['make_response'].call_args[0][2]).packets
  179. self.assertEqual(len(packets), 1)
  180. self.assertEqual(packets[0].packet_type, packet.OPEN)
  181. self.assertIn('upgrades', packets[0].data)
  182. self.assertEqual(packets[0].data['upgrades'], ['websocket'])
  183. self.assertIn('sid', packets[0].data)
  184. @mock.patch('importlib.import_module')
  185. def test_connect_async_request_response_handlers(self, import_module):
  186. a = self.get_async_mock()
  187. a._async['translate_request'] = AsyncMock(
  188. return_value=a._async['translate_request'].return_value)
  189. a._async['make_response'] = AsyncMock(
  190. return_value=a._async['make_response'].return_value)
  191. import_module.side_effect = [a]
  192. s = asyncio_server.AsyncServer()
  193. _run(s.handle_request('request'))
  194. self.assertEqual(len(s.sockets), 1)
  195. self.assertEqual(a._async['make_response'].mock.call_count, 1)
  196. self.assertEqual(a._async['make_response'].mock.call_args[0][0],
  197. '200 OK')
  198. self.assertIn(('Content-Type', 'application/octet-stream'),
  199. a._async['make_response'].mock.call_args[0][1])
  200. packets = payload.Payload(encoded_payload=a._async[
  201. 'make_response'].mock.call_args[0][2]).packets
  202. self.assertEqual(len(packets), 1)
  203. self.assertEqual(packets[0].packet_type, packet.OPEN)
  204. self.assertIn('upgrades', packets[0].data)
  205. self.assertEqual(packets[0].data['upgrades'], ['websocket'])
  206. self.assertIn('sid', packets[0].data)
  207. @mock.patch('importlib.import_module')
  208. def test_connect_no_upgrades(self, import_module):
  209. a = self.get_async_mock()
  210. import_module.side_effect = [a]
  211. s = asyncio_server.AsyncServer(allow_upgrades=False)
  212. _run(s.handle_request('request'))
  213. packets = payload.Payload(
  214. encoded_payload=a._async['make_response'].call_args[0][2]).packets
  215. self.assertEqual(packets[0].data['upgrades'], [])
  216. @mock.patch('importlib.import_module')
  217. def test_connect_b64_with_1(self, import_module):
  218. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  219. 'QUERY_STRING': 'b64=1'})
  220. import_module.side_effect = [a]
  221. s = asyncio_server.AsyncServer(allow_upgrades=False)
  222. s._generate_id = mock.MagicMock(return_value='1')
  223. _run(s.handle_request('request'))
  224. self.assertEqual(a._async['make_response'].call_count, 1)
  225. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  226. self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
  227. a._async['make_response'].call_args[0][1])
  228. _run(s.send('1', b'\x00\x01\x02', binary=True))
  229. a._async['translate_request'].return_value = {
  230. 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=1'}
  231. _run(s.handle_request('request'))
  232. self.assertEqual(a._async['make_response'].call_args[0][2],
  233. b'6:b4AAEC')
  234. @mock.patch('importlib.import_module')
  235. def test_connect_b64_with_true(self, import_module):
  236. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  237. 'QUERY_STRING': 'b64=true'})
  238. import_module.side_effect = [a]
  239. s = asyncio_server.AsyncServer(allow_upgrades=False)
  240. s._generate_id = mock.MagicMock(return_value='1')
  241. _run(s.handle_request('request'))
  242. self.assertEqual(a._async['make_response'].call_count, 1)
  243. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  244. self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
  245. a._async['make_response'].call_args[0][1])
  246. _run(s.send('1', b'\x00\x01\x02', binary=True))
  247. a._async['translate_request'].return_value = {
  248. 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=true'}
  249. _run(s.handle_request('request'))
  250. self.assertEqual(a._async['make_response'].call_args[0][2],
  251. b'6:b4AAEC')
  252. @mock.patch('importlib.import_module')
  253. def test_connect_b64_with_0(self, import_module):
  254. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  255. 'QUERY_STRING': 'b64=0'})
  256. import_module.side_effect = [a]
  257. s = asyncio_server.AsyncServer(allow_upgrades=False)
  258. s._generate_id = mock.MagicMock(return_value='1')
  259. _run(s.handle_request('request'))
  260. self.assertEqual(a._async['make_response'].call_count, 1)
  261. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  262. self.assertIn(('Content-Type', 'application/octet-stream'),
  263. a._async['make_response'].call_args[0][1])
  264. _run(s.send('1', b'\x00\x01\x02', binary=True))
  265. a._async['translate_request'].return_value = {
  266. 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=0'}
  267. _run(s.handle_request('request'))
  268. self.assertEqual(a._async['make_response'].call_args[0][2],
  269. b'\x01\x04\xff\x04\x00\x01\x02')
  270. @mock.patch('importlib.import_module')
  271. def test_connect_b64_with_false(self, import_module):
  272. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  273. 'QUERY_STRING': 'b64=false'})
  274. import_module.side_effect = [a]
  275. s = asyncio_server.AsyncServer(allow_upgrades=False)
  276. s._generate_id = mock.MagicMock(return_value='1')
  277. _run(s.handle_request('request'))
  278. self.assertEqual(a._async['make_response'].call_count, 1)
  279. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  280. self.assertIn(('Content-Type', 'application/octet-stream'),
  281. a._async['make_response'].call_args[0][1])
  282. _run(s.send('1', b'\x00\x01\x02', binary=True))
  283. a._async['translate_request'].return_value = {
  284. 'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=false'}
  285. _run(s.handle_request('request'))
  286. self.assertEqual(a._async['make_response'].call_args[0][2],
  287. b'\x01\x04\xff\x04\x00\x01\x02')
  288. @mock.patch('importlib.import_module')
  289. def test_connect_custom_ping_times(self, import_module):
  290. a = self.get_async_mock()
  291. import_module.side_effect = [a]
  292. s = asyncio_server.AsyncServer(ping_timeout=123, ping_interval=456)
  293. _run(s.handle_request('request'))
  294. packets = payload.Payload(
  295. encoded_payload=a._async['make_response'].call_args[0][2]).packets
  296. self.assertEqual(packets[0].data['pingTimeout'], 123000)
  297. self.assertEqual(packets[0].data['pingInterval'], 456000)
  298. @mock.patch('engineio.asyncio_socket.AsyncSocket')
  299. @mock.patch('importlib.import_module')
  300. def test_connect_bad_poll(self, import_module, AsyncSocket):
  301. a = self.get_async_mock()
  302. import_module.side_effect = [a]
  303. AsyncSocket.return_value = self._get_mock_socket()
  304. AsyncSocket.return_value.poll.side_effect = [exceptions.QueueEmpty]
  305. s = asyncio_server.AsyncServer()
  306. _run(s.handle_request('request'))
  307. self.assertEqual(a._async['make_response'].call_count, 1)
  308. self.assertEqual(a._async['make_response'].call_args[0][0],
  309. '400 BAD REQUEST')
  310. @mock.patch('engineio.asyncio_socket.AsyncSocket')
  311. @mock.patch('importlib.import_module')
  312. def test_connect_transport_websocket(self, import_module, AsyncSocket):
  313. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  314. 'QUERY_STRING': 'transport=websocket'})
  315. import_module.side_effect = [a]
  316. AsyncSocket.return_value = self._get_mock_socket()
  317. s = asyncio_server.AsyncServer()
  318. s._generate_id = mock.MagicMock(return_value='123')
  319. # force socket to stay open, so that we can check it later
  320. AsyncSocket().closed = False
  321. _run(s.handle_request('request'))
  322. self.assertEqual(
  323. s.sockets['123'].send.mock.call_args[0][0].packet_type,
  324. packet.OPEN)
  325. @mock.patch('engineio.asyncio_socket.AsyncSocket')
  326. @mock.patch('importlib.import_module')
  327. def test_connect_transport_websocket_closed(self, import_module,
  328. AsyncSocket):
  329. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  330. 'QUERY_STRING': 'transport=websocket'})
  331. import_module.side_effect = [a]
  332. AsyncSocket.return_value = self._get_mock_socket()
  333. s = asyncio_server.AsyncServer()
  334. s._generate_id = mock.MagicMock(return_value='123')
  335. # this mock handler just closes the socket, as it would happen on a
  336. # real websocket exchange
  337. async def mock_handle(environ):
  338. s.sockets['123'].closed = True
  339. AsyncSocket().handle_get_request = mock_handle
  340. _run(s.handle_request('request'))
  341. self.assertNotIn('123', s.sockets) # socket should close on its own
  342. @mock.patch('importlib.import_module')
  343. def test_connect_transport_invalid(self, import_module):
  344. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  345. 'QUERY_STRING': 'transport=foo'})
  346. import_module.side_effect = [a]
  347. s = asyncio_server.AsyncServer()
  348. _run(s.handle_request('request'))
  349. self.assertEqual(a._async['make_response'].call_count, 1)
  350. self.assertEqual(a._async['make_response'].call_args[0][0],
  351. '400 BAD REQUEST')
  352. @mock.patch('importlib.import_module')
  353. def test_connect_cors_headers(self, import_module):
  354. a = self.get_async_mock()
  355. import_module.side_effect = [a]
  356. s = asyncio_server.AsyncServer()
  357. _run(s.handle_request('request'))
  358. headers = a._async['make_response'].call_args[0][1]
  359. self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
  360. self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
  361. @mock.patch('importlib.import_module')
  362. def test_connect_cors_allowed_origin(self, import_module):
  363. a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  364. 'HTTP_ORIGIN': 'b'})
  365. import_module.side_effect = [a]
  366. s = asyncio_server.AsyncServer(cors_allowed_origins=['a', 'b'])
  367. _run(s.handle_request('request'))
  368. headers = a._async['make_response'].call_args[0][1]
  369. self.assertIn(('Access-Control-Allow-Origin', 'b'), headers)
  370. @mock.patch('importlib.import_module')
  371. def test_connect_cors_not_allowed_origin(self, import_module):
  372. a = self.get_async_mock({'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  373. 'HTTP_ORIGIN': 'c'})
  374. import_module.side_effect = [a]
  375. s = asyncio_server.AsyncServer(cors_allowed_origins=['a', 'b'])
  376. _run(s.handle_request('request'))
  377. headers = a._async['make_response'].call_args[0][1]
  378. self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
  379. self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)
  380. @mock.patch('importlib.import_module')
  381. def test_connect_cors_no_credentials(self, import_module):
  382. a = self.get_async_mock()
  383. import_module.side_effect = [a]
  384. s = asyncio_server.AsyncServer(cors_credentials=False)
  385. _run(s.handle_request('request'))
  386. headers = a._async['make_response'].call_args[0][1]
  387. self.assertNotIn(('Access-Control-Allow-Credentials', 'true'), headers)
  388. @mock.patch('importlib.import_module')
  389. def test_connect_cors_options(self, import_module):
  390. a = self.get_async_mock({'REQUEST_METHOD': 'OPTIONS',
  391. 'QUERY_STRING': ''})
  392. import_module.side_effect = [a]
  393. s = asyncio_server.AsyncServer(cors_credentials=False)
  394. _run(s.handle_request('request'))
  395. headers = a._async['make_response'].call_args[0][1]
  396. self.assertIn(('Access-Control-Allow-Methods',
  397. 'OPTIONS, GET, POST'), headers)
  398. @mock.patch('importlib.import_module')
  399. def test_connect_event(self, import_module):
  400. a = self.get_async_mock()
  401. import_module.side_effect = [a]
  402. s = asyncio_server.AsyncServer()
  403. s._generate_id = mock.MagicMock(return_value='123')
  404. def mock_connect(sid, environ):
  405. return True
  406. s.on('connect', handler=mock_connect)
  407. _run(s.handle_request('request'))
  408. self.assertEqual(len(s.sockets), 1)
  409. @mock.patch('importlib.import_module')
  410. def test_connect_event_rejects(self, import_module):
  411. a = self.get_async_mock()
  412. import_module.side_effect = [a]
  413. s = asyncio_server.AsyncServer()
  414. s._generate_id = mock.MagicMock(return_value='123')
  415. def mock_connect(sid, environ):
  416. return False
  417. s.on('connect')(mock_connect)
  418. _run(s.handle_request('request'))
  419. self.assertEqual(len(s.sockets), 0)
  420. self.assertEqual(a._async['make_response'].call_args[0][0],
  421. '401 UNAUTHORIZED')
  422. @mock.patch('importlib.import_module')
  423. def test_method_not_found(self, import_module):
  424. a = self.get_async_mock({'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''})
  425. import_module.side_effect = [a]
  426. s = asyncio_server.AsyncServer()
  427. _run(s.handle_request('request'))
  428. self.assertEqual(len(s.sockets), 0)
  429. self.assertEqual(a._async['make_response'].call_args[0][0],
  430. '405 METHOD NOT FOUND')
  431. @mock.patch('importlib.import_module')
  432. def test_get_request_with_bad_sid(self, import_module):
  433. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  434. 'QUERY_STRING': 'sid=foo'})
  435. import_module.side_effect = [a]
  436. s = asyncio_server.AsyncServer()
  437. _run(s.handle_request('request'))
  438. self.assertEqual(len(s.sockets), 0)
  439. self.assertEqual(a._async['make_response'].call_args[0][0],
  440. '400 BAD REQUEST')
  441. @mock.patch('importlib.import_module')
  442. def test_post_request_with_bad_sid(self, import_module):
  443. a = self.get_async_mock({'REQUEST_METHOD': 'POST',
  444. 'QUERY_STRING': 'sid=foo'})
  445. import_module.side_effect = [a]
  446. s = asyncio_server.AsyncServer()
  447. _run(s.handle_request('request'))
  448. self.assertEqual(len(s.sockets), 0)
  449. self.assertEqual(a._async['make_response'].call_args[0][0],
  450. '400 BAD REQUEST')
  451. @mock.patch('importlib.import_module')
  452. def test_send(self, import_module):
  453. a = self.get_async_mock()
  454. import_module.side_effect = [a]
  455. s = asyncio_server.AsyncServer()
  456. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  457. _run(s.send('foo', 'hello'))
  458. self.assertEqual(mock_socket.send.mock.call_count, 1)
  459. self.assertEqual(mock_socket.send.mock.call_args[0][0].packet_type,
  460. packet.MESSAGE)
  461. self.assertEqual(mock_socket.send.mock.call_args[0][0].data, 'hello')
  462. @mock.patch('importlib.import_module')
  463. def test_send_unknown_socket(self, import_module):
  464. a = self.get_async_mock()
  465. import_module.side_effect = [a]
  466. s = asyncio_server.AsyncServer()
  467. # just ensure no exceptions are raised
  468. _run(s.send('foo', 'hello'))
  469. @mock.patch('importlib.import_module')
  470. def test_get_request(self, import_module):
  471. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  472. 'QUERY_STRING': 'sid=foo'})
  473. import_module.side_effect = [a]
  474. s = asyncio_server.AsyncServer()
  475. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  476. mock_socket.handle_get_request.mock.return_value = \
  477. [packet.Packet(packet.MESSAGE, data='hello')]
  478. _run(s.handle_request('request'))
  479. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  480. packets = payload.Payload(
  481. encoded_payload=a._async['make_response'].call_args[0][2]).packets
  482. self.assertEqual(len(packets), 1)
  483. self.assertEqual(packets[0].packet_type, packet.MESSAGE)
  484. @mock.patch('importlib.import_module')
  485. def test_get_request_custom_response(self, import_module):
  486. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  487. 'QUERY_STRING': 'sid=foo'})
  488. import_module.side_effect = [a]
  489. s = asyncio_server.AsyncServer()
  490. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  491. mock_socket.handle_get_request.mock.return_value = 'resp'
  492. r = _run(s.handle_request('request'))
  493. self.assertEqual(r, 'resp')
  494. @mock.patch('importlib.import_module')
  495. def test_get_request_closes_socket(self, import_module):
  496. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  497. 'QUERY_STRING': 'sid=foo'})
  498. import_module.side_effect = [a]
  499. s = asyncio_server.AsyncServer()
  500. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  501. async def mock_get_request(*args, **kwargs):
  502. mock_socket.closed = True
  503. return 'resp'
  504. mock_socket.handle_get_request = mock_get_request
  505. r = _run(s.handle_request('request'))
  506. self.assertEqual(r, 'resp')
  507. self.assertNotIn('foo', s.sockets)
  508. @mock.patch('importlib.import_module')
  509. def test_get_request_error(self, import_module):
  510. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  511. 'QUERY_STRING': 'sid=foo'})
  512. import_module.side_effect = [a]
  513. s = asyncio_server.AsyncServer()
  514. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  515. async def mock_get_request(*args, **kwargs):
  516. raise exceptions.QueueEmpty()
  517. mock_socket.handle_get_request = mock_get_request
  518. _run(s.handle_request('request'))
  519. self.assertEqual(a._async['make_response'].call_args[0][0],
  520. '400 BAD REQUEST')
  521. self.assertEqual(len(s.sockets), 0)
  522. @mock.patch('importlib.import_module')
  523. def test_post_request(self, import_module):
  524. a = self.get_async_mock({'REQUEST_METHOD': 'POST',
  525. 'QUERY_STRING': 'sid=foo'})
  526. import_module.side_effect = [a]
  527. s = asyncio_server.AsyncServer()
  528. s.sockets['foo'] = self._get_mock_socket()
  529. _run(s.handle_request('request'))
  530. self.assertEqual(a._async['make_response'].call_args[0][0], '200 OK')
  531. @mock.patch('importlib.import_module')
  532. def test_post_request_error(self, import_module):
  533. a = self.get_async_mock({'REQUEST_METHOD': 'POST',
  534. 'QUERY_STRING': 'sid=foo'})
  535. import_module.side_effect = [a]
  536. s = asyncio_server.AsyncServer()
  537. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  538. async def mock_post_request(*args, **kwargs):
  539. raise exceptions.ContentTooLongError()
  540. mock_socket.handle_post_request = mock_post_request
  541. _run(s.handle_request('request'))
  542. self.assertEqual(a._async['make_response'].call_args[0][0],
  543. '400 BAD REQUEST')
  544. @staticmethod
  545. def _gzip_decompress(b):
  546. bytesio = six.BytesIO(b)
  547. with gzip.GzipFile(fileobj=bytesio, mode='r') as gz:
  548. return gz.read()
  549. @mock.patch('importlib.import_module')
  550. def test_gzip_compression(self, import_module):
  551. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  552. 'QUERY_STRING': 'sid=foo',
  553. 'HTTP_ACCEPT_ENCODING': 'gzip,deflate'})
  554. import_module.side_effect = [a]
  555. s = asyncio_server.AsyncServer(compression_threshold=0)
  556. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  557. mock_socket.handle_get_request.mock.return_value = \
  558. [packet.Packet(packet.MESSAGE, data='hello')]
  559. _run(s.handle_request('request'))
  560. headers = a._async['make_response'].call_args[0][1]
  561. self.assertIn(('Content-Encoding', 'gzip'), headers)
  562. self._gzip_decompress(a._async['make_response'].call_args[0][2])
  563. @mock.patch('importlib.import_module')
  564. def test_deflate_compression(self, import_module):
  565. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  566. 'QUERY_STRING': 'sid=foo',
  567. 'HTTP_ACCEPT_ENCODING': 'deflate;q=1,gzip'})
  568. import_module.side_effect = [a]
  569. s = asyncio_server.AsyncServer(compression_threshold=0)
  570. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  571. mock_socket.handle_get_request.mock.return_value = \
  572. [packet.Packet(packet.MESSAGE, data='hello')]
  573. _run(s.handle_request('request'))
  574. headers = a._async['make_response'].call_args[0][1]
  575. self.assertIn(('Content-Encoding', 'deflate'), headers)
  576. zlib.decompress(a._async['make_response'].call_args[0][2])
  577. @mock.patch('importlib.import_module')
  578. def test_gzip_compression_threshold(self, import_module):
  579. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  580. 'QUERY_STRING': 'sid=foo',
  581. 'HTTP_ACCEPT_ENCODING': 'gzip'})
  582. import_module.side_effect = [a]
  583. s = asyncio_server.AsyncServer(compression_threshold=1000)
  584. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  585. mock_socket.handle_get_request.mock.return_value = \
  586. [packet.Packet(packet.MESSAGE, data='hello')]
  587. _run(s.handle_request('request'))
  588. headers = a._async['make_response'].call_args[0][1]
  589. for header, value in headers:
  590. self.assertNotEqual(header, 'Content-Encoding')
  591. self.assertRaises(IOError, self._gzip_decompress,
  592. a._async['make_response'].call_args[0][2])
  593. @mock.patch('importlib.import_module')
  594. def test_compression_disabled(self, import_module):
  595. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  596. 'QUERY_STRING': 'sid=foo',
  597. 'HTTP_ACCEPT_ENCODING': 'gzip'})
  598. import_module.side_effect = [a]
  599. s = asyncio_server.AsyncServer(http_compression=False,
  600. compression_threshold=0)
  601. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  602. mock_socket.handle_get_request.mock.return_value = \
  603. [packet.Packet(packet.MESSAGE, data='hello')]
  604. _run(s.handle_request('request'))
  605. headers = a._async['make_response'].call_args[0][1]
  606. for header, value in headers:
  607. self.assertNotEqual(header, 'Content-Encoding')
  608. self.assertRaises(IOError, self._gzip_decompress,
  609. a._async['make_response'].call_args[0][2])
  610. @mock.patch('importlib.import_module')
  611. def test_compression_unknown(self, import_module):
  612. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  613. 'QUERY_STRING': 'sid=foo',
  614. 'HTTP_ACCEPT_ENCODING': 'rar'})
  615. import_module.side_effect = [a]
  616. s = asyncio_server.AsyncServer(compression_threshold=0)
  617. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  618. mock_socket.handle_get_request.mock.return_value = \
  619. [packet.Packet(packet.MESSAGE, data='hello')]
  620. _run(s.handle_request('request'))
  621. headers = a._async['make_response'].call_args[0][1]
  622. for header, value in headers:
  623. self.assertNotEqual(header, 'Content-Encoding')
  624. self.assertRaises(IOError, self._gzip_decompress,
  625. a._async['make_response'].call_args[0][2])
  626. @mock.patch('importlib.import_module')
  627. def test_compression_no_encoding(self, import_module):
  628. a = self.get_async_mock({'REQUEST_METHOD': 'GET',
  629. 'QUERY_STRING': 'sid=foo',
  630. 'HTTP_ACCEPT_ENCODING': ''})
  631. import_module.side_effect = [a]
  632. s = asyncio_server.AsyncServer(compression_threshold=0)
  633. s.sockets['foo'] = mock_socket = self._get_mock_socket()
  634. mock_socket.handle_get_request.mock.return_value = \
  635. [packet.Packet(packet.MESSAGE, data='hello')]
  636. _run(s.handle_request('request'))
  637. headers = a._async['make_response'].call_args[0][1]
  638. for header, value in headers:
  639. self.assertNotEqual(header, 'Content-Encoding')
  640. self.assertRaises(IOError, self._gzip_decompress,
  641. a._async['make_response'].call_args[0][2])
  642. @mock.patch('importlib.import_module')
  643. def test_cookie(self, import_module):
  644. a = self.get_async_mock()
  645. import_module.side_effect = [a]
  646. s = asyncio_server.AsyncServer(cookie='sid')
  647. s._generate_id = mock.MagicMock(return_value='123')
  648. _run(s.handle_request('request'))
  649. headers = a._async['make_response'].call_args[0][1]
  650. self.assertIn(('Set-Cookie', 'sid=123'), headers)
  651. @mock.patch('importlib.import_module')
  652. def test_no_cookie(self, import_module):
  653. a = self.get_async_mock()
  654. import_module.side_effect = [a]
  655. s = asyncio_server.AsyncServer(cookie=None)
  656. s._generate_id = mock.MagicMock(return_value='123')
  657. _run(s.handle_request('request'))
  658. headers = a._async['make_response'].call_args[0][1]
  659. for header, value in headers:
  660. self.assertNotEqual(header, 'Set-Cookie')
  661. def test_logger(self):
  662. s = asyncio_server.AsyncServer(logger=False)
  663. self.assertEqual(s.logger.getEffectiveLevel(), logging.ERROR)
  664. s.logger.setLevel(logging.NOTSET)
  665. s = asyncio_server.AsyncServer(logger=True)
  666. self.assertEqual(s.logger.getEffectiveLevel(), logging.INFO)
  667. s.logger.setLevel(logging.WARNING)
  668. s = asyncio_server.AsyncServer(logger=True)
  669. self.assertEqual(s.logger.getEffectiveLevel(), logging.WARNING)
  670. s.logger.setLevel(logging.NOTSET)
  671. my_logger = logging.Logger('foo')
  672. s = asyncio_server.AsyncServer(logger=my_logger)
  673. self.assertEqual(s.logger, my_logger)
  674. def test_custom_json(self):
  675. # Warning: this test cannot run in parallel with other tests, as it
  676. # changes the JSON encoding/decoding functions
  677. class CustomJSON(object):
  678. @staticmethod
  679. def dumps(*args, **kwargs):
  680. return '*** encoded ***'
  681. @staticmethod
  682. def loads(*args, **kwargs):
  683. return '+++ decoded +++'
  684. asyncio_server.AsyncServer(json=CustomJSON)
  685. pkt = packet.Packet(packet.MESSAGE, data={'foo': 'bar'})
  686. self.assertEqual(pkt.encode(), b'4*** encoded ***')
  687. pkt2 = packet.Packet(encoded_packet=pkt.encode())
  688. self.assertEqual(pkt2.data, '+++ decoded +++')
  689. # restore the default JSON module
  690. packet.Packet.json = json
  691. def test_background_tasks(self):
  692. r = []
  693. async def foo(arg):
  694. r.append(arg)
  695. s = asyncio_server.AsyncServer()
  696. s.start_background_task(foo, 'bar')
  697. pending = asyncio.Task.all_tasks()
  698. asyncio.get_event_loop().run_until_complete(asyncio.wait(pending))
  699. self.assertEqual(r, ['bar'])
  700. def test_sleep(self):
  701. s = asyncio_server.AsyncServer()
  702. _run(s.sleep(0))
  703. def test_trigger_event_function(self):
  704. result = []
  705. def foo_handler(arg):
  706. result.append('ok')
  707. result.append(arg)
  708. s = asyncio_server.AsyncServer()
  709. s.on('message', handler=foo_handler)
  710. _run(s._trigger_event('message', 'bar'))
  711. self.assertEqual(result, ['ok', 'bar'])
  712. def test_trigger_event_coroutine(self):
  713. result = []
  714. async def foo_handler(arg):
  715. result.append('ok')
  716. result.append(arg)
  717. s = asyncio_server.AsyncServer()
  718. s.on('message', handler=foo_handler)
  719. _run(s._trigger_event('message', 'bar'))
  720. self.assertEqual(result, ['ok', 'bar'])
  721. def test_trigger_event_function_error(self):
  722. def connect_handler(arg):
  723. return 1 / 0
  724. def foo_handler(arg):
  725. return 1 / 0
  726. s = asyncio_server.AsyncServer()
  727. s.on('connect', handler=connect_handler)
  728. s.on('message', handler=foo_handler)
  729. self.assertFalse(_run(s._trigger_event('connect', '123')))
  730. self.assertIsNone(_run(s._trigger_event('message', 'bar')))
  731. def test_trigger_event_coroutine_error(self):
  732. async def connect_handler(arg):
  733. return 1 / 0
  734. async def foo_handler(arg):
  735. return 1 / 0
  736. s = asyncio_server.AsyncServer()
  737. s.on('connect', handler=connect_handler)
  738. s.on('message', handler=foo_handler)
  739. self.assertFalse(_run(s._trigger_event('connect', '123')))
  740. self.assertIsNone(_run(s._trigger_event('message', 'bar')))
  741. def test_trigger_event_function_async(self):
  742. result = []
  743. def foo_handler(arg):
  744. result.append('ok')
  745. result.append(arg)
  746. s = asyncio_server.AsyncServer()
  747. s.on('message', handler=foo_handler)
  748. fut = _run(s._trigger_event('message', 'bar', run_async=True))
  749. asyncio.get_event_loop().run_until_complete(fut)
  750. self.assertEqual(result, ['ok', 'bar'])
  751. def test_trigger_event_coroutine_async(self):
  752. result = []
  753. async def foo_handler(arg):
  754. result.append('ok')
  755. result.append(arg)
  756. s = asyncio_server.AsyncServer()
  757. s.on('message', handler=foo_handler)
  758. fut = _run(s._trigger_event('message', 'bar', run_async=True))
  759. asyncio.get_event_loop().run_until_complete(fut)
  760. self.assertEqual(result, ['ok', 'bar'])
  761. def test_trigger_event_function_async_error(self):
  762. result = []
  763. def foo_handler(arg):
  764. result.append(arg)
  765. return 1 / 0
  766. s = asyncio_server.AsyncServer()
  767. s.on('message', handler=foo_handler)
  768. fut = _run(s._trigger_event('message', 'bar', run_async=True))
  769. self.assertRaises(
  770. ZeroDivisionError, asyncio.get_event_loop().run_until_complete,
  771. fut)
  772. self.assertEqual(result, ['bar'])
  773. def test_trigger_event_coroutine_async_error(self):
  774. result = []
  775. async def foo_handler(arg):
  776. result.append(arg)
  777. return 1 / 0
  778. s = asyncio_server.AsyncServer()
  779. s.on('message', handler=foo_handler)
  780. fut = _run(s._trigger_event('message', 'bar', run_async=True))
  781. self.assertRaises(
  782. ZeroDivisionError, asyncio.get_event_loop().run_until_complete,
  783. fut)
  784. self.assertEqual(result, ['bar'])
  785. def test_create_queue(self):
  786. s = asyncio_server.AsyncServer()
  787. q = s.create_queue()
  788. empty = s.get_queue_empty_exception()
  789. self.assertRaises(empty, q.get_nowait)
  790. def test_create_event(self):
  791. s = asyncio_server.AsyncServer()
  792. e = s.create_event()
  793. self.assertFalse(e.is_set())
  794. e.set()
  795. self.assertTrue(e.is_set())
  796. @mock.patch('importlib.import_module')
  797. def test_service_task_started(self, import_module):
  798. a = self.get_async_mock()
  799. import_module.side_effect = [a]
  800. s = asyncio_server.AsyncServer(monitor_clients=True)
  801. s._service_task = AsyncMock()
  802. _run(s.handle_request('request'))
  803. s._service_task.mock.assert_called_once_with()