Skip to content

feat(pt): add hard-coded aparam output gate for fitting nets#5495

Open
Jingbei-Bai wants to merge 6 commits into
deepmodeling:masterfrom
Jingbei-Bai:feat/aparam-output-gate
Open

feat(pt): add hard-coded aparam output gate for fitting nets#5495
Jingbei-Bai wants to merge 6 commits into
deepmodeling:masterfrom
Jingbei-Bai:feat/aparam-output-gate

Conversation

@Jingbei-Bai
Copy link
Copy Markdown

@Jingbei-Bai Jingbei-Bai commented Jun 4, 2026

Summary by CodeRabbit

  • New Features

    • Optional atomic-parameter output gating for fitting models (three new config options); example config updated.
    • New variational-Gaussian averaged smooth descriptor (se_a_vg) and supporting descriptor implementation and exports.
  • Bug Fixes / Behavior

    • Gating applied after output statistics/bias so gated atomic outputs honor aparam inputs.
    • Descriptor handling now accepts optional aparam without changing existing outputs.
  • Tests

    • Added tests for gate behavior, gate formula, serialization, and se_a_vg descriptor.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copilot AI review requested due to automatic review settings June 4, 2026 10:58
@dosubot dosubot Bot added the new feature label Jun 4, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds an optional “aparam output gate” to fitting networks (PyTorch + array-API/dpmodel) and exposes it through configuration/argcheck, with tests and an example config update.

Changes:

  • Introduce use_aparam_output_gate, aparam_gate_norm, and aparam_gate_clamp arguments, including validation + serialization.
  • Apply the gate in PT fitting forward paths and dpmodel array-API fitting path.
  • Add a dedicated PT unit test and update an example training input JSON.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
source/tests/pt/model/test_aparam_output_gate.py Adds unit tests validating gating behavior + serialize roundtrip
examples/fparam/train/input_aparam.json Demonstrates new config knobs in an example input
deepmd/utils/argcheck.py Exposes new fitting arguments and documentation strings
deepmd/pt/model/task/sezm_ener.py Extracts raw aparam and applies output gate in SE(Z/M) energy path
deepmd/pt/model/task/fitting.py Implements gate computation/application and wires it into common forward path
deepmd/dpmodel/fitting/general_fitting.py Mirrors gate logic for dpmodel/array-API fitting path

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

aparam_gate_norm=1.0,
aparam_gate_clamp=True,
).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))
aparam_gate_norm=norm,
aparam_gate_clamp=False,
).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

fitting_gate = fitting._compute_aparam_output_gate(aparam)
expected = (a_val * a_val) / (sigma * sigma * norm)
self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
Comment on lines +818 to +823
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam_raw = aparam.view([nf, -1, self.numb_aparam])
Comment on lines +698 to +703
if aparam.numel() % (nf * self.numb_aparam) != 0:
raise ValueError(
f"input aparam: cannot reshape {list(aparam.shape)} "
f"into ({nf}, nloc, {self.numb_aparam})."
)
aparam_raw = aparam.view([nf, -1, self.numb_aparam])
self,
aparam_raw: torch.Tensor,
) -> torch.Tensor:
"""Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam."""
Comment on lines +774 to +777
if self.numb_aparam > 1:
gate = gate.prod(dim=-1, keepdim=True)
if self.aparam_gate_clamp:
gate = gate.clamp(0.0, 1.0)
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 4, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds an optional multiplicative "aparam output gate" for fitting outputs and a new variational-Gaussian descriptor (se_a_vg). Implements gate compute/apply helpers in dpmodel and PT, wires gating into atomic output flow and SeZM, adds args/examples/tests, and adds VG descriptor code, exports, and tests.

Changes

Aparam Output Gate Feature

Layer / File(s) Summary
Configuration args, example, and tests
deepmd/utils/argcheck.py, examples/fparam/train/input_aparam.json, source/tests/pt/model/test_aparam_output_gate.py
Adds fitting_aparam_output_gate_args() and inserts the fields into fitting schemas; updates example JSON; adds tests verifying zero-aparam zeroing, formula, and serialize roundtrip.
dpmodel GeneralFitting and atomic hook
deepmd/dpmodel/fitting/general_fitting.py, deepmd/dpmodel/atomic_model/base_atomic_model.py
Adds use_aparam_output_gate, aparam_gate_norm, aparam_gate_clamp to constructor with validation and serialization; reshapes input aparam to aparam_raw in _call_common; adds _compute_aparam_output_gate/_apply_aparam_output_gate and apply_aparam_output_gate_to_atomic_output; applies gate post out-stat via atomic-model helper.
PyTorch GeneralFitting, atomic flow, and SeZM integration
deepmd/pt/model/task/fitting.py, deepmd/pt/model/atomic_model/base_atomic_model.py, deepmd/pt/model/atomic_model/dp_atomic_model.py, deepmd/pt/model/task/sezm_ener.py
Mirrors dpmodel in PT: ctor params/validation/attrs/serialize, gate compute/apply helpers (TorchScript wrapper), _forward_common builds/validates aparam_raw and uses it for aparam_embed, enforces presence when gating enabled, conditionally forwards aparam to descriptors, and applies gate after out-stat; SeZM forward builds aparam_raw and uses normalized aparam_embed.

se_a_vg Variational-Gaussian Descriptor

Layer / File(s) Summary
VG env-mat utilities and tabulation
deepmd/pt/model/descriptor/env_mat_vg.py
Adds VG radial kernels (vg_gaussian_radial_phi, vg_smooth_radial), neighbor sigma gathering, VG env-matrix builder _make_env_mat_vg and normalizer prod_env_mat_vg, plus two-stage tabulation helper tabulate_fusion_se_a_vg.
Descriptor block and wrapper
deepmd/pt/model/descriptor/se_a_vg.py
Adds DescrptBlockSeAVg with per-type embedding networks, stats collection (compute_input_stats), compression enablement, and forward; adds DescrptSeAVg wrapper with registration, forward, compression wiring, serialize/update_sel.
Integration, exports, and args
deepmd/pt/model/descriptor/__init__.py, deepmd/pt/model/descriptor/se_a.py, deepmd/utils/argcheck.py, deepmd/pt/model/model/__init__.py
Expose new descriptor classes, accept/ignore aparam in existing DescrptSeA.forward signature, register se_a_vg args as alias, include se_a_vg in compression options, and broaden get_spin_model sel-expansion to the se_a family.
se_a_vg unit tests
source/tests/pt/model/test_se_a_vg.py
Add tests verifying sigma=0 equivalence to SE, sensitivity to aparam, forward shapes, and optional compression numerical match when fused op exists.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • deepmodeling/deepmd-kit#5391: Modifies GeneralFitting (__init__ and _call_common) and extends results dict; code-level overlap with the aparam-gating plumbing.

Suggested reviewers

  • wanghan-iapcm
  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 29.41% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt): add hard-coded aparam output gate for fitting nets' clearly summarizes the main change: introducing a new aparam output gate feature for PyTorch fitting networks.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@source/tests/pt/model/test_aparam_output_gate.py`:
- Line 73: The assertion compares fitting_gate (which lives on env.DEVICE) to a
CPU tensor; create the expected tensor on the same device to avoid
device-mismatch failures: when constructing torch.tensor(expected, dtype=dtype)
in the test (the line comparing fitting_gate), pass device=env.DEVICE or call
.to(env.DEVICE) so the expected tensor matches fitting_gate's device before
calling torch.allclose.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 73a16adf-a7d6-43fb-adcb-ab9b7391afb8

📥 Commits

Reviewing files that changed from the base of the PR and between fb6ff93 and afcf8b0.

📒 Files selected for processing (6)
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/task/sezm_ener.py
  • deepmd/utils/argcheck.py
  • examples/fparam/train/input_aparam.json
  • source/tests/pt/model/test_aparam_output_gate.py


fitting_gate = fitting._compute_aparam_output_gate(aparam)
expected = (a_val * a_val) / (sigma * sigma * norm)
self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fix expected tensor device in gate-formula assertion.

At Line 73, torch.tensor(expected, dtype=dtype) is created on CPU, while fitting_gate is on env.DEVICE, which can fail on GPU/MPS.

Proposed fix
-        self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
+        self.assertTrue(
+            torch.allclose(
+                fitting_gate,
+                torch.tensor(expected, dtype=dtype, device=device),
+            )
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/model/test_aparam_output_gate.py` at line 73, The assertion
compares fitting_gate (which lives on env.DEVICE) to a CPU tensor; create the
expected tensor on the same device to avoid device-mismatch failures: when
constructing torch.tensor(expected, dtype=dtype) in the test (the line comparing
fitting_gate), pass device=env.DEVICE or call .to(env.DEVICE) so the expected
tensor matches fitting_gate's device before calling torch.allclose.

).to(device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device)
fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype))

descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device)
atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device)
WAbjb1314 and others added 2 commits June 5, 2026 15:40
Gate previously multiplied only the fitting output before apply_out_stat added per-type out_bias, so sigma=0 inference still returned non-zero energy. Apply the gate after out_bias in the atomic model forward path.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/task/fitting.py`:
- Around line 807-813: The current validation only checks aparam_raw.shape[-1]
against self.numb_aparam but does not ensure the leading dims match outs, which
can hide shape mismatches; update the block handling aparam in the method (the
aparam_raw conversion and return of _apply_aparam_output_gate) to explicitly
attempt to reshape aparam_raw to (outs.shape[0], outs.shape[1],
self.numb_aparam) and if that reshape is impossible raise a clear ValueError
describing the expected shape (using outs.shape[0], outs.shape[1],
self.numb_aparam), otherwise use the reshaped tensor for the subsequent call to
_apply_aparam_output_gate so broadcasting/multiplication errors are avoided.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: c9aa14cb-8701-4b9f-9544-1e433e241562

📥 Commits

Reviewing files that changed from the base of the PR and between 499b230 and e5c598b.

📒 Files selected for processing (6)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/fitting/general_fitting.py
  • deepmd/pt/model/atomic_model/base_atomic_model.py
  • deepmd/pt/model/task/fitting.py
  • deepmd/pt/model/task/sezm_ener.py
  • source/tests/pt/model/test_aparam_output_gate.py
💤 Files with no reviewable changes (1)
  • deepmd/pt/model/task/sezm_ener.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • source/tests/pt/model/test_aparam_output_gate.py

Comment on lines +807 to +813
aparam_raw = aparam.to(self.prec)
if aparam_raw.shape[-1] != self.numb_aparam:
raise ValueError(
f"input aparam last dim {aparam_raw.shape[-1]} does not match "
f"numb_aparam={self.numb_aparam}"
)
return self._apply_aparam_output_gate(outs, aparam_raw)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Shape validation inconsistency with dpmodel backend.

The PyTorch implementation only validates aparam_raw.shape[-1] == self.numb_aparam (line 808), whereas the dpmodel version (deepmd/dpmodel/fitting/general_fitting.py lines 673-680) explicitly reshapes aparam to (outs.shape[0], outs.shape[1], self.numb_aparam) and raises a clear error if the reshape fails.

If aparam has an unexpected shape (e.g., (X, Y, numb_aparam) where X ≠ nf or Y ≠ nloc), the current PyTorch validation would pass but the subsequent multiplication outs * gate (line 813 → line 792) might produce incorrect results or fail with a generic broadcasting error.

Suggested fix to align shape validation with dpmodel
 `@torch.jit.export`
 def apply_aparam_output_gate_to_atomic_output(
     self,
     outs: torch.Tensor,
     aparam: torch.Tensor | None,
 ) -> torch.Tensor:
     """Apply the aparam gate to atomic outputs after out_bias is added."""
     if not self.use_aparam_output_gate:
         return outs
     if aparam is None:
         raise ValueError(
             "aparam is required when use_aparam_output_gate is enabled"
         )
-    aparam_raw = aparam.to(self.prec)
-    if aparam_raw.shape[-1] != self.numb_aparam:
+    aparam = aparam.to(self.prec)
+    nf, nloc = outs.shape[0], outs.shape[1]
+    if aparam.numel() != nf * nloc * self.numb_aparam:
         raise ValueError(
-            f"input aparam last dim {aparam_raw.shape[-1]} does not match "
-            f"numb_aparam={self.numb_aparam}"
+            f"input aparam: cannot reshape {list(aparam.shape)} "
+            f"into ({nf}, {nloc}, {self.numb_aparam})."
         )
+    aparam_raw = aparam.view(nf, nloc, self.numb_aparam)
     return self._apply_aparam_output_gate(outs, aparam_raw)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/task/fitting.py` around lines 807 - 813, The current
validation only checks aparam_raw.shape[-1] against self.numb_aparam but does
not ensure the leading dims match outs, which can hide shape mismatches; update
the block handling aparam in the method (the aparam_raw conversion and return of
_apply_aparam_output_gate) to explicitly attempt to reshape aparam_raw to
(outs.shape[0], outs.shape[1], self.numb_aparam) and if that reshape is
impossible raise a clear ValueError describing the expected shape (using
outs.shape[0], outs.shape[1], self.numb_aparam), otherwise use the reshaped
tensor for the subsequent call to _apply_aparam_output_gate so
broadcasting/multiplication errors are avoided.

WAbjb1314 and others added 2 commits June 5, 2026 16:47
Introduce variational-Gaussian smooth descriptor (se_a_vg) where sigma
enters the radial kernel and a fifth environment-matrix column per VGM II.
Wire aparam through the atomic model, register argcheck/compression hooks,
and keep fitting output gate unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)

2285-2290: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Enforce a positive lower bound for aparam_gate_norm.

Line 2285 currently accepts 0/negative values, but this factor is used in a denominator; that can cause divide-by-zero or invalid gate scaling at runtime. Add schema validation here to fail fast.

Proposed fix
         Argument(
             "aparam_gate_norm",
             float,
             optional=True,
             default=1.0,
+            extra_check=lambda x: x > 0.0,
+            extra_check_errmsg="must be greater than 0",
             doc=doc_aparam_gate_norm,
         ),
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/utils/argcheck.py` around lines 2285 - 2290, The schema entry for
"aparam_gate_norm" currently allows zero/negative values which can cause
divide-by-zero; update the argument schema in deepmd/utils/argcheck.py for the
"aparam_gate_norm" field to enforce a strict positive lower bound (e.g., min >
0) or add a validator that raises an error if the provided value is <= 0 so the
check fails fast; locate the schema definition containing "aparam_gate_norm",
adjust its validation rules (or add a custom validator function) to reject
non-positive values while keeping the default=1.0.
🧹 Nitpick comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)

290-291: ⚡ Quick win

Avoid per-call inspect.signature(...) in the forward hot path.

Line 290 recomputes the signature every batch. Cache whether aparam is supported once in __init__ and reuse it in forward_atomic.

Proposed refactor
 class DPAtomicModel(BaseAtomicModel):
@@
     def __init__(
@@
         self.eval_descriptor_list = []
         self.eval_fitting_last_layer_list = []
+        self._descriptor_accepts_aparam = (
+            "aparam" in inspect.signature(self.descriptor.forward).parameters
+        )
@@
-        if "aparam" in inspect.signature(self.descriptor.forward).parameters:
+        if self._descriptor_accepts_aparam:
             descriptor_kwargs["aparam"] = aparam
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/atomic_model/dp_atomic_model.py` around lines 290 - 291, The
code currently calls inspect.signature(self.descriptor.forward) inside
forward_atomic for every batch; instead determine once in __init__ whether the
descriptor.forward accepts "aparam" (e.g. set self._descriptor_supports_aparam =
"aparam" in inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.
deepmd/pt/model/descriptor/se_a.py (1)

325-365: ⚡ Quick win

Document the aparam parameter in the docstring.

The forward method now accepts an aparam parameter (line 325) but it is not documented in the docstring (lines 333-364). Even though DescrptSeA does not use this parameter (it's immediately deleted on line 365), documenting it helps maintain API clarity and aids developers who may reference this signature.

📝 Suggested docstring addition
         comm_dict
             The data needed for communication for parallel inference.
+        aparam
+            Atomic parameters. Not used by this descriptor; accepted for
+            interface compatibility with aparam-aware descriptors.
 
         Returns
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a.py` around lines 325 - 365, The forward
method signature includes an unused parameter aparam that is deleted
immediately; update the forward docstring to include a brief description of
aparam (type torch.Tensor | None, optional), its purpose or note that it is
accepted for API compatibility and currently unused, and mention that it will be
ignored (or deleted) within DescrptSeA.forward to clarify behavior for callers
and maintainers; locate the method by the forward function in this module and
add the aparam entry to the existing Parameters section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Line 48: sigma_loc is being cast to nlist.dtype which is integer and truncates
fractional aparam values; remove the cast to nlist.dtype and instead preserve
aparam's floating dtype (or if device alignment is needed, cast only the device:
.to(device=nlist.device) or .to(dtype=aparam.dtype, device=nlist.device)).
Update the assignment for sigma_loc (from aparam[:, :nloc, 0]) to stop
converting to nlist.dtype so sigma retains its fractional values.

In `@deepmd/pt/model/descriptor/se_a_vg.py`:
- Around line 603-609: The serialization currently hardcodes "trainable": True;
change it to emit the actual flag from the descriptor object (e.g., use
obj.trainable or getattr(obj, "trainable", True)) in the serialize() output so a
descriptor created with trainable=False round-trips correctly; update the entry
replacing the literal True with the object's trainable attribute in the block
that builds the dict (the one referencing "`@variables`", "type_map", and
"type_one_side").
- Around line 282-305: The aparam tensor (aparam_t) is only sized for local
atoms but prod_env_mat_vg expects per-atom parameters for the extended atom set
used by nlist; expand aparam to the extended-atom domain before calling
prod_env_mat_vg by mapping local aparam values to extended indices (use the same
mapping that produced extended_atype/extended_coord from
extend_input_and_build_neighbor_list) to create extended_aparam (matching
extended_atype.shape/length) and pass that instead of aparam_t; apply the same
fix at the second call site around prod_env_mat_vg later in the file (the block
at ~378-401) so halo/periodic neighbors use the correct VG parameters.
- Around line 134-140: Normalize sel to a list before computing its length: move
or duplicate the normalization (self.sel = sel if isinstance(sel, list) else
[sel]) so that self.ntypes is set from len(self.sel) (not len(sel)), and then
compute self.sec using np.cumsum(self.sel); update the assignments around
self.ntypes, self.sel, and self.sec in __init__ (or set self.ntypes =
len(self.sel) immediately after the existing self.sel assignment) to avoid
calling len() on an int.

---

Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2285-2290: The schema entry for "aparam_gate_norm" currently
allows zero/negative values which can cause divide-by-zero; update the argument
schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a
strict positive lower bound (e.g., min > 0) or add a validator that raises an
error if the provided value is <= 0 so the check fails fast; locate the schema
definition containing "aparam_gate_norm", adjust its validation rules (or add a
custom validator function) to reject non-positive values while keeping the
default=1.0.

---

Nitpick comments:
In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 290-291: The code currently calls
inspect.signature(self.descriptor.forward) inside forward_atomic for every
batch; instead determine once in __init__ whether the descriptor.forward accepts
"aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in
inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.

In `@deepmd/pt/model/descriptor/se_a.py`:
- Around line 325-365: The forward method signature includes an unused parameter
aparam that is deleted immediately; update the forward docstring to include a
brief description of aparam (type torch.Tensor | None, optional), its purpose or
note that it is accepted for API compatibility and currently unused, and mention
that it will be ignored (or deleted) within DescrptSeA.forward to clarify
behavior for callers and maintainers; locate the method by the forward function
in this module and add the aparam entry to the existing Parameters section.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e1861c11-00bd-407a-bffa-34f107386e4f

📥 Commits

Reviewing files that changed from the base of the PR and between e5c598b and 444982e.

📒 Files selected for processing (8)
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/descriptor/__init__.py
  • deepmd/pt/model/descriptor/env_mat_vg.py
  • deepmd/pt/model/descriptor/se_a.py
  • deepmd/pt/model/descriptor/se_a_vg.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/model/test_se_a_vg.py

) -> torch.Tensor:
"""Map per-atom aparam to neighbor-list sigma values."""
nf, _, nnei = nlist.shape
sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Don't cast sigma to the neighbor-index dtype.

nlist is integer typed, so line 48 truncates fractional aparam values before sigma_ij is computed. For typical sub-unit sigmas, that collapses the VG kernel toward the nongaussian path.

Suggested fix
-    sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)
+    sigma_loc = aparam[:, :nloc, 0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)
sigma_loc = aparam[:, :nloc, 0]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/env_mat_vg.py` at line 48, sigma_loc is being cast
to nlist.dtype which is integer and truncates fractional aparam values; remove
the cast to nlist.dtype and instead preserve aparam's floating dtype (or if
device alignment is needed, cast only the device: .to(device=nlist.device) or
.to(dtype=aparam.dtype, device=nlist.device)). Update the assignment for
sigma_loc (from aparam[:, :nloc, 0]) to stop converting to nlist.dtype so sigma
retains its fractional values.

Comment on lines +134 to +140
self.ntypes = len(sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)

self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Normalize scalar sel before calling len().

The public API accepts sel: int | list[int], but line 134 evaluates len(sel) before the scalar case is converted on line 139. sel=64 will fail in __init__.

Suggested fix
-        self.ntypes = len(sel)
+        self.sel = sel if isinstance(sel, list) else [sel]
+        self.ntypes = len(self.sel)
         self.type_one_side = type_one_side
         self.seed = seed
         self.reinit_exclude(exclude_types)
-
-        self.sel = sel if isinstance(sel, list) else [sel]
         self.sec = [0, *np.cumsum(self.sel).tolist()]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.ntypes = len(sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]
self.sel = sel if isinstance(sel, list) else [sel]
self.ntypes = len(self.sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
self.sec = [0, *np.cumsum(self.sel).tolist()]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 134 - 140, Normalize sel
to a list before computing its length: move or duplicate the normalization
(self.sel = sel if isinstance(sel, list) else [sel]) so that self.ntypes is set
from len(self.sel) (not len(sel)), and then compute self.sec using
np.cumsum(self.sel); update the assignments around self.ntypes, self.sel, and
self.sec in __init__ (or set self.ntypes = len(self.sel) immediately after the
existing self.sel assignment) to avoid calling len() on an int.

Comment on lines +282 to +305
aparam_t = torch.tensor(
aparam_np[ff], dtype=self.prec, device=env.DEVICE
).reshape(1, nloc, 1)
extended_coord, extended_atype, _, nlist = (
extend_input_and_build_neighbor_list(
coord_t,
atype_t,
self.rcut,
self.sel,
mixed_types=False,
box=box_t,
)
)
env_mat, _, _ = prod_env_mat_vg(
extended_coord,
nlist,
extended_atype[:, :nloc],
aparam_t,
self.mean,
torch.ones_like(self.stddev),
self.rcut,
self.rcut_smth,
protection=self.env_protection,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Expand aparam to the extended-atom domain before building the VG env-mat.

Both call sites hand prod_env_mat_vg() an aparam tensor sized only for local atoms, while nlist indexes extended_coord. That makes every periodic/halo neighbor fall back to the zero-filled sigma path, so training stats and runtime descriptors diverge from the intended VG kernel whenever an extended image is selected. Please build an extended_aparam with the same mapping used for the extended coordinates before calling the env-mat helper.

Also applies to: 378-401

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 282 - 305, The aparam
tensor (aparam_t) is only sized for local atoms but prod_env_mat_vg expects
per-atom parameters for the extended atom set used by nlist; expand aparam to
the extended-atom domain before calling prod_env_mat_vg by mapping local aparam
values to extended indices (use the same mapping that produced
extended_atype/extended_coord from extend_input_and_build_neighbor_list) to
create extended_aparam (matching extended_atype.shape/length) and pass that
instead of aparam_t; apply the same fix at the second call site around
prod_env_mat_vg later in the file (the block at ~378-401) so halo/periodic
neighbors use the correct VG parameters.

Comment on lines +603 to +609
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
"type_map": self.type_map,
"trainable": True,
"type_one_side": obj.type_one_side,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Serialize the real trainable flag.

serialize() always emits "trainable": True, so a descriptor created with trainable=False will round-trip as trainable after save/load or compression setup.

Suggested fix
-            "trainable": True,
+            "trainable": obj.trainable,
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 603 - 609, The
serialization currently hardcodes "trainable": True; change it to emit the
actual flag from the descriptor object (e.g., use obj.trainable or getattr(obj,
"trainable", True)) in the serialize() output so a descriptor created with
trainable=False round-trips correctly; update the entry replacing the literal
True with the object's trainable attribute in the block that builds the dict
(the one referencing "`@variables`", "type_map", and "type_one_side").

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two blocking issues in the new VG env-mat path. Both are inline below.

— OpenClaw 2026.5.28 (model: custom-chat-jinzhezeng-group/gpt-5.5)

) -> torch.Tensor:
"""Map per-atom aparam to neighbor-list sigma values."""
nf, _, nnei = nlist.shape
sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This casts sigma values to nlist.dtype (int64), so fractional aparam/sigma values such as 0.5 become 0 before building sigma_ij. That makes the VG descriptor ignore non-integer sigma values and also explains why the current tests only compare 0 vs 1. Please keep this in the floating dtype/device of aparam instead of the neighbor-list dtype.

device=sigma_loc.device,
)
sigma_ext[:, :nloc] = sigma_loc
index = nlist.reshape(nf, -1)
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nlist can contain -1 padding entries, but this gathers with the raw neighbor list. torch.gather does not accept negative indices, so any padded neighbor list will raise at runtime. _make_env_mat_vg already builds nlist_safe; this path needs the same masking/safe-index handling before gathering sigma values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants