950 lines
35 KiB
Python
950 lines
35 KiB
Python
import base64
|
|
import functools
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import psycopg2
|
|
import random
|
|
import socket
|
|
import struct
|
|
import selectors
|
|
import threading
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from contextlib import closing, suppress
|
|
from enum import IntEnum
|
|
from psycopg2.pool import PoolError
|
|
from urllib.parse import urlparse
|
|
from weakref import WeakSet
|
|
|
|
from werkzeug.local import LocalStack
|
|
from werkzeug.exceptions import BadRequest, HTTPException, ServiceUnavailable
|
|
|
|
import odoo
|
|
from odoo import api
|
|
from .models.bus import dispatch
|
|
from odoo.http import root, Request, Response, SessionExpiredException, get_default_session
|
|
from odoo.modules.registry import Registry
|
|
from odoo.service import model as service_model
|
|
from odoo.service.server import CommonServer
|
|
from odoo.service.security import check_session
|
|
from odoo.tools import config
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
MAX_TRY_ON_POOL_ERROR = 10
|
|
DELAY_ON_POOL_ERROR = 0.03
|
|
|
|
|
|
def acquire_cursor(db):
|
|
""" Try to acquire a cursor up to `MAX_TRY_ON_POOL_ERROR` """
|
|
for tryno in range(1, MAX_TRY_ON_POOL_ERROR + 1):
|
|
with suppress(PoolError):
|
|
return odoo.registry(db).cursor()
|
|
time.sleep(random.uniform(DELAY_ON_POOL_ERROR, DELAY_ON_POOL_ERROR * tryno))
|
|
raise PoolError('Failed to acquire cursor after %s retries' % MAX_TRY_ON_POOL_ERROR)
|
|
|
|
|
|
# ------------------------------------------------------
|
|
# EXCEPTIONS
|
|
# ------------------------------------------------------
|
|
|
|
class UpgradeRequired(HTTPException):
|
|
code = 426
|
|
description = "Wrong websocket version was given during the handshake"
|
|
|
|
def get_headers(self, environ=None):
|
|
headers = super().get_headers(environ)
|
|
headers.append((
|
|
'Sec-WebSocket-Version',
|
|
'; '.join(WebsocketConnectionHandler.SUPPORTED_VERSIONS)
|
|
))
|
|
return headers
|
|
|
|
|
|
class WebsocketException(Exception):
|
|
""" Base class for all websockets exceptions """
|
|
|
|
|
|
class ConnectionClosed(WebsocketException):
|
|
"""
|
|
Raised when the other end closes the socket without performing
|
|
the closing handshake.
|
|
"""
|
|
|
|
|
|
class InvalidCloseCodeException(WebsocketException):
|
|
def __init__(self, code):
|
|
super().__init__(f"Invalid close code: {code}")
|
|
|
|
|
|
class InvalidDatabaseException(WebsocketException):
|
|
"""
|
|
When raised: the database probably does not exists anymore, the
|
|
database is corrupted or the database version doesn't match the
|
|
server version.
|
|
"""
|
|
|
|
|
|
class InvalidStateException(WebsocketException):
|
|
"""
|
|
Raised when an operation is forbidden in the current state.
|
|
"""
|
|
|
|
|
|
class InvalidWebsocketRequest(WebsocketException):
|
|
"""
|
|
Raised when a websocket request is invalid (format, wrong args).
|
|
"""
|
|
|
|
|
|
class PayloadTooLargeException(WebsocketException):
|
|
"""
|
|
Raised when a websocket message is too large.
|
|
"""
|
|
|
|
|
|
class ProtocolError(WebsocketException):
|
|
"""
|
|
Raised when a frame format doesn't match expectations.
|
|
"""
|
|
|
|
|
|
class RateLimitExceededException(Exception):
|
|
"""
|
|
Raised when a client exceeds the number of request in a given
|
|
time.
|
|
"""
|
|
|
|
|
|
# ------------------------------------------------------
|
|
# WEBSOCKET LIFECYCLE
|
|
# ------------------------------------------------------
|
|
|
|
|
|
class LifecycleEvent(IntEnum):
|
|
OPEN = 0
|
|
CLOSE = 1
|
|
|
|
|
|
# ------------------------------------------------------
|
|
# WEBSOCKET
|
|
# ------------------------------------------------------
|
|
|
|
|
|
class Opcode(IntEnum):
|
|
CONTINUE = 0x00
|
|
TEXT = 0x01
|
|
BINARY = 0x02
|
|
CLOSE = 0x08
|
|
PING = 0x09
|
|
PONG = 0x0A
|
|
|
|
|
|
class CloseCode(IntEnum):
|
|
CLEAN = 1000
|
|
GOING_AWAY = 1001
|
|
PROTOCOL_ERROR = 1002
|
|
INCORRECT_DATA = 1003
|
|
ABNORMAL_CLOSURE = 1006
|
|
INCONSISTENT_DATA = 1007
|
|
MESSAGE_VIOLATING_POLICY = 1008
|
|
MESSAGE_TOO_BIG = 1009
|
|
EXTENSION_NEGOTIATION_FAILED = 1010
|
|
SERVER_ERROR = 1011
|
|
RESTART = 1012
|
|
TRY_LATER = 1013
|
|
BAD_GATEWAY = 1014
|
|
SESSION_EXPIRED = 4001
|
|
KEEP_ALIVE_TIMEOUT = 4002
|
|
|
|
|
|
class ConnectionState(IntEnum):
|
|
OPEN = 0
|
|
CLOSING = 1
|
|
CLOSED = 2
|
|
|
|
|
|
DATA_OP = {Opcode.TEXT, Opcode.BINARY}
|
|
CTRL_OP = {Opcode.CLOSE, Opcode.PING, Opcode.PONG}
|
|
HEARTBEAT_OP = {Opcode.PING, Opcode.PONG}
|
|
|
|
VALID_CLOSE_CODES = {
|
|
code for code in CloseCode if code is not CloseCode.ABNORMAL_CLOSURE
|
|
}
|
|
CLEAN_CLOSE_CODES = {CloseCode.CLEAN, CloseCode.GOING_AWAY, CloseCode.RESTART}
|
|
RESERVED_CLOSE_CODES = range(3000, 5000)
|
|
|
|
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
|
|
|
|
|
|
class Frame:
|
|
def __init__(
|
|
self,
|
|
opcode,
|
|
payload=b'',
|
|
fin=True,
|
|
rsv1=False,
|
|
rsv2=False,
|
|
rsv3=False
|
|
):
|
|
self.opcode = opcode
|
|
self.payload = payload
|
|
self.fin = fin
|
|
self.rsv1 = rsv1
|
|
self.rsv2 = rsv2
|
|
self.rsv3 = rsv3
|
|
|
|
|
|
class CloseFrame(Frame):
|
|
def __init__(self, code, reason):
|
|
if code not in VALID_CLOSE_CODES and code not in RESERVED_CLOSE_CODES:
|
|
raise InvalidCloseCodeException(code)
|
|
payload = struct.pack('!H', code)
|
|
if reason:
|
|
payload += reason.encode('utf-8')
|
|
self.code = code
|
|
self.reason = reason
|
|
super().__init__(Opcode.CLOSE, payload)
|
|
|
|
|
|
_websocket_instances = WeakSet()
|
|
|
|
|
|
class Websocket:
|
|
__event_callbacks = defaultdict(set)
|
|
# Maximum size for a message in bytes, whether it is sent as one
|
|
# frame or many fragmented ones.
|
|
MESSAGE_MAX_SIZE = 2 ** 20
|
|
# Proxies usually close a connection after 1 minute of inactivity.
|
|
# Therefore, a PING frame have to be sent if no frame is either sent
|
|
# or received within CONNECTION_TIMEOUT - 15 seconds.
|
|
CONNECTION_TIMEOUT = 60
|
|
INACTIVITY_TIMEOUT = CONNECTION_TIMEOUT - 15
|
|
# How many requests can be made in excess of the given rate.
|
|
RL_BURST = int(config['websocket_rate_limit_burst'])
|
|
# How many seconds between each request.
|
|
RL_DELAY = float(config['websocket_rate_limit_delay'])
|
|
|
|
def __init__(self, sock, session):
|
|
# Session linked to the current websocket connection.
|
|
self._session = session
|
|
self._db = session.db
|
|
self.__socket = sock
|
|
self._close_sent = False
|
|
self._close_received = False
|
|
self._timeout_manager = TimeoutManager()
|
|
# Used for rate limiting.
|
|
self._incoming_frame_timestamps = deque(maxlen=self.RL_BURST)
|
|
# Used to notify the websocket that bus notifications are
|
|
# available.
|
|
self.__notif_sock_w, self.__notif_sock_r = socket.socketpair()
|
|
self._channels = set()
|
|
self._last_notif_sent_id = 0
|
|
# Websocket start up
|
|
self.__selector = (
|
|
selectors.PollSelector()
|
|
if odoo.evented and hasattr(selectors, 'PollSelector')
|
|
else selectors.DefaultSelector()
|
|
)
|
|
self.__selector.register(self.__socket, selectors.EVENT_READ)
|
|
self.__selector.register(self.__notif_sock_r, selectors.EVENT_READ)
|
|
self.state = ConnectionState.OPEN
|
|
_websocket_instances.add(self)
|
|
self._trigger_lifecycle_event(LifecycleEvent.OPEN)
|
|
|
|
# ------------------------------------------------------
|
|
# PUBLIC METHODS
|
|
# ------------------------------------------------------
|
|
|
|
def get_messages(self):
|
|
while self.state is not ConnectionState.CLOSED:
|
|
try:
|
|
readables = {
|
|
selector_key[0].fileobj for selector_key in
|
|
self.__selector.select(self.INACTIVITY_TIMEOUT)
|
|
}
|
|
if self._timeout_manager.has_timed_out() and self.state is ConnectionState.OPEN:
|
|
self.disconnect(
|
|
CloseCode.ABNORMAL_CLOSURE
|
|
if self._timeout_manager.timeout_reason is TimeoutReason.NO_RESPONSE
|
|
else CloseCode.KEEP_ALIVE_TIMEOUT
|
|
)
|
|
continue
|
|
if not readables:
|
|
self._send_ping_frame()
|
|
continue
|
|
if self.__notif_sock_r in readables:
|
|
self._dispatch_bus_notifications()
|
|
if self.__socket in readables:
|
|
message = self._process_next_message()
|
|
if message is not None:
|
|
yield message
|
|
except Exception as exc:
|
|
self._handle_transport_error(exc)
|
|
|
|
def disconnect(self, code, reason=None):
|
|
"""
|
|
Initiate the closing handshake that is, send a close frame
|
|
to the other end which will then send us back an
|
|
acknowledgment. Upon the reception of this acknowledgment,
|
|
the `_terminate` method will be called to perform an
|
|
orderly shutdown. Note that we don't need to wait for the
|
|
acknowledgment if the connection was failed beforewards.
|
|
"""
|
|
if code is not CloseCode.ABNORMAL_CLOSURE:
|
|
self._send_close_frame(code, reason)
|
|
else:
|
|
self._terminate()
|
|
|
|
@classmethod
|
|
def onopen(cls, func):
|
|
cls.__event_callbacks[LifecycleEvent.OPEN].add(func)
|
|
return func
|
|
|
|
@classmethod
|
|
def onclose(cls, func):
|
|
cls.__event_callbacks[LifecycleEvent.CLOSE].add(func)
|
|
return func
|
|
|
|
def subscribe(self, channels, last):
|
|
""" Subscribe to bus channels. """
|
|
self._channels = channels
|
|
if self._last_notif_sent_id < last:
|
|
self._last_notif_sent_id = last
|
|
# Dispatch past notifications if there are any.
|
|
self.trigger_notification_dispatching()
|
|
|
|
def trigger_notification_dispatching(self):
|
|
"""
|
|
Warn the socket that notifications are available. Ignore if a
|
|
dispatch is already planned or if the socket is already in the
|
|
closing state.
|
|
"""
|
|
if self.state is not ConnectionState.OPEN:
|
|
return
|
|
readables = {
|
|
selector_key[0].fileobj for selector_key in
|
|
self.__selector.select(0)
|
|
}
|
|
if self.__notif_sock_r not in readables:
|
|
# Send a random bit to mark the socket as readable.
|
|
self.__notif_sock_w.send(b'x')
|
|
|
|
# ------------------------------------------------------
|
|
# PRIVATE METHODS
|
|
# ------------------------------------------------------
|
|
|
|
def _get_next_frame(self):
|
|
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
|
# +-+-+-+-+-------+-+-------------+-------------------------------+
|
|
# |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
|
# |I|S|S|S| (4) |A| (7) | (16/64) |
|
|
# |N|V|V|V| |S| | (if payload len==126/127) |
|
|
# | |1|2|3| |K| | |
|
|
# +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
|
# | Extended payload length continued, if payload len == 127 |
|
|
# + - - - - - - - - - - - - - - - +-------------------------------+
|
|
# | |Masking-key, if MASK set to 1 |
|
|
# +-------------------------------+-------------------------------+
|
|
# | Masking-key (continued) | Payload Data |
|
|
# +-------------------------------- - - - - - - - - - - - - - - - +
|
|
# : Payload Data continued ... :
|
|
# + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
|
# | Payload Data continued ... |
|
|
# +---------------------------------------------------------------+
|
|
def recv_bytes(n):
|
|
""" Pull n bytes from the socket """
|
|
data = bytearray()
|
|
while len(data) < n:
|
|
received_data = self.__socket.recv(n - len(data))
|
|
if not received_data:
|
|
raise ConnectionClosed()
|
|
data.extend(received_data)
|
|
return data
|
|
|
|
def is_bit_set(byte, n):
|
|
"""
|
|
Check whether nth bit of byte is set or not (from left
|
|
to right).
|
|
"""
|
|
return byte & (1 << (7 - n))
|
|
|
|
def apply_mask(payload, mask):
|
|
# see: https://www.willmcgugan.com/blog/tech/post/speeding-up-websockets-60x/
|
|
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
|
payload[::4] = payload[::4].translate(a)
|
|
payload[1::4] = payload[1::4].translate(b)
|
|
payload[2::4] = payload[2::4].translate(c)
|
|
payload[3::4] = payload[3::4].translate(d)
|
|
return payload
|
|
|
|
self._limit_rate()
|
|
first_byte, second_byte = recv_bytes(2)
|
|
fin, rsv1, rsv2, rsv3 = (is_bit_set(first_byte, n) for n in range(4))
|
|
try:
|
|
opcode = Opcode(first_byte & 0b00001111)
|
|
except ValueError as exc:
|
|
raise ProtocolError(exc)
|
|
payload_length = second_byte & 0b01111111
|
|
|
|
if rsv1 or rsv2 or rsv3:
|
|
raise ProtocolError("Reserved bits must be unset")
|
|
if not is_bit_set(second_byte, 0):
|
|
raise ProtocolError("Frame must be masked")
|
|
if opcode in CTRL_OP:
|
|
if not fin:
|
|
raise ProtocolError("Control frames cannot be fragmented")
|
|
if payload_length > 125:
|
|
raise ProtocolError(
|
|
"Control frames payload must be smaller than 126"
|
|
)
|
|
if payload_length == 126:
|
|
payload_length = struct.unpack('!H', recv_bytes(2))[0]
|
|
elif payload_length == 127:
|
|
payload_length = struct.unpack('!Q', recv_bytes(8))[0]
|
|
if payload_length > self.MESSAGE_MAX_SIZE:
|
|
raise PayloadTooLargeException()
|
|
|
|
mask = recv_bytes(4)
|
|
payload = apply_mask(recv_bytes(payload_length), mask)
|
|
frame = Frame(opcode, bytes(payload), fin, rsv1, rsv2, rsv3)
|
|
self._timeout_manager.acknowledge_frame_receipt(frame)
|
|
return frame
|
|
|
|
def _process_next_message(self):
|
|
"""
|
|
Process the next message coming throught the socket. If a
|
|
data message can be extracted, return its decoded payload.
|
|
As per the RFC, only control frames will be processed once
|
|
the connection reaches the closing state.
|
|
"""
|
|
frame = self._get_next_frame()
|
|
if frame.opcode in CTRL_OP:
|
|
self._handle_control_frame(frame)
|
|
return
|
|
if self.state is not ConnectionState.OPEN:
|
|
# After receiving a control frame indicating the connection
|
|
# should be closed, a peer discards any further data
|
|
# received.
|
|
return
|
|
if frame.opcode is Opcode.CONTINUE:
|
|
raise ProtocolError("Unexpected continuation frame")
|
|
message = frame.payload
|
|
if not frame.fin:
|
|
message = self._recover_fragmented_message(frame)
|
|
return (
|
|
message.decode('utf-8')
|
|
if message is not None and frame.opcode is Opcode.TEXT else message
|
|
)
|
|
|
|
def _recover_fragmented_message(self, initial_frame):
|
|
message_fragments = bytearray(initial_frame.payload)
|
|
while True:
|
|
frame = self._get_next_frame()
|
|
if frame.opcode in CTRL_OP:
|
|
# Control frames can be received in the middle of a
|
|
# fragmented message, process them as soon as possible.
|
|
self._handle_control_frame(frame)
|
|
if self.state is not ConnectionState.OPEN:
|
|
return
|
|
continue
|
|
if frame.opcode is not Opcode.CONTINUE:
|
|
raise ProtocolError("A continuation frame was expected")
|
|
message_fragments.extend(frame.payload)
|
|
if len(message_fragments) > self.MESSAGE_MAX_SIZE:
|
|
raise PayloadTooLargeException()
|
|
if frame.fin:
|
|
return bytes(message_fragments)
|
|
|
|
def _send(self, message):
|
|
if self.state is not ConnectionState.OPEN:
|
|
raise InvalidStateException(
|
|
"Trying to send a frame on a closed socket"
|
|
)
|
|
opcode = Opcode.BINARY
|
|
if not isinstance(message, (bytes, bytearray)):
|
|
opcode = Opcode.TEXT
|
|
self._send_frame(Frame(opcode, message))
|
|
|
|
def _send_frame(self, frame):
|
|
if frame.opcode in CTRL_OP and len(frame.payload) > 125:
|
|
raise ProtocolError(
|
|
"Control frames should have a payload length smaller than 126"
|
|
)
|
|
if isinstance(frame.payload, str):
|
|
frame.payload = frame.payload.encode('utf-8')
|
|
elif not isinstance(frame.payload, (bytes, bytearray)):
|
|
frame.payload = json.dumps(frame.payload).encode('utf-8')
|
|
|
|
output = bytearray()
|
|
first_byte = (
|
|
(0b10000000 if frame.fin else 0)
|
|
| (0b01000000 if frame.rsv1 else 0)
|
|
| (0b00100000 if frame.rsv2 else 0)
|
|
| (0b00010000 if frame.rsv3 else 0)
|
|
| frame.opcode
|
|
)
|
|
payload_length = len(frame.payload)
|
|
if payload_length < 126:
|
|
output.extend(
|
|
struct.pack('!BB', first_byte, payload_length)
|
|
)
|
|
elif payload_length < 65536:
|
|
output.extend(
|
|
struct.pack('!BBH', first_byte, 126, payload_length)
|
|
)
|
|
else:
|
|
output.extend(
|
|
struct.pack('!BBQ', first_byte, 127, payload_length)
|
|
)
|
|
output.extend(frame.payload)
|
|
self.__socket.sendall(output)
|
|
self._timeout_manager.acknowledge_frame_sent(frame)
|
|
if not isinstance(frame, CloseFrame):
|
|
return
|
|
self.state = ConnectionState.CLOSING
|
|
self._close_sent = True
|
|
if frame.code not in CLEAN_CLOSE_CODES or self._close_received:
|
|
return self._terminate()
|
|
# After sending a control frame indicating the connection
|
|
# should be closed, a peer does not send any further data.
|
|
self.__selector.unregister(self.__notif_sock_r)
|
|
|
|
def _send_close_frame(self, code, reason=None):
|
|
""" Send a close frame. """
|
|
self._send_frame(CloseFrame(code, reason))
|
|
|
|
def _send_ping_frame(self):
|
|
""" Send a ping frame """
|
|
self._send_frame(Frame(Opcode.PING))
|
|
|
|
def _send_pong_frame(self, payload):
|
|
""" Send a pong frame """
|
|
self._send_frame(Frame(Opcode.PONG, payload))
|
|
|
|
def _terminate(self):
|
|
""" Close the underlying TCP socket. """
|
|
with suppress(OSError, TimeoutError):
|
|
self.__socket.shutdown(socket.SHUT_WR)
|
|
# Call recv until obtaining a return value of 0 indicating
|
|
# the other end has performed an orderly shutdown. A timeout
|
|
# is set to ensure the connection will be closed even if
|
|
# the other end does not close the socket properly.
|
|
self.__socket.settimeout(1)
|
|
while self.__socket.recv(4096):
|
|
pass
|
|
self.__selector.unregister(self.__socket)
|
|
self.__selector.close()
|
|
self.__socket.close()
|
|
self.state = ConnectionState.CLOSED
|
|
dispatch.unsubscribe(self)
|
|
self._trigger_lifecycle_event(LifecycleEvent.CLOSE)
|
|
|
|
def _handle_control_frame(self, frame):
|
|
if frame.opcode is Opcode.PING:
|
|
self._send_pong_frame(frame.payload)
|
|
elif frame.opcode is Opcode.CLOSE:
|
|
self.state = ConnectionState.CLOSING
|
|
self._close_received = True
|
|
code, reason = CloseCode.CLEAN, None
|
|
if len(frame.payload) >= 2:
|
|
code = struct.unpack('!H', frame.payload[:2])[0]
|
|
reason = frame.payload[2:].decode('utf-8')
|
|
elif frame.payload:
|
|
raise ProtocolError("Malformed closing frame")
|
|
if not self._close_sent:
|
|
self._send_close_frame(code, reason)
|
|
else:
|
|
self._terminate()
|
|
|
|
def _handle_transport_error(self, exc):
|
|
"""
|
|
Find out which close code should be sent according to given
|
|
exception and call `self.disconnect` in order to close the
|
|
connection cleanly.
|
|
"""
|
|
code, reason = CloseCode.SERVER_ERROR, str(exc)
|
|
if isinstance(exc, (ConnectionClosed, OSError)):
|
|
code = CloseCode.ABNORMAL_CLOSURE
|
|
elif isinstance(exc, (ProtocolError, InvalidCloseCodeException)):
|
|
code = CloseCode.PROTOCOL_ERROR
|
|
elif isinstance(exc, UnicodeDecodeError):
|
|
code = CloseCode.INCONSISTENT_DATA
|
|
elif isinstance(exc, PayloadTooLargeException):
|
|
code = CloseCode.MESSAGE_TOO_BIG
|
|
elif isinstance(exc, (PoolError, RateLimitExceededException)):
|
|
code = CloseCode.TRY_LATER
|
|
elif isinstance(exc, SessionExpiredException):
|
|
code = CloseCode.SESSION_EXPIRED
|
|
if code is CloseCode.SERVER_ERROR:
|
|
reason = None
|
|
registry = Registry(self._session.db)
|
|
sequence = registry.registry_sequence
|
|
registry = registry.check_signaling()
|
|
if sequence != registry.registry_sequence:
|
|
_logger.warning("Bus operation aborted; registry has been reloaded")
|
|
else:
|
|
_logger.error(exc, exc_info=True)
|
|
self.disconnect(code, reason)
|
|
|
|
def _limit_rate(self):
|
|
"""
|
|
This method is a simple rate limiter designed not to allow
|
|
more than one request by `RL_DELAY` seconds. `RL_BURST` specify
|
|
how many requests can be made in excess of the given rate at the
|
|
begining. When requests are received too fast, raises the
|
|
`RateLimitExceededException`.
|
|
"""
|
|
now = time.time()
|
|
if len(self._incoming_frame_timestamps) >= self.RL_BURST:
|
|
elapsed_time = now - self._incoming_frame_timestamps[0]
|
|
if elapsed_time < self.RL_DELAY * self.RL_BURST:
|
|
raise RateLimitExceededException()
|
|
self._incoming_frame_timestamps.append(now)
|
|
|
|
def _trigger_lifecycle_event(self, event_type):
|
|
"""
|
|
Trigger a lifecycle event that is, call every function
|
|
registered for this event type. Every callback is given both the
|
|
environment and the related websocket.
|
|
"""
|
|
if not self.__event_callbacks[event_type]:
|
|
return
|
|
with closing(acquire_cursor(self._db)) as cr:
|
|
env = api.Environment(cr, self._session.uid, self._session.context)
|
|
for callback in self.__event_callbacks[event_type]:
|
|
try:
|
|
service_model.retrying(functools.partial(callback, env, self), env)
|
|
except Exception:
|
|
_logger.warning(
|
|
'Error during Websocket %s callback',
|
|
LifecycleEvent(event_type).name,
|
|
exc_info=True
|
|
)
|
|
|
|
def _dispatch_bus_notifications(self):
|
|
"""
|
|
Dispatch notifications related to the registered channels. If
|
|
the session is expired, close the connection with the
|
|
`SESSION_EXPIRED` close code. If no cursor can be acquired,
|
|
close the connection with the `TRY_LATER` close code.
|
|
"""
|
|
session = root.session_store.get(self._session.sid)
|
|
if not session:
|
|
raise SessionExpiredException()
|
|
with acquire_cursor(session.db) as cr:
|
|
env = api.Environment(cr, session.uid, session.context)
|
|
if session.uid is not None and not check_session(session, env):
|
|
raise SessionExpiredException()
|
|
# Mark the notification request as processed.
|
|
self.__notif_sock_r.recv(1)
|
|
notifications = env['bus.bus']._poll(self._channels, self._last_notif_sent_id)
|
|
if not notifications:
|
|
return
|
|
self._last_notif_sent_id = notifications[-1]['id']
|
|
self._send(notifications)
|
|
|
|
|
|
class TimeoutReason(IntEnum):
|
|
KEEP_ALIVE = 0
|
|
NO_RESPONSE = 1
|
|
|
|
|
|
class TimeoutManager:
|
|
"""
|
|
This class handles the Websocket timeouts. If no response to a
|
|
PING/CLOSE frame is received after `TIMEOUT` seconds or if the
|
|
connection is opened for more than `self._keep_alive_timeout` seconds,
|
|
the connection is considered to have timed out. To determine if the
|
|
connection has timed out, use the `has_timed_out` method.
|
|
"""
|
|
TIMEOUT = 15
|
|
# Timeout specifying how many seconds the connection should be kept
|
|
# alive.
|
|
KEEP_ALIVE_TIMEOUT = int(config['websocket_keep_alive_timeout'])
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._awaited_opcode = None
|
|
# Time in which the connection was opened.
|
|
self._opened_at = time.time()
|
|
# Custom keep alive timeout for each TimeoutManager to avoid multiple
|
|
# connections timing out at the same time.
|
|
self._keep_alive_timeout = (
|
|
self.KEEP_ALIVE_TIMEOUT + random.uniform(0, self.KEEP_ALIVE_TIMEOUT / 2)
|
|
)
|
|
self.timeout_reason = None
|
|
# Start time recorded when we started awaiting an answer to a
|
|
# PING/CLOSE frame.
|
|
self._waiting_start_time = None
|
|
|
|
def acknowledge_frame_receipt(self, frame):
|
|
if self._awaited_opcode is frame.opcode:
|
|
self._awaited_opcode = None
|
|
self._waiting_start_time = None
|
|
|
|
def acknowledge_frame_sent(self, frame):
|
|
"""
|
|
Acknowledge a frame was sent. If this frame is a PING/CLOSE
|
|
frame, start waiting for an answer.
|
|
"""
|
|
if self.has_timed_out():
|
|
return
|
|
if frame.opcode is Opcode.PING:
|
|
self._awaited_opcode = Opcode.PONG
|
|
elif frame.opcode is Opcode.CLOSE:
|
|
self._awaited_opcode = Opcode.CLOSE
|
|
if self._awaited_opcode is not None:
|
|
self._waiting_start_time = time.time()
|
|
|
|
def has_timed_out(self):
|
|
"""
|
|
Determine whether the connection has timed out or not. The
|
|
connection times out when the answer to a CLOSE/PING frame
|
|
is not received within `TIMEOUT` seconds or if the connection
|
|
is opened for more than `self._keep_alive_timeout` seconds.
|
|
"""
|
|
now = time.time()
|
|
if now - self._opened_at >= self._keep_alive_timeout:
|
|
self.timeout_reason = TimeoutReason.KEEP_ALIVE
|
|
return True
|
|
if self._awaited_opcode and now - self._waiting_start_time >= self.TIMEOUT:
|
|
self.timeout_reason = TimeoutReason.NO_RESPONSE
|
|
return True
|
|
return False
|
|
|
|
|
|
# ------------------------------------------------------
|
|
# WEBSOCKET SERVING
|
|
# ------------------------------------------------------
|
|
|
|
|
|
_wsrequest_stack = LocalStack()
|
|
wsrequest = _wsrequest_stack()
|
|
|
|
class WebsocketRequest:
|
|
def __init__(self, db, httprequest, websocket):
|
|
self.db = db
|
|
self.httprequest = httprequest
|
|
self.session = None
|
|
self.ws = websocket
|
|
|
|
def __enter__(self):
|
|
_wsrequest_stack.push(self)
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
_wsrequest_stack.pop()
|
|
|
|
def serve_websocket_message(self, message):
|
|
try:
|
|
jsonrequest = json.loads(message)
|
|
event_name = jsonrequest['event_name'] # mandatory
|
|
except KeyError as exc:
|
|
raise InvalidWebsocketRequest(
|
|
f'Key {exc.args[0]!r} is missing from request'
|
|
) from exc
|
|
except ValueError as exc:
|
|
raise InvalidWebsocketRequest(
|
|
f'Invalid JSON data, {exc.args[0]}'
|
|
) from exc
|
|
data = jsonrequest.get('data')
|
|
self.session = self._get_session()
|
|
|
|
try:
|
|
self.registry = Registry(self.db)
|
|
self.registry.check_signaling()
|
|
except (
|
|
AttributeError, psycopg2.OperationalError, psycopg2.ProgrammingError
|
|
) as exc:
|
|
raise InvalidDatabaseException() from exc
|
|
|
|
with closing(acquire_cursor(self.db)) as cr:
|
|
self.env = api.Environment(cr, self.session.uid, self.session.context)
|
|
threading.current_thread().uid = self.env.uid
|
|
service_model.retrying(
|
|
functools.partial(self._serve_ir_websocket, event_name, data),
|
|
self.env,
|
|
)
|
|
|
|
def _serve_ir_websocket(self, event_name, data):
|
|
"""
|
|
Delegate most of the processing to the ir.websocket model
|
|
which is extensible by applications. Directly call the
|
|
appropriate ir.websocket method since only two events are
|
|
tolerated: `subscribe` and `update_presence`.
|
|
"""
|
|
self.env['ir.websocket']._authenticate()
|
|
if event_name == 'subscribe':
|
|
self.env['ir.websocket']._subscribe(data)
|
|
if event_name == 'update_presence':
|
|
self.env['ir.websocket']._update_bus_presence(**data)
|
|
|
|
def _get_session(self):
|
|
session = root.session_store.get(self.ws._session.sid)
|
|
if not session:
|
|
raise SessionExpiredException()
|
|
return session
|
|
|
|
def update_env(self, user=None, context=None, su=None):
|
|
"""
|
|
Update the environment of the current websocket request.
|
|
"""
|
|
Request.update_env(self, user, context, su)
|
|
|
|
def update_context(self, **overrides):
|
|
"""
|
|
Override the environment context of the current request with the
|
|
values of ``overrides``. To replace the entire context, please
|
|
use :meth:`~update_env` instead.
|
|
"""
|
|
self.update_env(context=dict(self.env.context, **overrides))
|
|
|
|
|
|
class WebsocketConnectionHandler:
|
|
SUPPORTED_VERSIONS = {'13'}
|
|
# Given by the RFC in order to generate Sec-WebSocket-Accept from
|
|
# Sec-WebSocket-Key value.
|
|
_HANDSHAKE_GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
_REQUIRED_HANDSHAKE_HEADERS = {
|
|
'connection', 'host', 'sec-websocket-key',
|
|
'sec-websocket-version', 'upgrade', 'origin',
|
|
}
|
|
|
|
@classmethod
|
|
def websocket_allowed(cls, request):
|
|
return not request.registry.in_test_mode()
|
|
|
|
@classmethod
|
|
def open_connection(cls, request):
|
|
"""
|
|
Open a websocket connection if the handshake is successfull.
|
|
:return: Response indicating the server performed a connection
|
|
upgrade.
|
|
:raise: UpgradeRequired if there is no intersection between the
|
|
versions the client supports and those we support.
|
|
:raise: BadRequest if the handshake data is incorrect.
|
|
"""
|
|
if not cls.websocket_allowed(request):
|
|
raise ServiceUnavailable("Websocket is disabled in test mode")
|
|
cls._handle_public_configuration(request)
|
|
try:
|
|
response = cls._get_handshake_response(request.httprequest.headers)
|
|
socket = request.httprequest._HTTPRequest__environ['socket']
|
|
session, db, httprequest = request.session, request.db, request.httprequest
|
|
response.call_on_close(lambda: cls._serve_forever(
|
|
Websocket(socket, session),
|
|
db,
|
|
httprequest,
|
|
))
|
|
# Force save the session. Session must be persisted to handle
|
|
# WebSocket authentication.
|
|
request.session.is_dirty = True
|
|
return response
|
|
except KeyError as exc:
|
|
raise RuntimeError(
|
|
f"Couldn't bind the websocket. Is the connection opened on the evented port ({config['gevent_port']})?"
|
|
) from exc
|
|
except HTTPException as exc:
|
|
# The HTTP stack does not log exceptions derivated from the
|
|
# HTTPException class since they are valid responses.
|
|
_logger.error(exc)
|
|
raise
|
|
|
|
|
|
|
|
@classmethod
|
|
def _get_handshake_response(cls, headers):
|
|
"""
|
|
:return: Response indicating the server performed a connection
|
|
upgrade.
|
|
:raise: BadRequest
|
|
:raise: UpgradeRequired
|
|
"""
|
|
cls._assert_handshake_validity(headers)
|
|
# sha-1 is used as it is required by
|
|
# https://datatracker.ietf.org/doc/html/rfc6455#page-7
|
|
accept_header = hashlib.sha1(
|
|
(headers['sec-websocket-key'] + cls._HANDSHAKE_GUID).encode()).digest()
|
|
accept_header = base64.b64encode(accept_header)
|
|
return Response(status=101, headers={
|
|
'Upgrade': 'websocket',
|
|
'Connection': 'Upgrade',
|
|
'Sec-WebSocket-Accept': accept_header.decode(),
|
|
})
|
|
|
|
@classmethod
|
|
def _handle_public_configuration(cls, request):
|
|
if not os.getenv('ODOO_BUS_PUBLIC_SAMESITE_WS'):
|
|
return
|
|
headers = request.httprequest.headers
|
|
origin_url = urlparse(headers.get('origin'))
|
|
if origin_url.netloc != headers.get('host') or origin_url.scheme != request.httprequest.scheme:
|
|
request.session = root.session_store.new()
|
|
request.session.update(get_default_session(), db=request.session.db)
|
|
request.session.is_explicit = True
|
|
|
|
@classmethod
|
|
def _assert_handshake_validity(cls, headers):
|
|
"""
|
|
:raise: UpgradeRequired if there is no intersection between
|
|
the version the client supports and those we support.
|
|
:raise: BadRequest in case of invalid handshake.
|
|
"""
|
|
missing_or_empty_headers = {
|
|
header for header in cls._REQUIRED_HANDSHAKE_HEADERS
|
|
if header not in headers
|
|
}
|
|
if missing_or_empty_headers:
|
|
raise BadRequest(
|
|
f"""Empty or missing header(s): {', '.join(missing_or_empty_headers)}"""
|
|
)
|
|
|
|
if headers['upgrade'].lower() != 'websocket':
|
|
raise BadRequest('Invalid upgrade header')
|
|
if 'upgrade' not in headers['connection'].lower():
|
|
raise BadRequest('Invalid connection header')
|
|
if headers['sec-websocket-version'] not in cls.SUPPORTED_VERSIONS:
|
|
raise UpgradeRequired()
|
|
|
|
key = headers['sec-websocket-key']
|
|
try:
|
|
decoded_key = base64.b64decode(key, validate=True)
|
|
except ValueError:
|
|
raise BadRequest("Sec-WebSocket-Key should be b64 encoded")
|
|
if len(decoded_key) != 16:
|
|
raise BadRequest(
|
|
"Sec-WebSocket-Key should be of length 16 once decoded"
|
|
)
|
|
|
|
@classmethod
|
|
def _serve_forever(cls, websocket, db, httprequest):
|
|
"""
|
|
Process incoming messages and dispatch them to the application.
|
|
"""
|
|
current_thread = threading.current_thread()
|
|
current_thread.type = 'websocket'
|
|
for message in websocket.get_messages():
|
|
with WebsocketRequest(db, httprequest, websocket) as req:
|
|
try:
|
|
req.serve_websocket_message(message)
|
|
except SessionExpiredException:
|
|
websocket.disconnect(CloseCode.SESSION_EXPIRED)
|
|
except PoolError:
|
|
websocket.disconnect(CloseCode.TRY_LATER)
|
|
except Exception:
|
|
_logger.exception("Exception occurred during websocket request handling")
|
|
|
|
|
|
def _kick_all():
|
|
""" Disconnect all the websocket instances. """
|
|
for websocket in _websocket_instances:
|
|
if websocket.state is ConnectionState.OPEN:
|
|
websocket.disconnect(CloseCode.GOING_AWAY)
|
|
|
|
|
|
CommonServer.on_stop(_kick_all)
|