Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions scripts/inject_strict_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,24 @@
# `root-model-extra`), so skip them. Their behavior is governed by the
# inner type, which on its own enforces strict validation.
CLASS_RE = re.compile(r"^class\s+([A-Za-z_][\w]*)\s*\(\s*(BaseModel)\s*\)\s*:\s*$")
CONFIG_LINE = " model_config = ConfigDict(extra='forbid', populate_by_name=True)"
CONFIG_LINE_STRICT = " model_config = ConfigDict(extra='forbid', populate_by_name=True)"
CONFIG_LINE_TOLERANT = " model_config = ConfigDict(extra='ignore', populate_by_name=True)"


def _is_response_shape(class_name: str) -> bool:
"""Response-shape classes tolerate unknown fields (Postel's Law)."""
if class_name[0].islower():
return False
if class_name.endswith(("Request", "Params")):
return False
return bool(
class_name.endswith(("Dto", "Response"))
or class_name.startswith(("SingleValueResponse", "TableValueResult", "CursorPage"))
)


# Keep the old name for backward compat in case anything imports it
CONFIG_LINE = CONFIG_LINE_STRICT

# Doc-banner injections keyed by class name. Inserted as a leading docstring
# inside the target class so the note shows up in IDE hovers and stays put
Expand Down Expand Up @@ -130,6 +147,7 @@ def inject(source: str) -> tuple[str, int]:
i += 1
continue
class_name = m.group(1)
config_line = CONFIG_LINE_TOLERANT if _is_response_shape(class_name) else CONFIG_LINE_STRICT
# Look at the very next line. If it's already model_config or pass,
# leave the class alone (idempotency / empty class).
next_idx = i + 1
Expand All @@ -141,10 +159,10 @@ def inject(source: str) -> tuple[str, int]:
out.append(f' """{banner}"""\n')
modified += 1
if "model_config" in next_line:
# Upgrade the existing config line to include populate_by_name=True
# if it isn't already there. Idempotent across re-runs.
if "populate_by_name" not in next_line:
out.append(CONFIG_LINE + "\n")
# Replace the existing config line if it doesn't match the
# desired strictness or is missing populate_by_name.
if next_line.strip() != config_line.strip():
out.append(config_line + "\n")
i += 2 # replace the existing model_config line
modified += 1
continue
Expand All @@ -154,11 +172,11 @@ def inject(source: str) -> tuple[str, int]:
# exact match (NOT startswith) — fields like `passed: Annotated[...]`
# also start with "pass" but are not empty class markers.
if next_line.strip() in ("pass", "pass\n"):
out.append(CONFIG_LINE + "\n")
out.append(config_line + "\n")
i += 2 # skip the pass
modified += 1
continue
out.append(CONFIG_LINE + "\n")
out.append(config_line + "\n")
modified += 1
i += 1
return "".join(out), modified
Expand Down
Loading
Loading