Files
wallter/plotter-app/venv/lib/python3.8/site-packages/websockets/legacy/auth.py
2024-11-20 17:17:14 +01:00

191 lines
6.4 KiB
Python

from __future__ import annotations
import functools
import hmac
import http
from typing import Any, Awaitable, Callable, Iterable, Tuple, cast
from ..datastructures import Headers
from ..exceptions import InvalidHeader
from ..headers import build_www_authenticate_basic, parse_authorization_basic
from .server import HTTPResponse, WebSocketServerProtocol
__all__ = ["BasicAuthWebSocketServerProtocol", "basic_auth_protocol_factory"]
# Change to tuple[str, str] when dropping Python < 3.9.
Credentials = Tuple[str, str]
def is_credentials(value: Any) -> bool:
try:
username, password = value
except (TypeError, ValueError):
return False
else:
return isinstance(username, str) and isinstance(password, str)
class BasicAuthWebSocketServerProtocol(WebSocketServerProtocol):
"""
WebSocket server protocol that enforces HTTP Basic Auth.
"""
realm: str = ""
"""
Scope of protection.
If provided, it should contain only ASCII characters because the
encoding of non-ASCII characters is undefined.
"""
username: str | None = None
"""Username of the authenticated user."""
def __init__(
self,
*args: Any,
realm: str | None = None,
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
**kwargs: Any,
) -> None:
if realm is not None:
self.realm = realm # shadow class attribute
self._check_credentials = check_credentials
super().__init__(*args, **kwargs)
async def check_credentials(self, username: str, password: str) -> bool:
"""
Check whether credentials are authorized.
This coroutine may be overridden in a subclass, for example to
authenticate against a database or an external service.
Args:
username: HTTP Basic Auth username.
password: HTTP Basic Auth password.
Returns:
:obj:`True` if the handshake should continue;
:obj:`False` if it should fail with an HTTP 401 error.
"""
if self._check_credentials is not None:
return await self._check_credentials(username, password)
return False
async def process_request(
self,
path: str,
request_headers: Headers,
) -> HTTPResponse | None:
"""
Check HTTP Basic Auth and return an HTTP 401 response if needed.
"""
try:
authorization = request_headers["Authorization"]
except KeyError:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Missing credentials\n",
)
try:
username, password = parse_authorization_basic(authorization)
except InvalidHeader:
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Unsupported credentials\n",
)
if not await self.check_credentials(username, password):
return (
http.HTTPStatus.UNAUTHORIZED,
[("WWW-Authenticate", build_www_authenticate_basic(self.realm))],
b"Invalid credentials\n",
)
self.username = username
return await super().process_request(path, request_headers)
def basic_auth_protocol_factory(
realm: str | None = None,
credentials: Credentials | Iterable[Credentials] | None = None,
check_credentials: Callable[[str, str], Awaitable[bool]] | None = None,
create_protocol: Callable[..., BasicAuthWebSocketServerProtocol] | None = None,
) -> Callable[..., BasicAuthWebSocketServerProtocol]:
"""
Protocol factory that enforces HTTP Basic Auth.
:func:`basic_auth_protocol_factory` is designed to integrate with
:func:`~websockets.legacy.server.serve` like this::
serve(
...,
create_protocol=basic_auth_protocol_factory(
realm="my dev server",
credentials=("hello", "iloveyou"),
)
)
Args:
realm: Scope of protection. It should contain only ASCII characters
because the encoding of non-ASCII characters is undefined.
Refer to section 2.2 of :rfc:`7235` for details.
credentials: Hard coded authorized credentials. It can be a
``(username, password)`` pair or a list of such pairs.
check_credentials: Coroutine that verifies credentials.
It receives ``username`` and ``password`` arguments
and returns a :class:`bool`. One of ``credentials`` or
``check_credentials`` must be provided but not both.
create_protocol: Factory that creates the protocol. By default, this
is :class:`BasicAuthWebSocketServerProtocol`. It can be replaced
by a subclass.
Raises:
TypeError: If the ``credentials`` or ``check_credentials`` argument is
wrong.
"""
if (credentials is None) == (check_credentials is None):
raise TypeError("provide either credentials or check_credentials")
if credentials is not None:
if is_credentials(credentials):
credentials_list = [cast(Credentials, credentials)]
elif isinstance(credentials, Iterable):
credentials_list = list(cast(Iterable[Credentials], credentials))
if not all(is_credentials(item) for item in credentials_list):
raise TypeError(f"invalid credentials argument: {credentials}")
else:
raise TypeError(f"invalid credentials argument: {credentials}")
credentials_dict = dict(credentials_list)
async def check_credentials(username: str, password: str) -> bool:
try:
expected_password = credentials_dict[username]
except KeyError:
return False
return hmac.compare_digest(expected_password, password)
if create_protocol is None:
create_protocol = BasicAuthWebSocketServerProtocol
# Help mypy and avoid this error: "type[BasicAuthWebSocketServerProtocol] |
# Callable[..., BasicAuthWebSocketServerProtocol]" not callable [misc]
create_protocol = cast(
Callable[..., BasicAuthWebSocketServerProtocol], create_protocol
)
return functools.partial(
create_protocol,
realm=realm,
check_credentials=check_credentials,
)