from __future__ import annotations import functools import hmac import http from typing import Any, Awaitable, Callable, Iterable, Tuple, cast from ..datastructures import Headers from ..exceptions import InvalidHeader from ..headers import build_www_authenticate_basic, parse_authorization_basic from .server import HTTPResponse, WebSocketServerProtocol __all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"] # Change to tuple[str, str] when dropping Python < 3.9. Credentials = Tuple[str, str] def is_credentials(value: Any) -> bool: try: username, password = value except (TypeError, ValueError): return False else: return isinstance(username, str) and isinstance(password, str) class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol): """ WebSocket server protocol that enforces HTTP Basic Auth. """ realm: str = "" """ Scope of protection. If provided, it should contain only ASCII characters because the encoding of non-ASCII characters is undefined. """ username: str | None = None """Username of the authenticated user.""" def __init__( self, *args: Any, realm: str | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, **kwargs: Any, ) -> None: if realm is not None: self.realm = realm # shadow class attribute self._check_credentials = check_credentials super().__init__(*args, **kwargs) async def check_credentials(self, username: str, password: str) -> bool: """ Check whether credentials are authorized. This coroutine may be overridden in a subclass, for example to authenticate against a database or an external service. Args: username: HTTP Basic Auth username. password: HTTP Basic Auth password. Returns: :obj:`True` if the handshake should continue; :obj:`False` if it should fail with an HTTP 401 error. """ if self._check_credentials is not None: return await self._check_credentials(username, password) return False async def process_request( self, path: str, request_headers: Headers, ) -> HTTPResponse | None: """ Check HTTP Basic Auth and return an HTTP 401 response if needed. """ try: authorization = request_headers["Authorization"] except KeyError: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Missing credentials\n", ) try: username, password = parse_authorization_basic(authorization) except InvalidHeader: return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Unsupported credentials\n", ) if not await self.check_credentials(username, password): return ( http.HTTPStatus.UNAUTHORIZED, [("WWW-Authenticate", build_www_authenticate_basic(self.realm))], b"Invalid credentials\n", ) self.username = username return await super().process_request(path, request_headers) def basic_auth_protocol_factory( realm: str | None = None, credentials: Credentials | Iterable[Credentials] | None = None, check_credentials: Callable[[str, str], Awaitable[bool]] | None = None, create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None, ) -> Callable[..., BasicAuthWebSocketServerProtocol]: """ Protocol factory that enforces HTTP Basic Auth. :func:`basic_auth_protocol_factory` is designed to integrate with :func:`~websockets.legacy.server.serve` like this:: serve( ..., create_protocol=basic_auth_protocol_factory( realm="my dev server", credentials=("hello", "iloveyou"), ) ) Args: realm: Scope of protection. It should contain only ASCII characters because the encoding of non-ASCII characters is undefined. Refer to section 2.2 of :rfc:`7235` for details. credentials: Hard coded authorized credentials. It can be a ``(username, password)`` pair or a list of such pairs. check_credentials: Coroutine that verifies credentials. It receives ``username`` and ``password`` arguments and returns a :class:`bool`. One of ``credentials`` or ``check_credentials`` must be provided but not both. create_protocol: Factory that creates the protocol. By default, this is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced by a subclass. Raises: TypeError: If the ``credentials`` or ``check_credentials`` argument is wrong. """ if (credentials is None) == (check_credentials is None): raise TypeError("provide either credentials or check_credentials") if credentials is not None: if is_credentials(credentials): credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(cast(Iterable[Credentials], credentials)) if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") else: raise TypeError(f"invalid credentials argument: {credentials}") credentials_dict = dict(credentials_list) async def check_credentials(username: str, password: str) -> bool: try: expected_password = credentials_dict[username] except KeyError: return False return hmac.compare_digest(expected_password, password) if create_protocol is None: create_protocol = BasicAuthWebSocketServerProtocol # Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] | # Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc] create_protocol = cast( Callable[..., BasicAuthWebSocketServerProtocol], create_protocol ) return functools.partial( create_protocol, realm=realm, check_credentials=check_credentials, )