Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

953 linhas
39KB

  1. import gzip
  2. import importlib
  3. import json
  4. import logging
  5. import sys
  6. import time
  7. import unittest
  8. import zlib
  9. import six
  10. if six.PY3:
  11. from unittest import mock
  12. else:
  13. import mock
  14. from engineio import exceptions
  15. from engineio import packet
  16. from engineio import payload
  17. from engineio import server
  18. original_import_module = importlib.import_module
  19. def _mock_import(module, *args, **kwargs):
  20. if module.startswith('engineio.'):
  21. return original_import_module(module, *args, **kwargs)
  22. return module
  23. class TestServer(unittest.TestCase):
  24. _mock_async = mock.MagicMock()
  25. _mock_async._async = {
  26. 'thread': 't',
  27. 'queue': 'q',
  28. 'queue_empty': RuntimeError,
  29. 'websocket': 'w',
  30. }
  31. def _get_mock_socket(self):
  32. mock_socket = mock.MagicMock()
  33. mock_socket.closed = False
  34. mock_socket.closing = False
  35. mock_socket.upgraded = False
  36. mock_socket.session = {}
  37. return mock_socket
  38. @classmethod
  39. def setUpClass(cls):
  40. server.Server._default_monitor_clients = False
  41. @classmethod
  42. def tearDownClass(cls):
  43. server.Server._default_monitor_clients = True
  44. def setUp(self):
  45. logging.getLogger('engineio').setLevel(logging.NOTSET)
  46. def tearDown(self):
  47. # restore JSON encoder, in case a test changed it
  48. packet.Packet.json = json
  49. def test_is_asyncio_based(self):
  50. s = server.Server()
  51. self.assertEqual(s.is_asyncio_based(), False)
  52. def test_async_modes(self):
  53. s = server.Server()
  54. self.assertEqual(s.async_modes(), ['eventlet', 'gevent_uwsgi',
  55. 'gevent', 'threading'])
  56. def test_create(self):
  57. kwargs = {
  58. 'ping_timeout': 1,
  59. 'ping_interval': 2,
  60. 'max_http_buffer_size': 3,
  61. 'allow_upgrades': False,
  62. 'http_compression': False,
  63. 'compression_threshold': 4,
  64. 'cookie': 'foo',
  65. 'cors_allowed_origins': ['foo', 'bar', 'baz'],
  66. 'cors_credentials': False,
  67. 'async_handlers': False}
  68. s = server.Server(**kwargs)
  69. for arg in six.iterkeys(kwargs):
  70. self.assertEqual(getattr(s, arg), kwargs[arg])
  71. def test_create_ignores_kwargs(self):
  72. server.Server(foo='bar') # this should not raise
  73. def test_async_mode_threading(self):
  74. s = server.Server(async_mode='threading')
  75. self.assertEqual(s.async_mode, 'threading')
  76. import threading
  77. try:
  78. import queue
  79. except ImportError:
  80. import Queue as queue
  81. self.assertEqual(s._async['thread'], threading.Thread)
  82. self.assertEqual(s._async['queue'], queue.Queue)
  83. self.assertEqual(s._async['websocket'], None)
  84. def test_async_mode_eventlet(self):
  85. s = server.Server(async_mode='eventlet')
  86. self.assertEqual(s.async_mode, 'eventlet')
  87. from eventlet.green import threading
  88. from eventlet import queue
  89. from engineio.async_drivers import eventlet as async_eventlet
  90. self.assertEqual(s._async['thread'], threading.Thread)
  91. self.assertEqual(s._async['queue'], queue.Queue)
  92. self.assertEqual(s._async['websocket'], async_eventlet.WebSocketWSGI)
  93. @mock.patch('importlib.import_module', side_effect=_mock_import)
  94. def test_async_mode_gevent_uwsgi(self, import_module):
  95. sys.modules['gevent'] = mock.MagicMock()
  96. sys.modules['gevent'].queue = mock.MagicMock()
  97. sys.modules['gevent.queue'] = sys.modules['gevent'].queue
  98. sys.modules['gevent.queue'].JoinableQueue = 'foo'
  99. sys.modules['gevent.queue'].Empty = RuntimeError
  100. sys.modules['gevent.event'] = mock.MagicMock()
  101. sys.modules['gevent.event'].Event = 'bar'
  102. sys.modules['uwsgi'] = mock.MagicMock()
  103. s = server.Server(async_mode='gevent_uwsgi')
  104. self.assertEqual(s.async_mode, 'gevent_uwsgi')
  105. from engineio.async_drivers import gevent_uwsgi as async_gevent_uwsgi
  106. self.assertEqual(s._async['thread'], async_gevent_uwsgi.Thread)
  107. self.assertEqual(s._async['queue'], 'foo')
  108. self.assertEqual(s._async['queue_empty'], RuntimeError)
  109. self.assertEqual(s._async['event'], 'bar')
  110. self.assertEqual(s._async['websocket'],
  111. async_gevent_uwsgi.uWSGIWebSocket)
  112. del sys.modules['gevent']
  113. del sys.modules['gevent.queue']
  114. del sys.modules['gevent.event']
  115. del sys.modules['uwsgi']
  116. del sys.modules['engineio.async_drivers.gevent_uwsgi']
  117. @mock.patch('importlib.import_module', side_effect=_mock_import)
  118. def test_async_mode_gevent_uwsgi_without_uwsgi(self, import_module):
  119. sys.modules['gevent'] = mock.MagicMock()
  120. sys.modules['gevent'].queue = mock.MagicMock()
  121. sys.modules['gevent.queue'] = sys.modules['gevent'].queue
  122. sys.modules['gevent.queue'].JoinableQueue = 'foo'
  123. sys.modules['gevent.queue'].Empty = RuntimeError
  124. sys.modules['gevent.event'] = mock.MagicMock()
  125. sys.modules['gevent.event'].Event = 'bar'
  126. sys.modules['uwsgi'] = None
  127. self.assertRaises(ValueError, server.Server,
  128. async_mode='gevent_uwsgi')
  129. del sys.modules['gevent']
  130. del sys.modules['gevent.queue']
  131. del sys.modules['gevent.event']
  132. del sys.modules['uwsgi']
  133. @mock.patch('importlib.import_module', side_effect=_mock_import)
  134. def test_async_mode_gevent_uwsgi_without_websocket(self, import_module):
  135. sys.modules['gevent'] = mock.MagicMock()
  136. sys.modules['gevent'].queue = mock.MagicMock()
  137. sys.modules['gevent.queue'] = sys.modules['gevent'].queue
  138. sys.modules['gevent.queue'].JoinableQueue = 'foo'
  139. sys.modules['gevent.queue'].Empty = RuntimeError
  140. sys.modules['gevent.event'] = mock.MagicMock()
  141. sys.modules['gevent.event'].Event = 'bar'
  142. sys.modules['uwsgi'] = mock.MagicMock()
  143. del sys.modules['uwsgi'].websocket_handshake
  144. s = server.Server(async_mode='gevent_uwsgi')
  145. self.assertEqual(s.async_mode, 'gevent_uwsgi')
  146. from engineio.async_drivers import gevent_uwsgi as async_gevent_uwsgi
  147. self.assertEqual(s._async['thread'], async_gevent_uwsgi.Thread)
  148. self.assertEqual(s._async['queue'], 'foo')
  149. self.assertEqual(s._async['queue_empty'], RuntimeError)
  150. self.assertEqual(s._async['event'], 'bar')
  151. self.assertEqual(s._async['websocket'], None)
  152. del sys.modules['gevent']
  153. del sys.modules['gevent.queue']
  154. del sys.modules['gevent.event']
  155. del sys.modules['uwsgi']
  156. del sys.modules['engineio.async_drivers.gevent_uwsgi']
  157. @mock.patch('importlib.import_module', side_effect=_mock_import)
  158. def test_async_mode_gevent(self, import_module):
  159. sys.modules['gevent'] = mock.MagicMock()
  160. sys.modules['gevent'].queue = mock.MagicMock()
  161. sys.modules['gevent.queue'] = sys.modules['gevent'].queue
  162. sys.modules['gevent.queue'].JoinableQueue = 'foo'
  163. sys.modules['gevent.queue'].Empty = RuntimeError
  164. sys.modules['gevent.event'] = mock.MagicMock()
  165. sys.modules['gevent.event'].Event = 'bar'
  166. sys.modules['geventwebsocket'] = 'geventwebsocket'
  167. s = server.Server(async_mode='gevent')
  168. self.assertEqual(s.async_mode, 'gevent')
  169. from engineio.async_drivers import gevent as async_gevent
  170. self.assertEqual(s._async['thread'], async_gevent.Thread)
  171. self.assertEqual(s._async['queue'], 'foo')
  172. self.assertEqual(s._async['queue_empty'], RuntimeError)
  173. self.assertEqual(s._async['event'], 'bar')
  174. self.assertEqual(s._async['websocket'], async_gevent.WebSocketWSGI)
  175. del sys.modules['gevent']
  176. del sys.modules['gevent.queue']
  177. del sys.modules['gevent.event']
  178. del sys.modules['geventwebsocket']
  179. del sys.modules['engineio.async_drivers.gevent']
  180. @mock.patch('importlib.import_module', side_effect=_mock_import)
  181. def test_async_mode_gevent_without_websocket(self, import_module):
  182. sys.modules['gevent'] = mock.MagicMock()
  183. sys.modules['gevent'].queue = mock.MagicMock()
  184. sys.modules['gevent.queue'] = sys.modules['gevent'].queue
  185. sys.modules['gevent.queue'].JoinableQueue = 'foo'
  186. sys.modules['gevent.queue'].Empty = RuntimeError
  187. sys.modules['gevent.event'] = mock.MagicMock()
  188. sys.modules['gevent.event'].Event = 'bar'
  189. sys.modules['geventwebsocket'] = None
  190. s = server.Server(async_mode='gevent')
  191. self.assertEqual(s.async_mode, 'gevent')
  192. from engineio.async_drivers import gevent as async_gevent
  193. self.assertEqual(s._async['thread'], async_gevent.Thread)
  194. self.assertEqual(s._async['queue'], 'foo')
  195. self.assertEqual(s._async['queue_empty'], RuntimeError)
  196. self.assertEqual(s._async['event'], 'bar')
  197. self.assertEqual(s._async['websocket'], None)
  198. del sys.modules['gevent']
  199. del sys.modules['gevent.queue']
  200. del sys.modules['gevent.event']
  201. del sys.modules['geventwebsocket']
  202. del sys.modules['engineio.async_drivers.gevent']
  203. @unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+')
  204. @mock.patch('importlib.import_module', side_effect=_mock_import)
  205. def test_async_mode_aiohttp(self, import_module):
  206. sys.modules['aiohttp'] = mock.MagicMock()
  207. self.assertRaises(ValueError, server.Server, async_mode='aiohttp')
  208. @mock.patch('importlib.import_module', side_effect=[ImportError])
  209. def test_async_mode_invalid(self, import_module):
  210. self.assertRaises(ValueError, server.Server, async_mode='foo')
  211. @mock.patch('importlib.import_module', side_effect=[_mock_async])
  212. def test_async_mode_auto_eventlet(self, import_module):
  213. s = server.Server()
  214. self.assertEqual(s.async_mode, 'eventlet')
  215. @mock.patch('importlib.import_module', side_effect=[ImportError,
  216. _mock_async])
  217. def test_async_mode_auto_gevent_uwsgi(self, import_module):
  218. s = server.Server()
  219. self.assertEqual(s.async_mode, 'gevent_uwsgi')
  220. @mock.patch('importlib.import_module', side_effect=[ImportError,
  221. ImportError,
  222. _mock_async])
  223. def test_async_mode_auto_gevent(self, import_module):
  224. s = server.Server()
  225. self.assertEqual(s.async_mode, 'gevent')
  226. @mock.patch('importlib.import_module', side_effect=[ImportError,
  227. ImportError,
  228. ImportError,
  229. _mock_async])
  230. def test_async_mode_auto_threading(self, import_module):
  231. s = server.Server()
  232. self.assertEqual(s.async_mode, 'threading')
  233. def test_generate_id(self):
  234. s = server.Server()
  235. self.assertNotEqual(s._generate_id(), s._generate_id())
  236. def test_on_event(self):
  237. s = server.Server()
  238. @s.on('connect')
  239. def foo():
  240. pass
  241. s.on('disconnect', foo)
  242. self.assertEqual(s.handlers['connect'], foo)
  243. self.assertEqual(s.handlers['disconnect'], foo)
  244. def test_on_event_invalid(self):
  245. s = server.Server()
  246. self.assertRaises(ValueError, s.on, 'invalid')
  247. def test_trigger_event(self):
  248. s = server.Server()
  249. f = {}
  250. @s.on('connect')
  251. def foo(sid, environ):
  252. return sid + environ
  253. @s.on('message')
  254. def bar(sid, data):
  255. f['bar'] = sid + data
  256. return 'bar'
  257. r = s._trigger_event('connect', 1, 2, run_async=False)
  258. self.assertEqual(r, 3)
  259. r = s._trigger_event('message', 3, 4, run_async=True)
  260. r.join()
  261. self.assertEqual(f['bar'], 7)
  262. r = s._trigger_event('message', 5, 6)
  263. self.assertEqual(r, 'bar')
  264. def test_trigger_event_error(self):
  265. s = server.Server()
  266. @s.on('connect')
  267. def foo(sid, environ):
  268. return 1 / 0
  269. @s.on('message')
  270. def bar(sid, data):
  271. return 1 / 0
  272. r = s._trigger_event('connect', 1, 2, run_async=False)
  273. self.assertEqual(r, False)
  274. r = s._trigger_event('message', 3, 4, run_async=False)
  275. self.assertEqual(r, None)
  276. def test_session(self):
  277. s = server.Server()
  278. mock_socket = self._get_mock_socket()
  279. s.sockets['foo'] = mock_socket
  280. with s.session('foo') as session:
  281. self.assertEqual(session, {})
  282. session['username'] = 'bar'
  283. self.assertEqual(s.get_session('foo'), {'username': 'bar'})
  284. def test_close_one_socket(self):
  285. s = server.Server()
  286. mock_socket = self._get_mock_socket()
  287. s.sockets['foo'] = mock_socket
  288. s.disconnect('foo')
  289. self.assertEqual(mock_socket.close.call_count, 1)
  290. self.assertNotIn('foo', s.sockets)
  291. def test_close_all_sockets(self):
  292. s = server.Server()
  293. mock_sockets = {}
  294. for sid in ['foo', 'bar', 'baz']:
  295. mock_sockets[sid] = self._get_mock_socket()
  296. s.sockets[sid] = mock_sockets[sid]
  297. s.disconnect()
  298. for socket in six.itervalues(mock_sockets):
  299. self.assertEqual(socket.close.call_count, 1)
  300. self.assertEqual(s.sockets, {})
  301. def test_upgrades(self):
  302. s = server.Server()
  303. s.sockets['foo'] = self._get_mock_socket()
  304. self.assertEqual(s._upgrades('foo', 'polling'), ['websocket'])
  305. self.assertEqual(s._upgrades('foo', 'websocket'), [])
  306. s.sockets['foo'].upgraded = True
  307. self.assertEqual(s._upgrades('foo', 'polling'), [])
  308. self.assertEqual(s._upgrades('foo', 'websocket'), [])
  309. s.allow_upgrades = False
  310. s.sockets['foo'].upgraded = True
  311. self.assertEqual(s._upgrades('foo', 'polling'), [])
  312. self.assertEqual(s._upgrades('foo', 'websocket'), [])
  313. def test_transport(self):
  314. s = server.Server()
  315. s.sockets['foo'] = self._get_mock_socket()
  316. s.sockets['foo'].upgraded = False
  317. s.sockets['bar'] = self._get_mock_socket()
  318. s.sockets['bar'].upgraded = True
  319. self.assertEqual(s.transport('foo'), 'polling')
  320. self.assertEqual(s.transport('bar'), 'websocket')
  321. def test_bad_session(self):
  322. s = server.Server()
  323. s.sockets['foo'] = 'client'
  324. self.assertRaises(KeyError, s._get_socket, 'bar')
  325. def test_closed_socket(self):
  326. s = server.Server()
  327. s.sockets['foo'] = self._get_mock_socket()
  328. s.sockets['foo'].closed = True
  329. self.assertRaises(KeyError, s._get_socket, 'foo')
  330. def test_jsonp_with_bad_index(self):
  331. s = server.Server()
  332. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'j=abc'}
  333. start_response = mock.MagicMock()
  334. s.handle_request(environ, start_response)
  335. self.assertEqual(start_response.call_args[0][0],
  336. '400 BAD REQUEST')
  337. def test_jsonp_index(self):
  338. s = server.Server()
  339. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'j=233'}
  340. start_response = mock.MagicMock()
  341. r = s.handle_request(environ, start_response)
  342. self.assertEqual(start_response.call_args[0][0],
  343. '200 OK')
  344. self.assertTrue(r[0].startswith(b'___eio[233]("'))
  345. self.assertTrue(r[0].endswith(b'");'))
  346. def test_connect(self):
  347. s = server.Server()
  348. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  349. start_response = mock.MagicMock()
  350. r = s.handle_request(environ, start_response)
  351. self.assertEqual(len(s.sockets), 1)
  352. self.assertEqual(start_response.call_count, 1)
  353. self.assertEqual(start_response.call_args[0][0], '200 OK')
  354. self.assertIn(('Content-Type', 'application/octet-stream'),
  355. start_response.call_args[0][1])
  356. self.assertEqual(len(r), 1)
  357. packets = payload.Payload(encoded_payload=r[0]).packets
  358. self.assertEqual(len(packets), 1)
  359. self.assertEqual(packets[0].packet_type, packet.OPEN)
  360. self.assertIn('upgrades', packets[0].data)
  361. self.assertEqual(packets[0].data['upgrades'], ['websocket'])
  362. self.assertIn('sid', packets[0].data)
  363. def test_connect_no_upgrades(self):
  364. s = server.Server(allow_upgrades=False)
  365. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  366. start_response = mock.MagicMock()
  367. r = s.handle_request(environ, start_response)
  368. packets = payload.Payload(encoded_payload=r[0]).packets
  369. self.assertEqual(packets[0].data['upgrades'], [])
  370. def test_connect_b64_with_1(self):
  371. s = server.Server(allow_upgrades=False)
  372. s._generate_id = mock.MagicMock(return_value='1')
  373. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=1'}
  374. start_response = mock.MagicMock()
  375. s.handle_request(environ, start_response)
  376. self.assertTrue(start_response.call_args[0][0], '200 OK')
  377. self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
  378. start_response.call_args[0][1])
  379. s.send('1', b'\x00\x01\x02', binary=True)
  380. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=1'}
  381. r = s.handle_request(environ, start_response)
  382. self.assertEqual(r[0], b'6:b4AAEC')
  383. def test_connect_b64_with_true(self):
  384. s = server.Server(allow_upgrades=False)
  385. s._generate_id = mock.MagicMock(return_value='1')
  386. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=true'}
  387. start_response = mock.MagicMock()
  388. s.handle_request(environ, start_response)
  389. self.assertTrue(start_response.call_args[0][0], '200 OK')
  390. self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
  391. start_response.call_args[0][1])
  392. s.send('1', b'\x00\x01\x02', binary=True)
  393. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=true'}
  394. r = s.handle_request(environ, start_response)
  395. self.assertEqual(r[0], b'6:b4AAEC')
  396. def test_connect_b64_with_0(self):
  397. s = server.Server(allow_upgrades=False)
  398. s._generate_id = mock.MagicMock(return_value='1')
  399. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=0'}
  400. start_response = mock.MagicMock()
  401. s.handle_request(environ, start_response)
  402. self.assertTrue(start_response.call_args[0][0], '200 OK')
  403. self.assertIn(('Content-Type', 'application/octet-stream'),
  404. start_response.call_args[0][1])
  405. s.send('1', b'\x00\x01\x02', binary=True)
  406. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=0'}
  407. r = s.handle_request(environ, start_response)
  408. self.assertEqual(r[0], b'\x01\x04\xff\x04\x00\x01\x02')
  409. def test_connect_b64_with_false(self):
  410. s = server.Server(allow_upgrades=False)
  411. s._generate_id = mock.MagicMock(return_value='1')
  412. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=false'}
  413. start_response = mock.MagicMock()
  414. s.handle_request(environ, start_response)
  415. self.assertTrue(start_response.call_args[0][0], '200 OK')
  416. self.assertIn(('Content-Type', 'application/octet-stream'),
  417. start_response.call_args[0][1])
  418. s.send('1', b'\x00\x01\x02', binary=True)
  419. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=false'}
  420. r = s.handle_request(environ, start_response)
  421. self.assertEqual(r[0], b'\x01\x04\xff\x04\x00\x01\x02')
  422. def test_connect_custom_ping_times(self):
  423. s = server.Server(ping_timeout=123, ping_interval=456)
  424. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  425. start_response = mock.MagicMock()
  426. r = s.handle_request(environ, start_response)
  427. packets = payload.Payload(encoded_payload=r[0]).packets
  428. self.assertEqual(packets[0].data['pingTimeout'], 123000)
  429. self.assertEqual(packets[0].data['pingInterval'], 456000)
  430. @mock.patch('engineio.socket.Socket.poll',
  431. side_effect=exceptions.QueueEmpty)
  432. def test_connect_bad_poll(self, poll):
  433. s = server.Server()
  434. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  435. start_response = mock.MagicMock()
  436. s.handle_request(environ, start_response)
  437. self.assertEqual(start_response.call_args[0][0],
  438. '400 BAD REQUEST')
  439. @mock.patch('engineio.socket.Socket',
  440. return_value=mock.MagicMock(connected=False, closed=False))
  441. def test_connect_transport_websocket(self, Socket):
  442. s = server.Server()
  443. s._generate_id = mock.MagicMock(return_value='123')
  444. environ = {'REQUEST_METHOD': 'GET',
  445. 'QUERY_STRING': 'transport=websocket'}
  446. start_response = mock.MagicMock()
  447. # force socket to stay open, so that we can check it later
  448. Socket().closed = False
  449. s.handle_request(environ, start_response)
  450. self.assertEqual(s.sockets['123'].send.call_args[0][0].packet_type,
  451. packet.OPEN)
  452. @mock.patch('engineio.socket.Socket',
  453. return_value=mock.MagicMock(connected=False, closed=False))
  454. def test_connect_transport_websocket_closed(self, Socket):
  455. s = server.Server()
  456. s._generate_id = mock.MagicMock(return_value='123')
  457. environ = {'REQUEST_METHOD': 'GET',
  458. 'QUERY_STRING': 'transport=websocket'}
  459. start_response = mock.MagicMock()
  460. def mock_handle(environ, start_response):
  461. s.sockets['123'].closed = True
  462. Socket().handle_get_request = mock_handle
  463. s.handle_request(environ, start_response)
  464. self.assertNotIn('123', s.sockets)
  465. def test_connect_transport_invalid(self):
  466. s = server.Server()
  467. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'transport=foo'}
  468. start_response = mock.MagicMock()
  469. s.handle_request(environ, start_response)
  470. self.assertEqual(start_response.call_args[0][0],
  471. '400 BAD REQUEST')
  472. def test_connect_cors_headers(self):
  473. s = server.Server()
  474. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  475. start_response = mock.MagicMock()
  476. s.handle_request(environ, start_response)
  477. headers = start_response.call_args[0][1]
  478. self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
  479. self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
  480. def test_connect_cors_allowed_origin(self):
  481. s = server.Server(cors_allowed_origins=['a', 'b'])
  482. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  483. 'HTTP_ORIGIN': 'b'}
  484. start_response = mock.MagicMock()
  485. s.handle_request(environ, start_response)
  486. headers = start_response.call_args[0][1]
  487. self.assertIn(('Access-Control-Allow-Origin', 'b'), headers)
  488. def test_connect_cors_not_allowed_origin(self):
  489. s = server.Server(cors_allowed_origins=['a', 'b'])
  490. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  491. 'HTTP_ORIGIN': 'c'}
  492. start_response = mock.MagicMock()
  493. s.handle_request(environ, start_response)
  494. headers = start_response.call_args[0][1]
  495. self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
  496. self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)
  497. def test_connect_cors_headers_all_origins(self):
  498. s = server.Server(cors_allowed_origins='*')
  499. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  500. start_response = mock.MagicMock()
  501. s.handle_request(environ, start_response)
  502. headers = start_response.call_args[0][1]
  503. self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
  504. self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
  505. def test_connect_cors_headers_one_origin(self):
  506. s = server.Server(cors_allowed_origins='a')
  507. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  508. 'HTTP_ORIGIN': 'a'}
  509. start_response = mock.MagicMock()
  510. s.handle_request(environ, start_response)
  511. headers = start_response.call_args[0][1]
  512. self.assertIn(('Access-Control-Allow-Origin', 'a'), headers)
  513. self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
  514. def test_connect_cors_headers_one_origin_not_allowed(self):
  515. s = server.Server(cors_allowed_origins='a')
  516. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
  517. 'HTTP_ORIGIN': 'b'}
  518. start_response = mock.MagicMock()
  519. s.handle_request(environ, start_response)
  520. headers = start_response.call_args[0][1]
  521. self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
  522. self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)
  523. def test_connect_cors_no_credentials(self):
  524. s = server.Server(cors_credentials=False)
  525. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  526. start_response = mock.MagicMock()
  527. s.handle_request(environ, start_response)
  528. headers = start_response.call_args[0][1]
  529. self.assertNotIn(('Access-Control-Allow-Credentials', 'true'), headers)
  530. def test_cors_options(self):
  531. s = server.Server()
  532. environ = {'REQUEST_METHOD': 'OPTIONS', 'QUERY_STRING': ''}
  533. start_response = mock.MagicMock()
  534. s.handle_request(environ, start_response)
  535. headers = start_response.call_args[0][1]
  536. self.assertIn(('Access-Control-Allow-Methods', 'OPTIONS, GET, POST'),
  537. headers)
  538. def test_cors_request_headers(self):
  539. s = server.Server()
  540. environ = {'REQUEST_METHOD': 'GET',
  541. 'HTTP_ACCESS_CONTROL_REQUEST_HEADERS': 'Foo, Bar'}
  542. start_response = mock.MagicMock()
  543. s.handle_request(environ, start_response)
  544. headers = start_response.call_args[0][1]
  545. self.assertIn(('Access-Control-Allow-Headers', 'Foo, Bar'), headers)
  546. def test_connect_event(self):
  547. s = server.Server()
  548. s._generate_id = mock.MagicMock(return_value='123')
  549. mock_event = mock.MagicMock()
  550. s.on('connect')(mock_event)
  551. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  552. start_response = mock.MagicMock()
  553. s.handle_request(environ, start_response)
  554. mock_event.assert_called_once_with('123', environ)
  555. self.assertEqual(len(s.sockets), 1)
  556. def test_connect_event_rejects(self):
  557. s = server.Server()
  558. s._generate_id = mock.MagicMock(return_value='123')
  559. mock_event = mock.MagicMock(return_value=False)
  560. s.on('connect')(mock_event)
  561. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  562. start_response = mock.MagicMock()
  563. s.handle_request(environ, start_response)
  564. self.assertEqual(len(s.sockets), 0)
  565. self.assertEqual(start_response.call_args[0][0], '401 UNAUTHORIZED')
  566. def test_method_not_found(self):
  567. s = server.Server()
  568. environ = {'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''}
  569. start_response = mock.MagicMock()
  570. s.handle_request(environ, start_response)
  571. self.assertEqual(start_response.call_args[0][0],
  572. '405 METHOD NOT FOUND')
  573. def test_get_request_with_bad_sid(self):
  574. s = server.Server()
  575. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  576. start_response = mock.MagicMock()
  577. s.handle_request(environ, start_response)
  578. self.assertEqual(start_response.call_args[0][0],
  579. '400 BAD REQUEST')
  580. def test_post_request_with_bad_sid(self):
  581. s = server.Server()
  582. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
  583. start_response = mock.MagicMock()
  584. s.handle_request(environ, start_response)
  585. self.assertEqual(start_response.call_args[0][0],
  586. '400 BAD REQUEST')
  587. def test_send(self):
  588. s = server.Server()
  589. mock_socket = self._get_mock_socket()
  590. s.sockets['foo'] = mock_socket
  591. s.send('foo', 'hello')
  592. self.assertEqual(mock_socket.send.call_count, 1)
  593. self.assertEqual(mock_socket.send.call_args[0][0].packet_type,
  594. packet.MESSAGE)
  595. self.assertEqual(mock_socket.send.call_args[0][0].data, 'hello')
  596. def test_send_unknown_socket(self):
  597. s = server.Server()
  598. # just ensure no exceptions are raised
  599. s.send('foo', 'hello')
  600. def test_get_request(self):
  601. s = server.Server()
  602. mock_socket = self._get_mock_socket()
  603. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  604. packet.Packet(packet.MESSAGE, data='hello')])
  605. s.sockets['foo'] = mock_socket
  606. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  607. start_response = mock.MagicMock()
  608. r = s.handle_request(environ, start_response)
  609. self.assertEqual(start_response.call_args[0][0],
  610. '200 OK')
  611. self.assertEqual(len(r), 1)
  612. packets = payload.Payload(encoded_payload=r[0]).packets
  613. self.assertEqual(len(packets), 1)
  614. self.assertEqual(packets[0].packet_type, packet.MESSAGE)
  615. def test_get_request_custom_response(self):
  616. s = server.Server()
  617. mock_socket = self._get_mock_socket()
  618. mock_socket.handle_get_request = mock.MagicMock(side_effect=['resp'])
  619. s.sockets['foo'] = mock_socket
  620. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  621. start_response = mock.MagicMock()
  622. self.assertEqual(s.handle_request(environ, start_response), 'resp')
  623. def test_get_request_closes_socket(self):
  624. s = server.Server()
  625. mock_socket = self._get_mock_socket()
  626. def mock_get_request(*args, **kwargs):
  627. mock_socket.closed = True
  628. return 'resp'
  629. mock_socket.handle_get_request = mock_get_request
  630. s.sockets['foo'] = mock_socket
  631. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  632. start_response = mock.MagicMock()
  633. self.assertEqual(s.handle_request(environ, start_response), 'resp')
  634. self.assertNotIn('foo', s.sockets)
  635. def test_get_request_error(self):
  636. s = server.Server()
  637. mock_socket = self._get_mock_socket()
  638. mock_socket.handle_get_request = mock.MagicMock(
  639. side_effect=[exceptions.QueueEmpty])
  640. s.sockets['foo'] = mock_socket
  641. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
  642. start_response = mock.MagicMock()
  643. s.handle_request(environ, start_response)
  644. self.assertEqual(start_response.call_args[0][0],
  645. '400 BAD REQUEST')
  646. self.assertEqual(len(s.sockets), 0)
  647. def test_post_request(self):
  648. s = server.Server()
  649. mock_socket = self._get_mock_socket()
  650. mock_socket.handle_post_request = mock.MagicMock()
  651. s.sockets['foo'] = mock_socket
  652. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
  653. start_response = mock.MagicMock()
  654. s.handle_request(environ, start_response)
  655. self.assertEqual(start_response.call_args[0][0],
  656. '200 OK')
  657. def test_post_request_error(self):
  658. s = server.Server()
  659. mock_socket = self._get_mock_socket()
  660. mock_socket.handle_post_request = mock.MagicMock(
  661. side_effect=[exceptions.EngineIOError])
  662. s.sockets['foo'] = mock_socket
  663. environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
  664. start_response = mock.MagicMock()
  665. s.handle_request(environ, start_response)
  666. self.assertEqual(start_response.call_args[0][0],
  667. '400 BAD REQUEST')
  668. self.assertNotIn('foo', s.sockets)
  669. @staticmethod
  670. def _gzip_decompress(b):
  671. bytesio = six.BytesIO(b)
  672. with gzip.GzipFile(fileobj=bytesio, mode='r') as gz:
  673. return gz.read()
  674. def test_gzip_compression(self):
  675. s = server.Server(compression_threshold=0)
  676. mock_socket = self._get_mock_socket()
  677. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  678. packet.Packet(packet.MESSAGE, data='hello')])
  679. s.sockets['foo'] = mock_socket
  680. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  681. 'HTTP_ACCEPT_ENCODING': 'gzip,deflate'}
  682. start_response = mock.MagicMock()
  683. r = s.handle_request(environ, start_response)
  684. self.assertIn(('Content-Encoding', 'gzip'),
  685. start_response.call_args[0][1])
  686. self._gzip_decompress(r[0])
  687. def test_deflate_compression(self):
  688. s = server.Server(compression_threshold=0)
  689. mock_socket = self._get_mock_socket()
  690. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  691. packet.Packet(packet.MESSAGE, data='hello')])
  692. s.sockets['foo'] = mock_socket
  693. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  694. 'HTTP_ACCEPT_ENCODING': 'deflate;q=1,gzip'}
  695. start_response = mock.MagicMock()
  696. r = s.handle_request(environ, start_response)
  697. self.assertIn(('Content-Encoding', 'deflate'),
  698. start_response.call_args[0][1])
  699. zlib.decompress(r[0])
  700. def test_gzip_compression_threshold(self):
  701. s = server.Server(compression_threshold=1000)
  702. mock_socket = self._get_mock_socket()
  703. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  704. packet.Packet(packet.MESSAGE, data='hello')])
  705. s.sockets['foo'] = mock_socket
  706. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  707. 'HTTP_ACCEPT_ENCODING': 'gzip'}
  708. start_response = mock.MagicMock()
  709. r = s.handle_request(environ, start_response)
  710. for header, value in start_response.call_args[0][1]:
  711. self.assertNotEqual(header, 'Content-Encoding')
  712. self.assertRaises(IOError, self._gzip_decompress, r[0])
  713. def test_compression_disabled(self):
  714. s = server.Server(http_compression=False, compression_threshold=0)
  715. mock_socket = self._get_mock_socket()
  716. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  717. packet.Packet(packet.MESSAGE, data='hello')])
  718. s.sockets['foo'] = mock_socket
  719. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  720. 'HTTP_ACCEPT_ENCODING': 'gzip'}
  721. start_response = mock.MagicMock()
  722. r = s.handle_request(environ, start_response)
  723. for header, value in start_response.call_args[0][1]:
  724. self.assertNotEqual(header, 'Content-Encoding')
  725. self.assertRaises(IOError, self._gzip_decompress, r[0])
  726. def test_compression_unknown(self):
  727. s = server.Server(compression_threshold=0)
  728. mock_socket = self._get_mock_socket()
  729. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  730. packet.Packet(packet.MESSAGE, data='hello')])
  731. s.sockets['foo'] = mock_socket
  732. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  733. 'HTTP_ACCEPT_ENCODING': 'rar'}
  734. start_response = mock.MagicMock()
  735. r = s.handle_request(environ, start_response)
  736. for header, value in start_response.call_args[0][1]:
  737. self.assertNotEqual(header, 'Content-Encoding')
  738. self.assertRaises(IOError, self._gzip_decompress, r[0])
  739. def test_compression_no_encoding(self):
  740. s = server.Server(compression_threshold=0)
  741. mock_socket = self._get_mock_socket()
  742. mock_socket.handle_get_request = mock.MagicMock(return_value=[
  743. packet.Packet(packet.MESSAGE, data='hello')])
  744. s.sockets['foo'] = mock_socket
  745. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
  746. 'HTTP_ACCEPT_ENCODING': ''}
  747. start_response = mock.MagicMock()
  748. r = s.handle_request(environ, start_response)
  749. for header, value in start_response.call_args[0][1]:
  750. self.assertNotEqual(header, 'Content-Encoding')
  751. self.assertRaises(IOError, self._gzip_decompress, r[0])
  752. def test_cookie(self):
  753. s = server.Server(cookie='sid')
  754. s._generate_id = mock.MagicMock(return_value='123')
  755. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  756. start_response = mock.MagicMock()
  757. s.handle_request(environ, start_response)
  758. self.assertIn(('Set-Cookie', 'sid=123'),
  759. start_response.call_args[0][1])
  760. def test_no_cookie(self):
  761. s = server.Server(cookie=None)
  762. s._generate_id = mock.MagicMock(return_value='123')
  763. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  764. start_response = mock.MagicMock()
  765. s.handle_request(environ, start_response)
  766. for header, value in start_response.call_args[0][1]:
  767. self.assertNotEqual(header, 'Set-Cookie')
  768. def test_logger(self):
  769. s = server.Server(logger=False)
  770. self.assertEqual(s.logger.getEffectiveLevel(), logging.ERROR)
  771. s.logger.setLevel(logging.NOTSET)
  772. s = server.Server(logger=True)
  773. self.assertEqual(s.logger.getEffectiveLevel(), logging.INFO)
  774. s.logger.setLevel(logging.WARNING)
  775. s = server.Server(logger=True)
  776. self.assertEqual(s.logger.getEffectiveLevel(), logging.WARNING)
  777. s.logger.setLevel(logging.NOTSET)
  778. my_logger = logging.Logger('foo')
  779. s = server.Server(logger=my_logger)
  780. self.assertEqual(s.logger, my_logger)
  781. def test_custom_json(self):
  782. # Warning: this test cannot run in parallel with other tests, as it
  783. # changes the JSON encoding/decoding functions
  784. class CustomJSON(object):
  785. @staticmethod
  786. def dumps(*args, **kwargs):
  787. return '*** encoded ***'
  788. @staticmethod
  789. def loads(*args, **kwargs):
  790. return '+++ decoded +++'
  791. server.Server(json=CustomJSON)
  792. pkt = packet.Packet(packet.MESSAGE, data={'foo': 'bar'})
  793. self.assertEqual(pkt.encode(), b'4*** encoded ***')
  794. pkt2 = packet.Packet(encoded_packet=pkt.encode())
  795. self.assertEqual(pkt2.data, '+++ decoded +++')
  796. # restore the default JSON module
  797. packet.Packet.json = json
  798. def test_background_tasks(self):
  799. flag = {}
  800. def bg_task():
  801. flag['task'] = True
  802. s = server.Server()
  803. task = s.start_background_task(bg_task)
  804. task.join()
  805. self.assertIn('task', flag)
  806. self.assertTrue(flag['task'])
  807. def test_sleep(self):
  808. s = server.Server()
  809. t = time.time()
  810. s.sleep(0.1)
  811. self.assertTrue(time.time() - t > 0.1)
  812. def test_create_queue(self):
  813. s = server.Server()
  814. q = s.create_queue()
  815. empty = s.get_queue_empty_exception()
  816. self.assertRaises(empty, q.get, timeout=0.01)
  817. def test_create_event(self):
  818. s = server.Server()
  819. e = s.create_event()
  820. self.assertFalse(e.is_set())
  821. e.set()
  822. self.assertTrue(e.is_set())
  823. def test_service_task_started(self):
  824. s = server.Server(monitor_clients=True)
  825. s._service_task = mock.MagicMock()
  826. environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
  827. start_response = mock.MagicMock()
  828. s.handle_request(environ, start_response)
  829. s._service_task.assert_called_once_with()