Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions app/data/repositories/layer1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ def get_column_units(self, catalog: model.RawCatalog) -> dict[str, str]:
)
return {row["column_name"]: row["unit"] for row in rows}

def get_catalog_columns(self, schema: str, table: str) -> list[dict[str, Any]]:
return self._storage.query(
"""
SELECT c.column_name,
c.data_type::text AS data_type,
(c.is_nullable = 'NO') AS not_null,
ci.param
FROM information_schema.columns c
LEFT JOIN meta.column_info ci
ON ci.schema_name = c.table_schema
AND ci.table_name = c.table_name
AND ci.column_name = c.column_name
WHERE c.table_schema = %s AND c.table_name = %s
ORDER BY c.ordinal_position
""",
params=[schema, table],
)

def save_structured_data(
self,
table: str,
Expand Down
6 changes: 5 additions & 1 deletion app/domain/adminapi/actions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import final

from app.data import repositories
from app.domain.adminapi import crossmatch, layer1_write, login, sources, table_upload
from app.domain.adminapi import catalogs, crossmatch, layer1_write, login, sources, table_upload
from app.lib import auth, cache, clients
from app.presentation import adminapi

Expand Down Expand Up @@ -29,6 +29,7 @@ def __init__(
)
self.crossmatch_manager = crossmatch.CrossmatchManager(layer0_repo, layer1_repo, layer2_repo)
self.layer1_writer = layer1_write.Layer1Writer(layer1_repo)
self.catalog_manager = catalogs.CatalogManager(layer1_repo)

def create_source(self, r: adminapi.CreateSourceRequest) -> adminapi.CreateSourceResponse:
return self.source_manager.create_source(r)
Expand All @@ -54,6 +55,9 @@ def get_table(self, r: adminapi.GetTableRequest) -> adminapi.GetTableResponse:
def get_table_list(self, r: adminapi.GetTableListRequest) -> adminapi.GetTableListResponse:
return self.table_upload_manager.get_table_list(r)

def get_catalogs(self) -> adminapi.GetCatalogsResponse:
return self.catalog_manager.get_catalogs()

def get_records(self, r: adminapi.GetRecordsRequest) -> adminapi.GetRecordsResponse:
return self.table_upload_manager.get_records(r)

Expand Down
59 changes: 59 additions & 0 deletions app/domain/adminapi/catalogs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Any, final

from app.data import model, repositories
from app.presentation import adminapi

_INTERNAL_COLUMNS = frozenset({"record_id", "object_id", "id", "modification_time"})

_CATALOG_DISPLAY: dict[model.RawCatalog, tuple[str, str]] = {
model.RawCatalog.ICRS: ("ICRS", "Equatorial coordinates in the ICRS frame."),
model.RawCatalog.DESIGNATION: ("Designations", "Object designations."),
model.RawCatalog.REDSHIFT: ("Redshift", "Heliocentric velocity (cz)."),
model.RawCatalog.NATURE: ("Nature", "Object type classification."),
model.RawCatalog.PHOTOMETRY__TOTAL: ("Photometry (total)", "Total magnitudes per band and method."),
model.RawCatalog.PHOTOMETRY__ISOPHOTAL: ("Photometry (isophotal)", "Isophotal magnitudes per band and level."),
model.RawCatalog.GEOMETRY: ("Geometry", "Isophotal ellipse geometry."),
model.RawCatalog.NOTE: ("Note", "Free-text notes attached to records."),
}


def _field_from_row(row: dict[str, Any]) -> adminapi.CatalogField:
param = row.get("param") or {}
if not isinstance(param, dict):
param = {}
description = param.get("description")
return adminapi.CatalogField(
name=row["column_name"],
data_type=adminapi.postgres_type_to_datatype(row["data_type"]),
unit=param.get("unit"),
required=bool(row["not_null"]),
ucd=param.get("ucd"),
description=str(description) if description else "",
)


@final
class CatalogManager:
def __init__(self, layer1_repo: repositories.Layer1Repository) -> None:
self._layer1_repo = layer1_repo

def get_catalogs(self) -> adminapi.GetCatalogsResponse:
catalogs: list[adminapi.CatalogSchema] = []
for catalog in model.RawCatalog:
if catalog in model.RUNTIME_RAW_CATALOGS:
continue
object_cls = model.get_catalog_object_type(catalog)
layer1_table = object_cls.layer1_table()
schema, table = layer1_table.split(".", maxsplit=1)
rows = self._layer1_repo.get_catalog_columns(schema, table)
fields = [_field_from_row(row) for row in rows if row["column_name"] not in _INTERNAL_COLUMNS]
title, description = _CATALOG_DISPLAY.get(catalog, (catalog.value, ""))
catalogs.append(
adminapi.CatalogSchema(
catalog=catalog.value,
title=title,
description=description,
fields=fields,
)
)
return adminapi.GetCatalogsResponse(catalogs=catalogs)
41 changes: 41 additions & 0 deletions app/presentation/adminapi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,43 @@ class AssignRecordPgcsResponse(pydantic.BaseModel):
pass


class CatalogField(pydantic.BaseModel):
name: str
data_type: DatatypeEnum
unit: str | None = None
required: bool = True
ucd: str | None = None
description: str = ""


class CatalogSchema(pydantic.BaseModel):
catalog: str
title: str
description: str
fields: list[CatalogField]


class GetCatalogsResponse(pydantic.BaseModel):
catalogs: list[CatalogSchema]


def postgres_type_to_datatype(pg_type: str) -> DatatypeEnum:
normalized = pg_type.lower().strip()
if normalized in {"text", "character varying", "character", "char", "user-defined"}:
return DatatypeEnum["str"]
if normalized in {"double precision", "real", "numeric"}:
return DatatypeEnum["float"]
if normalized in {"integer", "smallint"}:
return DatatypeEnum["int"]
if normalized == "bigint":
return DatatypeEnum["long"]
if normalized == "timestamp without time zone":
return DatatypeEnum["timestamp without time zone"]
if normalized in DatatypeEnum.__members__:
return DatatypeEnum[normalized]
raise ValueError(f"unsupported postgres type: {pg_type}")


class Actions(abc.ABC):
@abc.abstractmethod
def add_data(self, r: AddDataRequest) -> AddDataResponse:
Expand All @@ -394,6 +431,10 @@ def get_table(self, r: GetTableRequest) -> GetTableResponse:
def get_table_list(self, r: GetTableListRequest) -> GetTableListResponse:
pass

@abc.abstractmethod
def get_catalogs(self) -> GetCatalogsResponse:
pass

@abc.abstractmethod
def patch_table(self, r: PatchTableRequest) -> PatchTableResponse:
pass
Expand Down
11 changes: 11 additions & 0 deletions app/presentation/adminapi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def get_table_list(
response = self.actions.get_table_list(request)
return server.APIOkResponse(data=response)

def get_catalogs(self) -> server.APIOkResponse[interface.GetCatalogsResponse]:
response = self.actions.get_catalogs()
return server.APIOkResponse(data=response)

def get_records(
self, request: Annotated[interface.GetRecordsRequest, fastapi.Query()]
) -> server.APIOkResponse[interface.GetRecordsResponse]:
Expand Down Expand Up @@ -172,6 +176,13 @@ def __init__(
"List tables",
"Returns a paginated list of tables matching the search query by name or description",
),
server.Route(
"/v1/catalogs",
http.HTTPMethod.GET,
api.get_catalogs,
"List layer 1 catalog structures",
"Returns the writable column structure of every layer 1 catalog.",
),
server.Route(
"/v1/records",
http.HTTPMethod.GET,
Expand Down
84 changes: 84 additions & 0 deletions tests/integration/catalogs_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import unittest

import structlog
from starlette import testclient

from app.data import repositories
from app.domain import adminapi as domain
from app.domain.adminapi.mock import get_mock_table_stats_cache
from app.lib import audit, auth, clients
from app.lib.web import server
from app.presentation.adminapi.server import Server
from tests import lib


class CatalogsAPITest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.pg_storage = lib.TestPostgresStorage.get()

def setUp(self) -> None:
pg = self.pg_storage.get_storage()
log = structlog.get_logger()
layer0_repo = repositories.Layer0Repository(pg, log)
self.actions = domain.Actions(
common_repo=repositories.CommonRepository(pg, log),
layer0_repo=layer0_repo,
layer1_repo=repositories.Layer1Repository(pg, log),
layer2_repo=repositories.Layer2Repository(pg, log),
authenticator=auth.NoopAuthenticator(),
clients=clients.Clients(ads_token="test"),
table_stats_cache=get_mock_table_stats_cache(),
)
cfg = server.ServerConfig(host="127.0.0.1", port=0, path_prefix="/admin/api")
self.client = testclient.TestClient(
Server(
self.actions,
cfg,
log,
auth.NoopAuthenticator(),
audit.NoopActionRecorder(),
auth_enabled=False,
).app
)

def tearDown(self) -> None:
self.pg_storage.clear()

def _catalogs_by_name(self) -> dict[str, dict]:
response = self.client.get("/admin/api/v1/catalogs")
self.assertEqual(response.status_code, 200)
return {c["catalog"]: c for c in response.json()["data"]["catalogs"]}

def test_get_catalogs_icrs(self) -> None:
catalogs = self._catalogs_by_name()
icrs = catalogs["icrs"]
fields = {f["name"]: f for f in icrs["fields"]}
self.assertEqual(set(fields), {"ra", "dec", "e_ra", "e_dec"})
for name in ("ra", "dec", "e_ra", "e_dec"):
self.assertTrue(fields[name]["required"])
self.assertEqual(fields[name]["data_type"], "float")
self.assertEqual(fields["ra"]["unit"], "deg")
self.assertEqual(fields["dec"]["unit"], "deg")
self.assertEqual(fields["e_ra"]["unit"], "deg")
self.assertEqual(fields["e_dec"]["unit"], "deg")

def test_get_catalogs_geometry(self) -> None:
catalogs = self._catalogs_by_name()
geometry = catalogs["geometry"]
fields = {f["name"]: f for f in geometry["fields"]}
self.assertEqual(
set(fields),
{"band", "method", "level", "a", "e_a", "b", "e_b", "pa", "e_pa", "isophote", "e_isophote"},
)
self.assertTrue(fields["band"]["required"])
self.assertTrue(fields["method"]["required"])
self.assertFalse(fields["level"]["required"])
self.assertFalse(fields["pa"]["required"])
self.assertEqual(fields["a"]["unit"], "arcsec")
self.assertEqual(fields["isophote"]["unit"], "mag/arcmin2")
self.assertEqual(fields["method"]["data_type"], "str")

def test_runtime_catalogs_excluded(self) -> None:
catalogs = self._catalogs_by_name()
self.assertNotIn("additional_designations", catalogs)
Loading