diff --git a/bigchaindb/web/websocket_server.py b/bigchaindb/web/websocket_server.py index 9d8f5ef9..6915d54a 100644 --- a/bigchaindb/web/websocket_server.py +++ b/bigchaindb/web/websocket_server.py @@ -29,15 +29,15 @@ class Dispatcher: self.event_source = event_source self.subscribers = {} - def subscribe(self, uuid, ws): + def subscribe(self, uuid, websocket): """Add a websocket to the list of subscribers. Args: uuid (str): a unique identifier for the websocket. - ws: the websocket to publish information. + websocket: the websocket to publish information. """ - self.subscribers[uuid] = ws + self.subscribers[uuid] = websocket @asyncio.coroutine def publish(self): @@ -47,8 +47,8 @@ class Dispatcher: event = yield from self.event_source.get() if event == POISON_PILL: return - for uuid, ws in self.subscribers.items(): - ws.send_str(event) + for uuid, websocket in self.subscribers.items(): + websocket.send_str(event) @asyncio.coroutine @@ -56,20 +56,20 @@ def websocket_handler(request): """Handle a new socket connection.""" logger.debug('New websocket connection.') - ws = web.WebSocketResponse() - yield from ws.prepare(request) + websocket = web.WebSocketResponse() + yield from websocket.prepare(request) uuid = uuid4() - request.app['dispatcher'].subscribe(uuid, ws) + request.app['dispatcher'].subscribe(uuid, websocket) while True: # Consume input buffer - msg = yield from ws.receive() + msg = yield from websocket.receive() if msg.type == aiohttp.WSMsgType.ERROR: - logger.debug('Websocket exception: {}'.format(ws.exception())) + logger.debug('Websocket exception: %s', websocket.exception()) return -def init_app(event_source, loop=None): +def init_app(event_source, *, loop=None): """Init the application server. Return: @@ -87,17 +87,33 @@ def init_app(event_source, loop=None): return app -@asyncio.coroutine -def constant_event_source(event_source): - while True: - yield from asyncio.sleep(1) - yield from event_source.put('meow') +def start(event_source, *, loop=None): + """Create and start the WebSocket server.""" + + if not loop: + loop = asyncio.get_event_loop() + + app = init_app(event_source, loop=loop) + aiohttp.web.run_app(app, port=9985) + + +def test_websocket_server(): + """Set up a server and output a message every second. + Used for testing purposes.""" + + @asyncio.coroutine + def constant_event_source(event_source): + """Put a message in ``event_source`` every second.""" + + while True: + yield from asyncio.sleep(1) + yield from event_source.put('meow') + + loop = asyncio.get_event_loop() + event_source = asyncio.Queue() + loop.create_task(constant_event_source(event_source)) + start(event_source, loop=loop) if __name__ == '__main__': - loop = asyncio.get_event_loop() - event_source = asyncio.Queue() - - loop.create_task(constant_event_source(event_source)) - app = init_app(event_source, loop=loop) - aiohttp.web.run_app(app, port=9985) + test_websocket_server()