Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 23 additions & 65 deletions bluesky_httpserver/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import secrets
import uuid as uuid_module
import warnings
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Any, Optional

from fastapi import APIRouter, Depends, Form, HTTPException, Query, Request, Response, Security, WebSocket
Expand All @@ -19,7 +19,7 @@
# int_from_bytes is deprecated, use int.from_bytes instead
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from jose import ExpiredSignatureError, JWTError, jwt
from jose import ExpiredSignatureError

import pydantic
from packaging import version
Expand All @@ -31,6 +31,9 @@
from pydantic_settings import BaseSettings

from . import schemas
from bluesky_authentication import tokens as auth_tokens
from bluesky_authentication.tokens import decode_token
from bluesky_authentication.utils import extract_scopes, find_proxied_authenticator
from .authorization._defaults import _DEFAULT_ANONYMOUS_PROVIDER_NAME
from .core import json_or_msgpack
from .database import orm
Expand All @@ -54,7 +57,6 @@
get_current_username,
)

ALGORITHM = "HS256"
UNIT_SECOND = timedelta(seconds=1)

# Device code flow constants
Expand All @@ -64,7 +66,7 @@

def utcnow():
"UTC now with second resolution"
return datetime.utcnow().replace(microsecond=0)
return datetime.now(timezone.utc).replace(microsecond=0)


class Token(BaseModel):
Expand Down Expand Up @@ -123,71 +125,26 @@ async def __call__(self, request: Request) -> Optional[str]:


def create_access_token(data, secret_key, expires_delta):
to_encode = data.copy()
expire = utcnow() + expires_delta
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt
return auth_tokens.create_access_token(
data,
secret_key,
expires_delta,
utcnow=utcnow,
)


def create_refresh_token(session_id, secret_key, expires_delta):
expire = utcnow() + expires_delta
to_encode = {
"type": "refresh",
"sid": session_id,
"exp": expire,
}
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
return encoded_jwt


def _decode_token_with_secret_keys(token, secret_keys):
# The first key in settings.secret_keys is used for *encoding*.
# All keys are tried for *decoding* until one works or they all
# fail. They support key rotation.
for secret_key in secret_keys:
try:
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
return payload
except ExpiredSignatureError:
# Do not let this be caught below with the other JWTError types.
raise
except JWTError:
# Try the next key in the key rotation.
continue
return None


def decode_token(token, secret_keys, proxied_authenticator=None):
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
return auth_tokens.create_refresh_token(
session_id,
secret_key,
expires_delta,
utcnow=utcnow,
)
payload = _decode_token_with_secret_keys(token, secret_keys)
if payload is not None:
return payload
if proxied_authenticator is not None:
return proxied_authenticator.decode_token(token)
raise credentials_exception


def _extract_scopes(decoded_access_token: dict[str, Any]) -> set[str]:
if "scp" in decoded_access_token:
scp = decoded_access_token["scp"]
return set(scp) if isinstance(scp, list) else set(scp.split(" "))
if "scope" in decoded_access_token:
return set(decoded_access_token["scope"].split(" "))
return set()


def _get_proxied_authenticator(authenticators):
if not authenticators:
return None
for authenticator in authenticators.values():
if hasattr(authenticator, "oauth2_schema") and hasattr(authenticator, "decode_token"):
return authenticator
return None
# Kept as a local alias for backward compatibility with any internal callers.
return extract_scopes(decoded_access_token)


async def get_api_key(
Expand All @@ -201,7 +158,7 @@ async def get_api_key(
return None


def get_current_principal(
async def get_current_principal(
request: Request,
security_scopes: SecurityScopes,
access_token: str = Depends(oauth2_scheme),
Expand Down Expand Up @@ -306,10 +263,11 @@ def get_current_principal(
request.state.cookies_to_set.append({"key": API_KEY_COOKIE_NAME, "value": api_key})
elif access_token is not None:
try:
proxied = find_proxied_authenticator(authenticators)
payload = decode_token(
access_token,
settings.secret_keys,
_get_proxied_authenticator(authenticators),
proxied_decoder=proxied.decode_token if proxied else None,
)
except ExpiredSignatureError:
raise HTTPException(
Expand Down Expand Up @@ -447,7 +405,7 @@ async def get_current_principal_websocket(
return None

try:
principal = get_current_principal(
principal = await get_current_principal(
request=websocket,
security_scopes=security_scopes,
access_token=access_token,
Expand Down
70 changes: 8 additions & 62 deletions bluesky_httpserver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import urllib.parse
from functools import lru_cache, partial

from bluesky_authentication.protocols import ExternalAuthenticator, InternalAuthenticator
from bluesky_authentication.integration import AuthProviderRegistration
from bluesky_queueserver.manager.comms import validate_zmq_key
from bluesky_queueserver_api.zmq.aio import REManagerAPI
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi

Expand All @@ -22,6 +22,7 @@
from .resources import SERVER_RESOURCES as SR
from .routers import core_api
from .settings import get_settings
from .authentication import build_shared_authentication_router, oauth2_scheme
from .utils import (
API_KEY_COOKIE_NAME,
CSRF_COOKIE_NAME,
Expand Down Expand Up @@ -157,70 +158,15 @@ def build_app(authentication=None, api_access=None, resource_access=None, server
add_router(app, module_and_router_name=rn)
logger.info("All custom routers are included successfully.")

from .authentication import (
base_authentication_router,
build_auth_code_route,
build_authorize_route,
build_device_code_authorize_route,
build_device_code_form_route,
build_device_code_submit_route,
build_device_code_token_route,
build_handle_credentials_route,
oauth2_scheme,
)

authentication_router = APIRouter()
# This adds the universal routes like /session/refresh and /session/revoke.
# Below we will add routes specific to our authentication providers.
authentication_router.include_router(base_authentication_router)

if authentication.get("providers", []):
# For the OpenAPI schema, inject a OAuth2PasswordBearer URL.
first_provider = authentication["providers"][0]["provider"]
oauth2_scheme.model.flows.password.tokenUrl = f"/api/auth/provider/{first_provider}/token"
# Authenticators provide Router(s) for their particular flow.
# Collect them in the authentication_router.

for spec in authentication["providers"]:
provider = spec["provider"]
authenticator = spec["authenticator"]
if isinstance(authenticator, InternalAuthenticator):
authentication_router.post(f"/provider/{provider}/token")(
build_handle_credentials_route(authenticator, provider)
)
elif isinstance(authenticator, ExternalAuthenticator):
# Standard OAuth callback route (authorization code flow)
authentication_router.get(f"/provider/{provider}/code")(
build_auth_code_route(authenticator, provider)
)
authentication_router.post(f"/provider/{provider}/code")(
build_auth_code_route(authenticator, provider)
)
# Device code flow routes for CLI/headless clients
# GET /authorize - redirects browser to OIDC provider
authentication_router.get(f"/provider/{provider}/authorize")(
build_authorize_route(authenticator, provider)
)
# POST /authorize - initiates device code flow (returns device_code, user_code, etc.)
authentication_router.post(f"/provider/{provider}/authorize")(
build_device_code_authorize_route(authenticator, provider)
)
# GET /device_code - shows user code entry form
authentication_router.get(f"/provider/{provider}/device_code")(
build_device_code_form_route(authenticator, provider)
)
# POST /device_code - handles user code submission after browser auth
authentication_router.post(f"/provider/{provider}/device_code")(
build_device_code_submit_route(authenticator, provider)
)
# POST /token - CLI client polls this for tokens
authentication_router.post(f"/provider/{provider}/token")(
build_device_code_token_route(authenticator, provider)
)
else:
raise ValueError(f"unknown authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
authentication_router.include_router(custom_router, prefix=f"/provider/{provider}")
provider_registrations = [
AuthProviderRegistration(provider=spec["provider"], authenticator=spec["authenticator"])
for spec in authentication.get("providers", [])
]
authentication_router = build_shared_authentication_router(provider_registrations)

# And add this authentication_router itself to the app.
app.include_router(authentication_router, prefix="/api/auth")
Expand Down
92 changes: 85 additions & 7 deletions bluesky_httpserver/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from __future__ import annotations

from collections.abc import Iterable

from fastapi import APIRouter

from bluesky_authentication.integration import (
AuthProviderRegistration,
build_authentication_router,
)

from .._authentication import (
_extract_scopes,
base_authentication_router,
build_auth_code_route,
build_authorize_route,
build_device_code_authorize_route,
build_device_code_form_route,
build_device_code_submit_route,
build_device_code_token_route,
build_handle_credentials_route,
build_auth_code_route as _build_auth_code_route,
build_authorize_route as _build_authorize_route,
build_device_code_authorize_route as _build_device_code_authorize_route,
build_device_code_form_route as _build_device_code_form_route,
build_device_code_submit_route as _build_device_code_submit_route,
build_device_code_token_route as _build_device_code_token_route,
build_handle_credentials_route as _build_handle_credentials_route,
get_current_principal,
get_current_principal_websocket,
oauth2_scheme,
Expand All @@ -18,6 +29,71 @@
UserSessionState,
)


class HttpServerAuthRouteAdapter:
def include_base_routes(self, router: APIRouter) -> None:
router.include_router(base_authentication_router)

def build_internal_token_route(self, authenticator: InternalAuthenticator, provider: str):
return _build_handle_credentials_route(authenticator, provider)

def build_external_code_route(self, authenticator: ExternalAuthenticator, provider: str):
return _build_auth_code_route(authenticator, provider)

def build_external_authorize_route(
self, authenticator: ExternalAuthenticator, provider: str
):
return _build_authorize_route(authenticator, provider)

def build_device_code_authorize_route(
self, authenticator: ExternalAuthenticator, provider: str
):
return _build_device_code_authorize_route(authenticator, provider)

def build_device_code_form_route(
self, authenticator: ExternalAuthenticator, provider: str
):
return _build_device_code_form_route(authenticator, provider)

def build_device_code_submit_route(
self, authenticator: ExternalAuthenticator, provider: str
):
return _build_device_code_submit_route(authenticator, provider)

def build_device_code_token_route(
self, authenticator: ExternalAuthenticator, provider: str
):
return _build_device_code_token_route(authenticator, provider)

def include_authenticator_routes(
self,
router: APIRouter,
*,
provider: str,
authenticator: InternalAuthenticator | ExternalAuthenticator,
) -> None:
for custom_router in getattr(authenticator, "include_routers", []):
router.include_router(custom_router, prefix=f"/provider/{provider}")


def build_shared_authentication_router(
providers: Iterable[AuthProviderRegistration],
) -> APIRouter:
return build_authentication_router(
providers,
HttpServerAuthRouteAdapter(),
external_code_methods=("GET", "POST"),
)


build_auth_code_route = _build_auth_code_route
build_authorize_route = _build_authorize_route
build_device_code_authorize_route = _build_device_code_authorize_route
build_device_code_form_route = _build_device_code_form_route
build_device_code_submit_route = _build_device_code_submit_route
build_device_code_token_route = _build_device_code_token_route
build_handle_credentials_route = _build_handle_credentials_route

__all__ = [
"ExternalAuthenticator",
"InternalAuthenticator",
Expand All @@ -32,6 +108,8 @@
"build_device_code_form_route",
"build_device_code_submit_route",
"build_device_code_token_route",
"build_shared_authentication_router",
"build_handle_credentials_route",
"HttpServerAuthRouteAdapter",
"oauth2_scheme",
]
Loading
Loading