Source code for aiopyramid.websocket.config.gunicorn
import asyncio
import inspect
import functools
import websockets
import gunicorn # noqa
from pyramid.response import Response
from aiopyramid.config import AsyncioMapperBase
def _connection_closed_to_none(func):
"""
A backwards compatibility shim for websockets 3+. We need to
still return `None` rather than throwing an exception in order
to unite the interface with uWSGI even though the exception is
more Pythonic.
"""
@asyncio.coroutine
@functools.wraps(func)
def _connection_closed_to_none_inner(*args, **kwargs):
try:
msg = yield from func(*args, **kwargs)
except websockets.exceptions.ConnectionClosed:
msg = None
return msg
return _connection_closed_to_none_inner
def _use_bytes(func):
"""
Encodes strings received from websockets to bytes to
provide consistency with uwsgi since we don't have access
to the raw WebsocketFrame.
"""
@asyncio.coroutine
@functools.wraps(func)
def _use_bytes_inner(*args, **kwargs):
data = yield from func(*args, **kwargs)
if isinstance(data, str):
return str.encode(data)
else:
return data
return _use_bytes_inner
[docs]class HandshakeInterator:
def __init__(self, app_iter):
self.content = list(app_iter)
self.index = 0
def __iter__(self):
return self
def __next__(self):
try:
return self.content[self.index]
except IndexError:
raise StopIteration
finally:
self.index += 1
[docs]class SwitchProtocolsResponse(Response):
"""Upgrade from a WSGI connection with the WebSocket handshake."""
def __init__(self, environ, switch_protocols):
super().__init__()
self.status_int = 101
http_1_1 = environ['SERVER_PROTOCOL'] == 'HTTP/1.1'
def get_header(k):
key_map = {k.upper(): k for k in environ}
return environ[key_map['HTTP_' + k.upper().replace('-', '_')]]
key = websockets.handshake.check_request(get_header)
if not http_1_1 or key is None:
self.status_int = 400
self.content = "Invalid WebSocket handshake.\n"
else:
set_header = self.headers.__setitem__
websockets.handshake.build_response(set_header, key)
self.app_iter = HandshakeInterator(self.app_iter)
self.app_iter.close = switch_protocols
[docs]class WebsocketMapper(AsyncioMapperBase):
use_bytes = False
[docs] def launch_websocket_view(self, view):
def websocket_view(context, request):
if inspect.isclass(view):
view_callable = view(context, request)
else:
view_callable = view
@asyncio.coroutine
def _ensure_ws_close(ws):
if WebsocketMapper.use_bytes:
ws.recv = _use_bytes(ws.recv)
ws.recv = _connection_closed_to_none(ws.recv)
yield from view_callable(ws)
yield from ws.close()
def switch_protocols():
# TODO: Determine if there is a more standard way to do this
ws_protocol = websockets.WebSocketCommonProtocol()
transport = request.environ['async.writer']._transport
http_protocol = request.environ['async.protocol']
http_protocol.connection_lost(None)
transport._protocol = ws_protocol
ws_protocol.connection_made(transport)
asyncio.ensure_future(_ensure_ws_close(ws_protocol))
response = SwitchProtocolsResponse(
request.environ,
switch_protocols,
)
# convert iterator to avoid eof issues
response.body = response.body
return response
return websocket_view
def __call__(self, view):
""" Accepts a view_callable class. """
return self.launch_websocket_view(view)