feat(pt): add hard-coded aparam output gate for fitting nets#5495
feat(pt): add hard-coded aparam output gate for fitting nets#5495Jingbei-Bai wants to merge 6 commits into
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds an optional “aparam output gate” to fitting networks (PyTorch + array-API/dpmodel) and exposes it through configuration/argcheck, with tests and an example config update.
Changes:
- Introduce
use_aparam_output_gate,aparam_gate_norm, andaparam_gate_clamparguments, including validation + serialization. - Apply the gate in PT fitting forward paths and dpmodel array-API fitting path.
- Add a dedicated PT unit test and update an example training input JSON.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/model/test_aparam_output_gate.py | Adds unit tests validating gating behavior + serialize roundtrip |
| examples/fparam/train/input_aparam.json | Demonstrates new config knobs in an example input |
| deepmd/utils/argcheck.py | Exposes new fitting arguments and documentation strings |
| deepmd/pt/model/task/sezm_ener.py | Extracts raw aparam and applies output gate in SE(Z/M) energy path |
| deepmd/pt/model/task/fitting.py | Implements gate computation/application and wires it into common forward path |
| deepmd/dpmodel/fitting/general_fitting.py | Mirrors gate logic for dpmodel/array-API fitting path |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| aparam_gate_norm=1.0, | ||
| aparam_gate_clamp=True, | ||
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) |
| aparam_gate_norm=norm, | ||
| aparam_gate_clamp=False, | ||
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) |
|
|
||
| fitting_gate = fitting._compute_aparam_output_gate(aparam) | ||
| expected = (a_val * a_val) / (sigma * sigma * norm) | ||
| self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype))) |
| if aparam.numel() % (nf * self.numb_aparam) != 0: | ||
| raise ValueError( | ||
| f"input aparam: cannot reshape {list(aparam.shape)} " | ||
| f"into ({nf}, nloc, {self.numb_aparam})." | ||
| ) | ||
| aparam_raw = aparam.view([nf, -1, self.numb_aparam]) |
| if aparam.numel() % (nf * self.numb_aparam) != 0: | ||
| raise ValueError( | ||
| f"input aparam: cannot reshape {list(aparam.shape)} " | ||
| f"into ({nf}, nloc, {self.numb_aparam})." | ||
| ) | ||
| aparam_raw = aparam.view([nf, -1, self.numb_aparam]) |
| self, | ||
| aparam_raw: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam.""" |
| if self.numb_aparam > 1: | ||
| gate = gate.prod(dim=-1, keepdim=True) | ||
| if self.aparam_gate_clamp: | ||
| gate = gate.clamp(0.0, 1.0) |
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds an optional multiplicative "aparam output gate" for fitting outputs and a new variational-Gaussian descriptor (se_a_vg). Implements gate compute/apply helpers in dpmodel and PT, wires gating into atomic output flow and SeZM, adds args/examples/tests, and adds VG descriptor code, exports, and tests. ChangesAparam Output Gate Feature
se_a_vg Variational-Gaussian Descriptor
Estimated code review effort 🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@source/tests/pt/model/test_aparam_output_gate.py`:
- Line 73: The assertion compares fitting_gate (which lives on env.DEVICE) to a
CPU tensor; create the expected tensor on the same device to avoid
device-mismatch failures: when constructing torch.tensor(expected, dtype=dtype)
in the test (the line comparing fitting_gate), pass device=env.DEVICE or call
.to(env.DEVICE) so the expected tensor matches fitting_gate's device before
calling torch.allclose.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 73a16adf-a7d6-43fb-adcb-ab9b7391afb8
📒 Files selected for processing (6)
deepmd/dpmodel/fitting/general_fitting.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/task/sezm_ener.pydeepmd/utils/argcheck.pyexamples/fparam/train/input_aparam.jsonsource/tests/pt/model/test_aparam_output_gate.py
|
|
||
| fitting_gate = fitting._compute_aparam_output_gate(aparam) | ||
| expected = (a_val * a_val) / (sigma * sigma * norm) | ||
| self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype))) |
There was a problem hiding this comment.
Fix expected tensor device in gate-formula assertion.
At Line 73, torch.tensor(expected, dtype=dtype) is created on CPU, while fitting_gate is on env.DEVICE, which can fail on GPU/MPS.
Proposed fix
- self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
+ self.assertTrue(
+ torch.allclose(
+ fitting_gate,
+ torch.tensor(expected, dtype=dtype, device=device),
+ )
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/pt/model/test_aparam_output_gate.py` at line 73, The assertion
compares fitting_gate (which lives on env.DEVICE) to a CPU tensor; create the
expected tensor on the same device to avoid device-mismatch failures: when
constructing torch.tensor(expected, dtype=dtype) in the test (the line comparing
fitting_gate), pass device=env.DEVICE or call .to(env.DEVICE) so the expected
tensor matches fitting_gate's device before calling torch.allclose.
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) | ||
|
|
||
| descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) |
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) | ||
|
|
||
| descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) | ||
| atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device) |
Gate previously multiplied only the fitting output before apply_out_stat added per-type out_bias, so sigma=0 inference still returned non-zero energy. Apply the gate after out_bias in the atomic model forward path. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/task/fitting.py`:
- Around line 807-813: The current validation only checks aparam_raw.shape[-1]
against self.numb_aparam but does not ensure the leading dims match outs, which
can hide shape mismatches; update the block handling aparam in the method (the
aparam_raw conversion and return of _apply_aparam_output_gate) to explicitly
attempt to reshape aparam_raw to (outs.shape[0], outs.shape[1],
self.numb_aparam) and if that reshape is impossible raise a clear ValueError
describing the expected shape (using outs.shape[0], outs.shape[1],
self.numb_aparam), otherwise use the reshaped tensor for the subsequent call to
_apply_aparam_output_gate so broadcasting/multiplication errors are avoided.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c9aa14cb-8701-4b9f-9544-1e433e241562
📒 Files selected for processing (6)
deepmd/dpmodel/atomic_model/base_atomic_model.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/pt/model/atomic_model/base_atomic_model.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/task/sezm_ener.pysource/tests/pt/model/test_aparam_output_gate.py
💤 Files with no reviewable changes (1)
- deepmd/pt/model/task/sezm_ener.py
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/model/test_aparam_output_gate.py
| aparam_raw = aparam.to(self.prec) | ||
| if aparam_raw.shape[-1] != self.numb_aparam: | ||
| raise ValueError( | ||
| f"input aparam last dim {aparam_raw.shape[-1]} does not match " | ||
| f"numb_aparam={self.numb_aparam}" | ||
| ) | ||
| return self._apply_aparam_output_gate(outs, aparam_raw) |
There was a problem hiding this comment.
Shape validation inconsistency with dpmodel backend.
The PyTorch implementation only validates aparam_raw.shape[-1] == self.numb_aparam (line 808), whereas the dpmodel version (deepmd/dpmodel/fitting/general_fitting.py lines 673-680) explicitly reshapes aparam to (outs.shape[0], outs.shape[1], self.numb_aparam) and raises a clear error if the reshape fails.
If aparam has an unexpected shape (e.g., (X, Y, numb_aparam) where X ≠ nf or Y ≠ nloc), the current PyTorch validation would pass but the subsequent multiplication outs * gate (line 813 → line 792) might produce incorrect results or fail with a generic broadcasting error.
Suggested fix to align shape validation with dpmodel
`@torch.jit.export`
def apply_aparam_output_gate_to_atomic_output(
self,
outs: torch.Tensor,
aparam: torch.Tensor | None,
) -> torch.Tensor:
"""Apply the aparam gate to atomic outputs after out_bias is added."""
if not self.use_aparam_output_gate:
return outs
if aparam is None:
raise ValueError(
"aparam is required when use_aparam_output_gate is enabled"
)
- aparam_raw = aparam.to(self.prec)
- if aparam_raw.shape[-1] != self.numb_aparam:
+ aparam = aparam.to(self.prec)
+ nf, nloc = outs.shape[0], outs.shape[1]
+ if aparam.numel() != nf * nloc * self.numb_aparam:
raise ValueError(
- f"input aparam last dim {aparam_raw.shape[-1]} does not match "
- f"numb_aparam={self.numb_aparam}"
+ f"input aparam: cannot reshape {list(aparam.shape)} "
+ f"into ({nf}, {nloc}, {self.numb_aparam})."
)
+ aparam_raw = aparam.view(nf, nloc, self.numb_aparam)
return self._apply_aparam_output_gate(outs, aparam_raw)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/task/fitting.py` around lines 807 - 813, The current
validation only checks aparam_raw.shape[-1] against self.numb_aparam but does
not ensure the leading dims match outs, which can hide shape mismatches; update
the block handling aparam in the method (the aparam_raw conversion and return of
_apply_aparam_output_gate) to explicitly attempt to reshape aparam_raw to
(outs.shape[0], outs.shape[1], self.numb_aparam) and if that reshape is
impossible raise a clear ValueError describing the expected shape (using
outs.shape[0], outs.shape[1], self.numb_aparam), otherwise use the reshaped
tensor for the subsequent call to _apply_aparam_output_gate so
broadcasting/multiplication errors are avoided.
Introduce variational-Gaussian smooth descriptor (se_a_vg) where sigma enters the radial kernel and a fifth environment-matrix column per VGM II. Wire aparam through the atomic model, register argcheck/compression hooks, and keep fitting output gate unchanged. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)
2285-2290:⚠️ Potential issue | 🟠 Major | ⚡ Quick winEnforce a positive lower bound for
aparam_gate_norm.Line 2285 currently accepts
0/negative values, but this factor is used in a denominator; that can cause divide-by-zero or invalid gate scaling at runtime. Add schema validation here to fail fast.Proposed fix
Argument( "aparam_gate_norm", float, optional=True, default=1.0, + extra_check=lambda x: x > 0.0, + extra_check_errmsg="must be greater than 0", doc=doc_aparam_gate_norm, ),🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/utils/argcheck.py` around lines 2285 - 2290, The schema entry for "aparam_gate_norm" currently allows zero/negative values which can cause divide-by-zero; update the argument schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a strict positive lower bound (e.g., min > 0) or add a validator that raises an error if the provided value is <= 0 so the check fails fast; locate the schema definition containing "aparam_gate_norm", adjust its validation rules (or add a custom validator function) to reject non-positive values while keeping the default=1.0.
🧹 Nitpick comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
290-291: ⚡ Quick winAvoid per-call
inspect.signature(...)in the forward hot path.Line 290 recomputes the signature every batch. Cache whether
aparamis supported once in__init__and reuse it inforward_atomic.Proposed refactor
class DPAtomicModel(BaseAtomicModel): @@ def __init__( @@ self.eval_descriptor_list = [] self.eval_fitting_last_layer_list = [] + self._descriptor_accepts_aparam = ( + "aparam" in inspect.signature(self.descriptor.forward).parameters + ) @@ - if "aparam" in inspect.signature(self.descriptor.forward).parameters: + if self._descriptor_accepts_aparam: descriptor_kwargs["aparam"] = aparam🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/atomic_model/dp_atomic_model.py` around lines 290 - 291, The code currently calls inspect.signature(self.descriptor.forward) inside forward_atomic for every batch; instead determine once in __init__ whether the descriptor.forward accepts "aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in inspect.signature(self.descriptor.forward).parameters) and then in forward_atomic replace the per-call inspect.signature check with that cached boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any initialization path that sets self.descriptor to ensure the cached flag is computed after descriptor is assigned.deepmd/pt/model/descriptor/se_a.py (1)
325-365: ⚡ Quick winDocument the
aparamparameter in the docstring.The
forwardmethod now accepts anaparamparameter (line 325) but it is not documented in the docstring (lines 333-364). Even thoughDescrptSeAdoes not use this parameter (it's immediately deleted on line 365), documenting it helps maintain API clarity and aids developers who may reference this signature.📝 Suggested docstring addition
comm_dict The data needed for communication for parallel inference. + aparam + Atomic parameters. Not used by this descriptor; accepted for + interface compatibility with aparam-aware descriptors. Returns🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/se_a.py` around lines 325 - 365, The forward method signature includes an unused parameter aparam that is deleted immediately; update the forward docstring to include a brief description of aparam (type torch.Tensor | None, optional), its purpose or note that it is accepted for API compatibility and currently unused, and mention that it will be ignored (or deleted) within DescrptSeA.forward to clarify behavior for callers and maintainers; locate the method by the forward function in this module and add the aparam entry to the existing Parameters section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Line 48: sigma_loc is being cast to nlist.dtype which is integer and truncates
fractional aparam values; remove the cast to nlist.dtype and instead preserve
aparam's floating dtype (or if device alignment is needed, cast only the device:
.to(device=nlist.device) or .to(dtype=aparam.dtype, device=nlist.device)).
Update the assignment for sigma_loc (from aparam[:, :nloc, 0]) to stop
converting to nlist.dtype so sigma retains its fractional values.
In `@deepmd/pt/model/descriptor/se_a_vg.py`:
- Around line 603-609: The serialization currently hardcodes "trainable": True;
change it to emit the actual flag from the descriptor object (e.g., use
obj.trainable or getattr(obj, "trainable", True)) in the serialize() output so a
descriptor created with trainable=False round-trips correctly; update the entry
replacing the literal True with the object's trainable attribute in the block
that builds the dict (the one referencing "`@variables`", "type_map", and
"type_one_side").
- Around line 282-305: The aparam tensor (aparam_t) is only sized for local
atoms but prod_env_mat_vg expects per-atom parameters for the extended atom set
used by nlist; expand aparam to the extended-atom domain before calling
prod_env_mat_vg by mapping local aparam values to extended indices (use the same
mapping that produced extended_atype/extended_coord from
extend_input_and_build_neighbor_list) to create extended_aparam (matching
extended_atype.shape/length) and pass that instead of aparam_t; apply the same
fix at the second call site around prod_env_mat_vg later in the file (the block
at ~378-401) so halo/periodic neighbors use the correct VG parameters.
- Around line 134-140: Normalize sel to a list before computing its length: move
or duplicate the normalization (self.sel = sel if isinstance(sel, list) else
[sel]) so that self.ntypes is set from len(self.sel) (not len(sel)), and then
compute self.sec using np.cumsum(self.sel); update the assignments around
self.ntypes, self.sel, and self.sec in __init__ (or set self.ntypes =
len(self.sel) immediately after the existing self.sel assignment) to avoid
calling len() on an int.
---
Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2285-2290: The schema entry for "aparam_gate_norm" currently
allows zero/negative values which can cause divide-by-zero; update the argument
schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a
strict positive lower bound (e.g., min > 0) or add a validator that raises an
error if the provided value is <= 0 so the check fails fast; locate the schema
definition containing "aparam_gate_norm", adjust its validation rules (or add a
custom validator function) to reject non-positive values while keeping the
default=1.0.
---
Nitpick comments:
In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 290-291: The code currently calls
inspect.signature(self.descriptor.forward) inside forward_atomic for every
batch; instead determine once in __init__ whether the descriptor.forward accepts
"aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in
inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.
In `@deepmd/pt/model/descriptor/se_a.py`:
- Around line 325-365: The forward method signature includes an unused parameter
aparam that is deleted immediately; update the forward docstring to include a
brief description of aparam (type torch.Tensor | None, optional), its purpose or
note that it is accepted for API compatibility and currently unused, and mention
that it will be ignored (or deleted) within DescrptSeA.forward to clarify
behavior for callers and maintainers; locate the method by the forward function
in this module and add the aparam entry to the existing Parameters section.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e1861c11-00bd-407a-bffa-34f107386e4f
📒 Files selected for processing (8)
deepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/descriptor/__init__.pydeepmd/pt/model/descriptor/env_mat_vg.pydeepmd/pt/model/descriptor/se_a.pydeepmd/pt/model/descriptor/se_a_vg.pydeepmd/pt/model/model/__init__.pydeepmd/utils/argcheck.pysource/tests/pt/model/test_se_a_vg.py
| ) -> torch.Tensor: | ||
| """Map per-atom aparam to neighbor-list sigma values.""" | ||
| nf, _, nnei = nlist.shape | ||
| sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype) |
There was a problem hiding this comment.
Don't cast sigma to the neighbor-index dtype.
nlist is integer typed, so line 48 truncates fractional aparam values before sigma_ij is computed. For typical sub-unit sigmas, that collapses the VG kernel toward the nongaussian path.
Suggested fix
- sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype)
+ sigma_loc = aparam[:, :nloc, 0]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype) | |
| sigma_loc = aparam[:, :nloc, 0] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/env_mat_vg.py` at line 48, sigma_loc is being cast
to nlist.dtype which is integer and truncates fractional aparam values; remove
the cast to nlist.dtype and instead preserve aparam's floating dtype (or if
device alignment is needed, cast only the device: .to(device=nlist.device) or
.to(dtype=aparam.dtype, device=nlist.device)). Update the assignment for
sigma_loc (from aparam[:, :nloc, 0]) to stop converting to nlist.dtype so sigma
retains its fractional values.
| self.ntypes = len(sel) | ||
| self.type_one_side = type_one_side | ||
| self.seed = seed | ||
| self.reinit_exclude(exclude_types) | ||
|
|
||
| self.sel = sel if isinstance(sel, list) else [sel] | ||
| self.sec = [0, *np.cumsum(self.sel).tolist()] |
There was a problem hiding this comment.
Normalize scalar sel before calling len().
The public API accepts sel: int | list[int], but line 134 evaluates len(sel) before the scalar case is converted on line 139. sel=64 will fail in __init__.
Suggested fix
- self.ntypes = len(sel)
+ self.sel = sel if isinstance(sel, list) else [sel]
+ self.ntypes = len(self.sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
-
- self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.ntypes = len(sel) | |
| self.type_one_side = type_one_side | |
| self.seed = seed | |
| self.reinit_exclude(exclude_types) | |
| self.sel = sel if isinstance(sel, list) else [sel] | |
| self.sec = [0, *np.cumsum(self.sel).tolist()] | |
| self.sel = sel if isinstance(sel, list) else [sel] | |
| self.ntypes = len(self.sel) | |
| self.type_one_side = type_one_side | |
| self.seed = seed | |
| self.reinit_exclude(exclude_types) | |
| self.sec = [0, *np.cumsum(self.sel).tolist()] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 134 - 140, Normalize sel
to a list before computing its length: move or duplicate the normalization
(self.sel = sel if isinstance(sel, list) else [sel]) so that self.ntypes is set
from len(self.sel) (not len(sel)), and then compute self.sec using
np.cumsum(self.sel); update the assignments around self.ntypes, self.sel, and
self.sec in __init__ (or set self.ntypes = len(self.sel) immediately after the
existing self.sel assignment) to avoid calling len() on an int.
| aparam_t = torch.tensor( | ||
| aparam_np[ff], dtype=self.prec, device=env.DEVICE | ||
| ).reshape(1, nloc, 1) | ||
| extended_coord, extended_atype, _, nlist = ( | ||
| extend_input_and_build_neighbor_list( | ||
| coord_t, | ||
| atype_t, | ||
| self.rcut, | ||
| self.sel, | ||
| mixed_types=False, | ||
| box=box_t, | ||
| ) | ||
| ) | ||
| env_mat, _, _ = prod_env_mat_vg( | ||
| extended_coord, | ||
| nlist, | ||
| extended_atype[:, :nloc], | ||
| aparam_t, | ||
| self.mean, | ||
| torch.ones_like(self.stddev), | ||
| self.rcut, | ||
| self.rcut_smth, | ||
| protection=self.env_protection, | ||
| ) |
There was a problem hiding this comment.
Expand aparam to the extended-atom domain before building the VG env-mat.
Both call sites hand prod_env_mat_vg() an aparam tensor sized only for local atoms, while nlist indexes extended_coord. That makes every periodic/halo neighbor fall back to the zero-filled sigma path, so training stats and runtime descriptors diverge from the intended VG kernel whenever an extended image is selected. Please build an extended_aparam with the same mapping used for the extended coordinates before calling the env-mat helper.
Also applies to: 378-401
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 282 - 305, The aparam
tensor (aparam_t) is only sized for local atoms but prod_env_mat_vg expects
per-atom parameters for the extended atom set used by nlist; expand aparam to
the extended-atom domain before calling prod_env_mat_vg by mapping local aparam
values to extended indices (use the same mapping that produced
extended_atype/extended_coord from extend_input_and_build_neighbor_list) to
create extended_aparam (matching extended_atype.shape/length) and pass that
instead of aparam_t; apply the same fix at the second call site around
prod_env_mat_vg later in the file (the block at ~378-401) so halo/periodic
neighbors use the correct VG parameters.
| "@variables": { | ||
| "davg": obj["davg"].detach().cpu().numpy(), | ||
| "dstd": obj["dstd"].detach().cpu().numpy(), | ||
| }, | ||
| "type_map": self.type_map, | ||
| "trainable": True, | ||
| "type_one_side": obj.type_one_side, |
There was a problem hiding this comment.
Serialize the real trainable flag.
serialize() always emits "trainable": True, so a descriptor created with trainable=False will round-trip as trainable after save/load or compression setup.
Suggested fix
- "trainable": True,
+ "trainable": obj.trainable,🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 603 - 609, The
serialization currently hardcodes "trainable": True; change it to emit the
actual flag from the descriptor object (e.g., use obj.trainable or getattr(obj,
"trainable", True)) in the serialize() output so a descriptor created with
trainable=False round-trips correctly; update the entry replacing the literal
True with the object's trainable attribute in the block that builds the dict
(the one referencing "`@variables`", "type_map", and "type_one_side").
njzjz-bot
left a comment
There was a problem hiding this comment.
I found two blocking issues in the new VG env-mat path. Both are inline below.
— OpenClaw 2026.5.28 (model: custom-chat-jinzhezeng-group/gpt-5.5)
| ) -> torch.Tensor: | ||
| """Map per-atom aparam to neighbor-list sigma values.""" | ||
| nf, _, nnei = nlist.shape | ||
| sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype) |
There was a problem hiding this comment.
This casts sigma values to nlist.dtype (int64), so fractional aparam/sigma values such as 0.5 become 0 before building sigma_ij. That makes the VG descriptor ignore non-integer sigma values and also explains why the current tests only compare 0 vs 1. Please keep this in the floating dtype/device of aparam instead of the neighbor-list dtype.
| device=sigma_loc.device, | ||
| ) | ||
| sigma_ext[:, :nloc] = sigma_loc | ||
| index = nlist.reshape(nf, -1) |
There was a problem hiding this comment.
nlist can contain -1 padding entries, but this gathers with the raw neighbor list. torch.gather does not accept negative indices, so any padded neighbor list will raise at runtime. _make_env_mat_vg already builds nlist_safe; this path needs the same masking/safe-index handling before gathering sigma values.
Summary by CodeRabbit
New Features
Bug Fixes / Behavior
Tests