diff --git a/bluesky_httpserver/_authentication.py b/bluesky_httpserver/_authentication.py index 992c5fb..7b9f06e 100644 --- a/bluesky_httpserver/_authentication.py +++ b/bluesky_httpserver/_authentication.py @@ -3,7 +3,7 @@ import secrets import uuid as uuid_module import warnings -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Optional from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket @@ -19,7 +19,7 @@ # int_from_bytes is deprecated, use int.from_bytes instead with warnings.catch_warnings(): warnings.simplefilter("ignore") - from jose import ExpiredSignatureError, JWTError, jwt + from jose import ExpiredSignatureError import pydantic from packaging import version @@ -31,6 +31,9 @@ from pydantic_settings import BaseSettings from . import schemas +from bluesky_authentication import tokens as auth_tokens +from bluesky_authentication.tokens import decode_token +from bluesky_authentication.utils import extract_scopes, find_proxied_authenticator from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME from .core import json_or_msgpack from .database import orm @@ -54,7 +57,6 @@ get_current_username, ) -ALGORITHM = "HS256" UNIT_SECOND = timedelta(seconds=1) # Device code flow constants @@ -64,7 +66,7 @@ def utcnow(): "UTC now with second resolution" - return datetime.utcnow().replace(microsecond=0) + return datetime.now(timezone.utc).replace(microsecond=0) class Token(BaseModel): @@ -123,71 +125,26 @@ async def __call__(self, request: Request) -> Optional[str]: def create_access_token(data, secret_key, expires_delta): - to_encode = data.copy() - expire = utcnow() + expires_delta - to_encode.update({"exp": expire, "type": "access"}) - encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) - return encoded_jwt + return auth_tokens.create_access_token( + data, + secret_key, + expires_delta, + utcnow=utcnow, + ) def create_refresh_token(session_id, secret_key, expires_delta): - expire = utcnow() + expires_delta - to_encode = { - "type": "refresh", - "sid": session_id, - "exp": expire, - } - encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) - return encoded_jwt - - -def _decode_token_with_secret_keys(token, secret_keys): - # The first key in settings.secret_keys is used for *encoding*. - # All keys are tried for *decoding* until one works or they all - # fail. They support key rotation. - for secret_key in secret_keys: - try: - payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) - return payload - except ExpiredSignatureError: - # Do not let this be caught below with the other JWTError types. - raise - except JWTError: - # Try the next key in the key rotation. - continue - return None - - -def decode_token(token, secret_keys, proxied_authenticator=None): - credentials_exception = HTTPException( - status_code=401, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, + return auth_tokens.create_refresh_token( + session_id, + secret_key, + expires_delta, + utcnow=utcnow, ) - payload = _decode_token_with_secret_keys(token, secret_keys) - if payload is not None: - return payload - if proxied_authenticator is not None: - return proxied_authenticator.decode_token(token) - raise credentials_exception def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]: - if "scp" in decoded_access_token: - scp = decoded_access_token["scp"] - return set(scp) if isinstance(scp, list) else set(scp.split(" ")) - if "scope" in decoded_access_token: - return set(decoded_access_token["scope"].split(" ")) - return set() - - -def _get_proxied_authenticator(authenticators): - if not authenticators: - return None - for authenticator in authenticators.values(): - if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"): - return authenticator - return None + # Kept as a local alias for backward compatibility with any internal callers. + return extract_scopes(decoded_access_token) async def get_api_key( @@ -201,7 +158,7 @@ async def get_api_key( return None -def get_current_principal( +async def get_current_principal( request: Request, security_scopes: SecurityScopes, access_token: str = Depends(oauth2_scheme), @@ -306,10 +263,11 @@ def get_current_principal( request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key}) elif access_token is not None: try: + proxied = find_proxied_authenticator(authenticators) payload = decode_token( access_token, settings.secret_keys, - _get_proxied_authenticator(authenticators), + proxied_decoder=proxied.decode_token if proxied else None, ) except ExpiredSignatureError: raise HTTPException( @@ -447,7 +405,7 @@ async def get_current_principal_websocket( return None try: - principal = get_current_principal( + principal = await get_current_principal( request=websocket, security_scopes=security_scopes, access_token=access_token, diff --git a/bluesky_httpserver/app.py b/bluesky_httpserver/app.py index 84470d1..dd4ab2a 100644 --- a/bluesky_httpserver/app.py +++ b/bluesky_httpserver/app.py @@ -9,10 +9,10 @@ import urllib.parse from functools import lru_cache, partial -from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator +from bluesky_authentication.integration import AuthProviderRegistration from bluesky_queueserver.manager.comms import validate_zmq_key from bluesky_queueserver_api.zmq.aio import REManagerAPI -from fastapi import APIRouter, FastAPI, Request, Response +from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi @@ -22,6 +22,7 @@ from .resources import SERVER_RESOURCES as SR from .routers import core_api from .settings import get_settings +from .authentication import build_shared_authentication_router, oauth2_scheme from .utils import ( API_KEY_COOKIE_NAME, CSRF_COOKIE_NAME, @@ -157,70 +158,15 @@ def build_app(authentication=None, api_access=None, resource_access=None, server add_router(app, module_and_router_name=rn) logger.info("All custom routers are included successfully.") - from .authentication import ( - base_authentication_router, - build_auth_code_route, - build_authorize_route, - build_device_code_authorize_route, - build_device_code_form_route, - build_device_code_submit_route, - build_device_code_token_route, - build_handle_credentials_route, - oauth2_scheme, - ) - - authentication_router = APIRouter() - # This adds the universal routes like /session/refresh and /session/revoke. - # Below we will add routes specific to our authentication providers. - authentication_router.include_router(base_authentication_router) - if authentication.get("providers", []): # For the OpenAPI schema, inject a OAuth2PasswordBearer URL. first_provider = authentication["providers"][0]["provider"] oauth2_scheme.model.flows.password.tokenUrl = f"/api/auth/provider/{first_provider}/token" - # Authenticators provide Router(s) for their particular flow. - # Collect them in the authentication_router. - - for spec in authentication["providers"]: - provider = spec["provider"] - authenticator = spec["authenticator"] - if isinstance(authenticator, InternalAuthenticator): - authentication_router.post(f"/provider/{provider}/token")( - build_handle_credentials_route(authenticator, provider) - ) - elif isinstance(authenticator, ExternalAuthenticator): - # Standard OAuth callback route (authorization code flow) - authentication_router.get(f"/provider/{provider}/code")( - build_auth_code_route(authenticator, provider) - ) - authentication_router.post(f"/provider/{provider}/code")( - build_auth_code_route(authenticator, provider) - ) - # Device code flow routes for CLI/headless clients - # GET /authorize - redirects browser to OIDC provider - authentication_router.get(f"/provider/{provider}/authorize")( - build_authorize_route(authenticator, provider) - ) - # POST /authorize - initiates device code flow (returns device_code, user_code, etc.) - authentication_router.post(f"/provider/{provider}/authorize")( - build_device_code_authorize_route(authenticator, provider) - ) - # GET /device_code - shows user code entry form - authentication_router.get(f"/provider/{provider}/device_code")( - build_device_code_form_route(authenticator, provider) - ) - # POST /device_code - handles user code submission after browser auth - authentication_router.post(f"/provider/{provider}/device_code")( - build_device_code_submit_route(authenticator, provider) - ) - # POST /token - CLI client polls this for tokens - authentication_router.post(f"/provider/{provider}/token")( - build_device_code_token_route(authenticator, provider) - ) - else: - raise ValueError(f"unknown authenticator type {type(authenticator)}") - for custom_router in getattr(authenticator, "include_routers", []): - authentication_router.include_router(custom_router, prefix=f"/provider/{provider}") + provider_registrations = [ + AuthProviderRegistration(provider=spec["provider"], authenticator=spec["authenticator"]) + for spec in authentication.get("providers", []) + ] + authentication_router = build_shared_authentication_router(provider_registrations) # And add this authentication_router itself to the app. app.include_router(authentication_router, prefix="/api/auth") diff --git a/bluesky_httpserver/authentication/__init__.py b/bluesky_httpserver/authentication/__init__.py index 3475cd1..c4006cb 100644 --- a/bluesky_httpserver/authentication/__init__.py +++ b/bluesky_httpserver/authentication/__init__.py @@ -1,13 +1,24 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from fastapi import APIRouter + +from bluesky_authentication.integration import ( + AuthProviderRegistration, + build_authentication_router, +) + from .._authentication import ( _extract_scopes, base_authentication_router, - build_auth_code_route, - build_authorize_route, - build_device_code_authorize_route, - build_device_code_form_route, - build_device_code_submit_route, - build_device_code_token_route, - build_handle_credentials_route, + build_auth_code_route as _build_auth_code_route, + build_authorize_route as _build_authorize_route, + build_device_code_authorize_route as _build_device_code_authorize_route, + build_device_code_form_route as _build_device_code_form_route, + build_device_code_submit_route as _build_device_code_submit_route, + build_device_code_token_route as _build_device_code_token_route, + build_handle_credentials_route as _build_handle_credentials_route, get_current_principal, get_current_principal_websocket, oauth2_scheme, @@ -18,6 +29,71 @@ UserSessionState, ) + +class HttpServerAuthRouteAdapter: + def include_base_routes(self, router: APIRouter) -> None: + router.include_router(base_authentication_router) + + def build_internal_token_route(self, authenticator: InternalAuthenticator, provider: str): + return _build_handle_credentials_route(authenticator, provider) + + def build_external_code_route(self, authenticator: ExternalAuthenticator, provider: str): + return _build_auth_code_route(authenticator, provider) + + def build_external_authorize_route( + self, authenticator: ExternalAuthenticator, provider: str + ): + return _build_authorize_route(authenticator, provider) + + def build_device_code_authorize_route( + self, authenticator: ExternalAuthenticator, provider: str + ): + return _build_device_code_authorize_route(authenticator, provider) + + def build_device_code_form_route( + self, authenticator: ExternalAuthenticator, provider: str + ): + return _build_device_code_form_route(authenticator, provider) + + def build_device_code_submit_route( + self, authenticator: ExternalAuthenticator, provider: str + ): + return _build_device_code_submit_route(authenticator, provider) + + def build_device_code_token_route( + self, authenticator: ExternalAuthenticator, provider: str + ): + return _build_device_code_token_route(authenticator, provider) + + def include_authenticator_routes( + self, + router: APIRouter, + *, + provider: str, + authenticator: InternalAuthenticator | ExternalAuthenticator, + ) -> None: + for custom_router in getattr(authenticator, "include_routers", []): + router.include_router(custom_router, prefix=f"/provider/{provider}") + + +def build_shared_authentication_router( + providers: Iterable[AuthProviderRegistration], +) -> APIRouter: + return build_authentication_router( + providers, + HttpServerAuthRouteAdapter(), + external_code_methods=("GET", "POST"), + ) + + +build_auth_code_route = _build_auth_code_route +build_authorize_route = _build_authorize_route +build_device_code_authorize_route = _build_device_code_authorize_route +build_device_code_form_route = _build_device_code_form_route +build_device_code_submit_route = _build_device_code_submit_route +build_device_code_token_route = _build_device_code_token_route +build_handle_credentials_route = _build_handle_credentials_route + __all__ = [ "ExternalAuthenticator", "InternalAuthenticator", @@ -32,6 +108,8 @@ "build_device_code_form_route", "build_device_code_submit_route", "build_device_code_token_route", + "build_shared_authentication_router", "build_handle_credentials_route", + "HttpServerAuthRouteAdapter", "oauth2_scheme", ] diff --git a/bluesky_httpserver/authentication/authenticator_base.py b/bluesky_httpserver/authentication/authenticator_base.py index ae7c599..e2f7969 100644 --- a/bluesky_httpserver/authentication/authenticator_base.py +++ b/bluesky_httpserver/authentication/authenticator_base.py @@ -1,31 +1,11 @@ -try: - from bluesky_authentication.protocols import ( - ExternalAuthenticator, - InternalAuthenticator, - UserSessionState, - ) -except ModuleNotFoundError: - from abc import ABC - from dataclasses import dataclass - from typing import Optional - - from fastapi import Request - - @dataclass - class UserSessionState: - """Data transfer class to communicate custom session state information.""" - - user_name: str - state: dict = None - - class InternalAuthenticator(ABC): - """Base class for authenticators that use username/password credentials.""" - - async def authenticate(self, username: str, password: str) -> Optional[UserSessionState]: - raise NotImplementedError - - class ExternalAuthenticator(ABC): - """Base class for authenticators that use external identity providers.""" - - async def authenticate(self, request: Request) -> Optional[UserSessionState]: - raise NotImplementedError +from bluesky_authentication.protocols import ( # noqa: F401 + ExternalAuthenticator, + InternalAuthenticator, + UserSessionState, +) + +__all__ = [ + "ExternalAuthenticator", + "InternalAuthenticator", + "UserSessionState", +]