|
- import time
- import unittest
-
- import six
- if six.PY3:
- from unittest import mock
- else:
- import mock
-
- from engineio import exceptions
- from engineio import packet
- from engineio import payload
- from engineio import socket
-
-
- class TestSocket(unittest.TestCase):
- def setUp(self):
- self.bg_tasks = []
-
- def _get_mock_server(self):
- mock_server = mock.Mock()
- mock_server.ping_timeout = 0.2
- mock_server.ping_interval = 0.2
- mock_server.async_handlers = True
-
- try:
- import queue
- except ImportError:
- import Queue as queue
- import threading
- mock_server._async = {'threading': threading.Thread,
- 'queue': queue.Queue,
- 'websocket': None}
-
- def bg_task(target, *args, **kwargs):
- th = threading.Thread(target=target, args=args, kwargs=kwargs)
- self.bg_tasks.append(th)
- th.start()
- return th
-
- def create_queue(*args, **kwargs):
- return queue.Queue(*args, **kwargs)
-
- mock_server.start_background_task = bg_task
- mock_server.create_queue = create_queue
- mock_server.get_queue_empty_exception.return_value = queue.Empty
- return mock_server
-
- def _join_bg_tasks(self):
- for task in self.bg_tasks:
- task.join()
-
- def test_create(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- self.assertEqual(s.server, mock_server)
- self.assertEqual(s.sid, 'sid')
- self.assertFalse(s.upgraded)
- self.assertFalse(s.closed)
- self.assertTrue(hasattr(s.queue, 'get'))
- self.assertTrue(hasattr(s.queue, 'put'))
- self.assertTrue(hasattr(s.queue, 'task_done'))
- self.assertTrue(hasattr(s.queue, 'join'))
-
- def test_empty_poll(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- self.assertRaises(exceptions.QueueEmpty, s.poll)
-
- def test_poll(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- pkt1 = packet.Packet(packet.MESSAGE, data='hello')
- pkt2 = packet.Packet(packet.MESSAGE, data='bye')
- s.send(pkt1)
- s.send(pkt2)
- self.assertEqual(s.poll(), [pkt1, pkt2])
-
- def test_ping_pong(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.receive(packet.Packet(packet.PING, data='abc'))
- r = s.poll()
- self.assertEqual(len(r), 1)
- self.assertTrue(r[0].encode(), b'3abc')
-
- def test_message_async_handler(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.receive(packet.Packet(packet.MESSAGE, data='foo'))
- mock_server._trigger_event.assert_called_once_with('message', 'sid',
- 'foo',
- run_async=True)
-
- def test_message_sync_handler(self):
- mock_server = self._get_mock_server()
- mock_server.async_handlers = False
- s = socket.Socket(mock_server, 'sid')
- s.receive(packet.Packet(packet.MESSAGE, data='foo'))
- mock_server._trigger_event.assert_called_once_with('message', 'sid',
- 'foo',
- run_async=False)
-
- def test_invalid_packet(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- self.assertRaises(exceptions.UnknownPacketError, s.receive,
- packet.Packet(packet.OPEN))
-
- def test_timeout(self):
- mock_server = self._get_mock_server()
- mock_server.ping_interval = -6
- s = socket.Socket(mock_server, 'sid')
- s.last_ping = time.time() - 1
- s.close = mock.MagicMock()
- s.send('packet')
- s.close.assert_called_once_with(wait=False, abort=False)
-
- def test_polling_read(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'foo')
- pkt1 = packet.Packet(packet.MESSAGE, data='hello')
- pkt2 = packet.Packet(packet.MESSAGE, data='bye')
- s.send(pkt1)
- s.send(pkt2)
- environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
- start_response = mock.MagicMock()
- packets = s.handle_get_request(environ, start_response)
- self.assertEqual(packets, [pkt1, pkt2])
-
- def test_polling_read_error(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'foo')
- environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
- start_response = mock.MagicMock()
- self.assertRaises(exceptions.QueueEmpty, s.handle_get_request, environ,
- start_response)
-
- def test_polling_write(self):
- mock_server = self._get_mock_server()
- mock_server.max_http_buffer_size = 1000
- pkt1 = packet.Packet(packet.MESSAGE, data='hello')
- pkt2 = packet.Packet(packet.MESSAGE, data='bye')
- p = payload.Payload(packets=[pkt1, pkt2]).encode()
- s = socket.Socket(mock_server, 'foo')
- s.receive = mock.MagicMock()
- environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
- 'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
- s.handle_post_request(environ)
- self.assertEqual(s.receive.call_count, 2)
-
- def test_polling_write_too_large(self):
- mock_server = self._get_mock_server()
- pkt1 = packet.Packet(packet.MESSAGE, data='hello')
- pkt2 = packet.Packet(packet.MESSAGE, data='bye')
- p = payload.Payload(packets=[pkt1, pkt2]).encode()
- mock_server.max_http_buffer_size = len(p) - 1
- s = socket.Socket(mock_server, 'foo')
- s.receive = mock.MagicMock()
- environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
- 'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
- self.assertRaises(exceptions.ContentTooLongError,
- s.handle_post_request, environ)
-
- def test_upgrade_handshake(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'foo')
- s._upgrade_websocket = mock.MagicMock()
- environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
- 'HTTP_CONNECTION': 'Foo,Upgrade,Bar',
- 'HTTP_UPGRADE': 'websocket'}
- start_response = mock.MagicMock()
- s.handle_get_request(environ, start_response)
- s._upgrade_websocket.assert_called_once_with(environ, start_response)
-
- def test_upgrade(self):
- mock_server = self._get_mock_server()
- mock_server._async['websocket'] = mock.MagicMock()
- mock_ws = mock.MagicMock()
- mock_server._async['websocket'].return_value = mock_ws
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- environ = "foo"
- start_response = "bar"
- s._upgrade_websocket(environ, start_response)
- mock_server._async['websocket'].assert_called_once_with(
- s._websocket_handler)
- mock_ws.assert_called_once_with(environ, start_response)
-
- def test_upgrade_twice(self):
- mock_server = self._get_mock_server()
- mock_server._async['websocket'] = mock.MagicMock()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.upgraded = True
- environ = "foo"
- start_response = "bar"
- self.assertRaises(IOError, s._upgrade_websocket,
- environ, start_response)
-
- def test_upgrade_packet(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.receive(packet.Packet(packet.UPGRADE))
- r = s.poll()
- self.assertEqual(len(r), 1)
- self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode())
-
- def test_upgrade_no_probe(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- ws = mock.MagicMock()
- ws.wait.return_value = packet.Packet(packet.NOOP).encode(
- always_bytes=False)
- s._websocket_handler(ws)
- self.assertFalse(s.upgraded)
-
- def test_upgrade_no_upgrade_packet(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.queue.join = mock.MagicMock(return_value=None)
- ws = mock.MagicMock()
- probe = six.text_type('probe')
- ws.wait.side_effect = [
- packet.Packet(packet.PING, data=probe).encode(
- always_bytes=False),
- packet.Packet(packet.NOOP).encode(always_bytes=False)]
- s._websocket_handler(ws)
- ws.send.assert_called_once_with(packet.Packet(
- packet.PONG, data=probe).encode(always_bytes=False))
- self.assertEqual(s.queue.get().packet_type, packet.NOOP)
- self.assertFalse(s.upgraded)
-
- def test_close_packet(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.close = mock.MagicMock()
- s.receive(packet.Packet(packet.CLOSE))
- s.close.assert_called_once_with(wait=False, abort=True)
-
- def test_invalid_packet_type(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- pkt = packet.Packet(packet_type=99)
- self.assertRaises(exceptions.UnknownPacketError, s.receive, pkt)
-
- def test_upgrade_not_supported(self):
- mock_server = self._get_mock_server()
- mock_server._async['websocket'] = None
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- environ = "foo"
- start_response = "bar"
- s._upgrade_websocket(environ, start_response)
- mock_server._bad_request.assert_called_once_with()
-
- def test_websocket_read_write(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = False
- s.queue.join = mock.MagicMock(return_value=None)
- foo = six.text_type('foo')
- bar = six.text_type('bar')
- s.poll = mock.MagicMock(side_effect=[
- [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.MESSAGE, data=foo).encode(
- always_bytes=False),
- None]
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertTrue(s.connected)
- self.assertTrue(s.upgraded)
- self.assertEqual(mock_server._trigger_event.call_count, 2)
- mock_server._trigger_event.assert_has_calls([
- mock.call('message', 'sid', 'foo', run_async=True),
- mock.call('disconnect', 'sid', run_async=False)])
- ws.send.assert_called_with('4bar')
-
- def test_websocket_upgrade_read_write(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.queue.join = mock.MagicMock(return_value=None)
- foo = six.text_type('foo')
- bar = six.text_type('bar')
- probe = six.text_type('probe')
- s.poll = mock.MagicMock(side_effect=[
- [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.PING, data=probe).encode(
- always_bytes=False),
- packet.Packet(packet.UPGRADE).encode(always_bytes=False),
- packet.Packet(packet.MESSAGE, data=foo).encode(
- always_bytes=False),
- None]
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertTrue(s.upgraded)
- self.assertEqual(mock_server._trigger_event.call_count, 2)
- mock_server._trigger_event.assert_has_calls([
- mock.call('message', 'sid', 'foo', run_async=True),
- mock.call('disconnect', 'sid', run_async=False)])
- ws.send.assert_called_with('4bar')
-
- def test_websocket_upgrade_with_payload(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.queue.join = mock.MagicMock(return_value=None)
- probe = six.text_type('probe')
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.PING, data=probe).encode(
- always_bytes=False),
- packet.Packet(packet.UPGRADE, data=b'2').encode(
- always_bytes=False)]
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertTrue(s.upgraded)
-
- def test_websocket_upgrade_with_backlog(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = True
- s.queue.join = mock.MagicMock(return_value=None)
- probe = six.text_type('probe')
- foo = six.text_type('foo')
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.PING, data=probe).encode(
- always_bytes=False),
- packet.Packet(packet.UPGRADE, data=b'2').encode(
- always_bytes=False)]
- s.upgrading = True
- s.send(packet.Packet(packet.MESSAGE, data=foo))
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertTrue(s.upgraded)
- self.assertFalse(s.upgrading)
- self.assertEqual(s.packet_backlog, [])
- ws.send.assert_called_with('4foo')
-
- def test_websocket_read_write_wait_fail(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = False
- s.queue.join = mock.MagicMock(return_value=None)
- foo = six.text_type('foo')
- bar = six.text_type('bar')
- s.poll = mock.MagicMock(side_effect=[
- [packet.Packet(packet.MESSAGE, data=bar)],
- [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.MESSAGE, data=foo).encode(
- always_bytes=False),
- RuntimeError]
- ws.send.side_effect = [None, RuntimeError]
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertEqual(s.closed, True)
-
- def test_websocket_ignore_invalid_packet(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.connected = False
- s.queue.join = mock.MagicMock(return_value=None)
- foo = six.text_type('foo')
- bar = six.text_type('bar')
- s.poll = mock.MagicMock(side_effect=[
- [packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
- ws = mock.MagicMock()
- ws.wait.side_effect = [
- packet.Packet(packet.OPEN).encode(always_bytes=False),
- packet.Packet(packet.MESSAGE, data=foo).encode(
- always_bytes=False),
- None]
- s._websocket_handler(ws)
- self._join_bg_tasks()
- self.assertTrue(s.connected)
- self.assertEqual(mock_server._trigger_event.call_count, 2)
- mock_server._trigger_event.assert_has_calls([
- mock.call('message', 'sid', foo, run_async=True),
- mock.call('disconnect', 'sid', run_async=False)])
- ws.send.assert_called_with('4bar')
-
- def test_send_after_close(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.close(wait=False)
- self.assertRaises(exceptions.SocketIsClosedError, s.send,
- packet.Packet(packet.NOOP))
-
- def test_close_after_close(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.close(wait=False)
- self.assertTrue(s.closed)
- self.assertEqual(mock_server._trigger_event.call_count, 1)
- mock_server._trigger_event.assert_called_once_with('disconnect', 'sid',
- run_async=False)
- s.close()
- self.assertEqual(mock_server._trigger_event.call_count, 1)
-
- def test_close_and_wait(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.queue = mock.MagicMock()
- s.close(wait=True)
- s.queue.join.assert_called_once_with()
-
- def test_close_without_wait(self):
- mock_server = self._get_mock_server()
- s = socket.Socket(mock_server, 'sid')
- s.queue = mock.MagicMock()
- s.close(wait=False)
- self.assertEqual(s.queue.join.call_count, 0)
|