feat(pt): Add DPA4/SeZM descriptor & model 🎉🎉🎉#5448
Conversation
| # === Step 1. Setup dimensions === | ||
| extended_coord = extended_coord.to(self.compute_dtype) | ||
| nf, nloc, nnei = nlist.shape | ||
| nall = extended_coord.shape[1] |
| try: | ||
| import triton # noqa: F401 | ||
| except ImportError: | ||
| SEZM_TRITON_AVAILABLE = False |
| except ImportError: | ||
| SEZM_TRITON_AVAILABLE = False | ||
| else: | ||
| SEZM_TRITON_AVAILABLE = True |
| else: | ||
| SEZM_TRITON_AVAILABLE = True | ||
| else: | ||
| SEZM_TRITON_AVAILABLE = False |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds SeZM descriptor stack and NN building blocks (SO(3)/SO(2), radial, activations, FFN, attention), Triton kernels and autograd, DeNS fitting and loss, SeZM atomic model and model factories, PT2 AOTInductor freezer, and dpmodel utilities (min-pair-distance and JAX nlist fix). ChangesSeZM descriptor, NN building blocks, and integrations
Triton kernels, autograd, dispatch, and custom ops
PT2 freezer, CLI, and JIT gating
dpmodel utilities and LMDB reader
Estimated code review effort 🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)
1340-1342: 💤 Low valueAdd
strict=Truetozip()calls to catch length mismatches.The three module lists (
so2_linears,so2_inter_norms,non_linearities) are constructed to have equal length, but addingstrict=Trueprovides a runtime guard against future refactoring errors.Proposed fix
for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) + zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True) ):And similarly for line 1365:
for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( - zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) + zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True) ):Also applies to: 1364-1366
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 1340 - 1342, The zip over the three module lists (so2_linears, so2_inter_norms, non_linearities) should use strict=True to surface any length mismatches at runtime; update the zip(...) calls in the for loops that iterate over (so2_linear, inter_norm, non_linear) (and the similar loop around lines 1364–1366) to call zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True) so a ValueError is raised if the lists differ in length.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/utils/dist_check.py`:
- Around line 45-56: The current implementation materializes diff =
real_coord[np.newaxis, :, :] - real_coord[:, np.newaxis, :] which creates an
(n_real, n_real, 3) array and O(n^2) memory; replace this with a blocked (or
streaming) pairwise-min computation to avoid full NxN allocation: iterate over
chunks/blocks of indices (e.g., outer blocks i and inner blocks j), compute the
block-wise displacement between real_coord[i:i_b] and real_coord[j:j_b], apply
the PBC transform using cell and inv_cell (reuse the same frac_diff -=
round(frac_diff) and back-transform logic on the block), compute block-wise
dist_sq, ignore self-pair entries when i==j (mask diagonal within the block),
and accumulate the global minimum distance (or a per-atom min array) instead of
assembling dist_sq for all pairs; update code locations around symbols diff,
frac_diff, cell, inv_cell, and dist_sq to perform block-wise operations and
maintain identical results with bounded memory.
In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 377-380: The type check in freeze_sezm_to_pt2 currently only
allows "sezm" and rejects valid "dpa4" checkpoints; update the validation to
accept both "sezm" and "dpa4" (case-insensitive) by checking params.get("type")
against the set {"sezm", "dpa4"} so that is_sezm_checkpoint()-routed dpa4
checkpoints are not erroneously rejected.
In `@deepmd/pt/model/descriptor/sezm_nn/indexing.py`:
- Line 70: Rename the ambiguous single-letter identifier `l` to a clearer name
such as `ell` throughout indexing.py to satisfy Ruff E741; specifically update
list comprehensions like `[2 * l + 1 for l in range(lmax + 1)]` to `[2 * ell + 1
for ell in range(lmax + 1)]`, and rename the loop/parameter occurrences
referenced in the file (including the other reported occurrences around the same
blocks) so function signatures, list comprehensions, for-loops and any uses
inside functions/methods (e.g., variables used alongside `lmax`) are
consistently changed to `ell` (and all references updated) to avoid breaking
references. Ensure tests/uses of symbols from functions/methods in this module
reflect the new name and run `ruff check .` to confirm E741 is resolved.
In `@deepmd/pt/model/descriptor/sezm_nn/radial.py`:
- Around line 230-241: The constructor currently allows non-positive rcut which
later causes invalid divisions; in the class __init__ (the shown __init__) and
the other constructor/initializer that assigns self.rcut (the occurrence around
lines 487-488), add a strict check that rcut > 0 and raise ValueError("`rcut`
must be positive") if not, then convert rcut to float and assign to self.rcut
(and only derive widths or compute 1/rcut after that check) so all downstream
computations are safe.
---
Nitpick comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1340-1342: The zip over the three module lists (so2_linears,
so2_inter_norms, non_linearities) should use strict=True to surface any length
mismatches at runtime; update the zip(...) calls in the for loops that iterate
over (so2_linear, inter_norm, non_linear) (and the similar loop around lines
1364–1366) to call zip(self.so2_linears, self.so2_inter_norms,
self.non_linearities, strict=True) so a ValueError is raised if the lists differ
in length.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9c1d7bdc-93ba-406f-8498-2f63fd44c217
📒 Files selected for processing (84)
deepmd/dpmodel/utils/dist_check.pydeepmd/dpmodel/utils/lmdb_data.pydeepmd/dpmodel/utils/nlist.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/entrypoints/main.pydeepmd/pt/infer/deep_eval.pydeepmd/pt/loss/__init__.pydeepmd/pt/loss/dens.pydeepmd/pt/model/atomic_model/__init__.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/__init__.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/__init__.pydeepmd/pt/model/descriptor/sezm_nn/activation.pydeepmd/pt/model/descriptor/sezm_nn/attention.pydeepmd/pt/model/descriptor/sezm_nn/attn_res.pydeepmd/pt/model/descriptor/sezm_nn/block.pydeepmd/pt/model/descriptor/sezm_nn/dens.pydeepmd/pt/model/descriptor/sezm_nn/edge_cache.pydeepmd/pt/model/descriptor/sezm_nn/embedding.pydeepmd/pt/model/descriptor/sezm_nn/ffn.pydeepmd/pt/model/descriptor/sezm_nn/indexing.pydeepmd/pt/model/descriptor/sezm_nn/lebedev.pydeepmd/pt/model/descriptor/sezm_nn/lebedev_rules.npzdeepmd/pt/model/descriptor/sezm_nn/lora.pydeepmd/pt/model/descriptor/sezm_nn/norm.pydeepmd/pt/model/descriptor/sezm_nn/radial.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/descriptor/sezm_nn/so3.pydeepmd/pt/model/descriptor/sezm_nn/triton/__init__.pydeepmd/pt/model/descriptor/sezm_nn/triton/autograd.pydeepmd/pt/model/descriptor/sezm_nn/triton/constants.pydeepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.pydeepmd/pt/model/descriptor/sezm_nn/triton/dispatch.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.pydeepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.pydeepmd/pt/model/descriptor/sezm_nn/utils.pydeepmd/pt/model/descriptor/sezm_nn/wignerd.pydeepmd/pt/model/model/__init__.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/sezm_spin_model.pydeepmd/pt/model/model/spin_model.pydeepmd/pt/model/network/mlp.pydeepmd/pt/model/task/__init__.pydeepmd/pt/model/task/sezm_ener.pydeepmd/pt/train/training.pydeepmd/pt/train/utils.pydeepmd/pt/train/validation.pydeepmd/pt/utils/multi_task.pydeepmd/pt/utils/serialization.pydeepmd/utils/argcheck.pydeepmd/utils/data.pydeepmd/utils/data_system.pydoc/model/dpa4.mddoc/model/index.rstexamples/water/dpa4/.gitignoreexamples/water/dpa4/README.mdexamples/water/dpa4/input-spin.jsonexamples/water/dpa4/input-zbl.jsonexamples/water/dpa4/input.jsonexamples/water/dpa4/input_dens.jsonexamples/water/dpa4/input_multitask.jsonexamples/water/dpa4/input_multitask_sharefit-zbl.jsonexamples/water/dpa4/input_multitask_sharefit.jsonexamples/water/dpa4/lmp/.gitignoreexamples/water/dpa4/lmp/README.mdexamples/water/dpa4/lmp/in.lammpsexamples/water/dpa4/lmp/input.jsonexamples/water/dpa4/lmp/pretrained.ptexamples/water/dpa4/lmp/water.lmpexamples/water/dpa4/lora_ft.jsonpyproject.tomlsource/tests/common/dpmodel/test_dist_check.pysource/tests/common/dpmodel/test_lmdb_data.pysource/tests/pt/model/test_descriptor_sezm.pysource/tests/pt/model/test_descriptor_sezm_s2_equivariance.pysource/tests/pt/model/test_descriptor_sezm_triton.pysource/tests/pt/model/test_sezm_export.pysource/tests/pt/model/test_sezm_model.pysource/tests/pt/model/test_sezm_spin_model.pysource/tests/pt/requirements.txtsource/tests/pt/test_train_utils.pysource/tests/pt/test_training.py
njzjz-bot
left a comment
There was a problem hiding this comment.
Thanks for the large DPA-4/SeZM integration. I focused this pass on the new model registration and the freeze/export path, since that is the user-facing path that determines whether trained checkpoints can be deployed.
I found one blocking issue in the new .pt2 freeze routing:
deepmd/pt/entrypoints/main.py::freeze() routes both type: "sezm" and type: "dpa4" checkpoints to freeze_sezm_to_pt2() because is_sezm_checkpoint() returns true for both aliases. However, freeze_sezm_to_pt2() then rejects anything whose normalized model_params["type"] is not exactly "sezm":
if str(params.get("type", "")).lower() != "sezm":
raise ValueError(...)As a result, a checkpoint saved with the documented/registered "dpa4" model alias cannot be frozen at all: the legacy TorchScript path is bypassed, then the new AOTInductor path raises before model construction. This should accept both aliases, e.g. not in {"sezm", "dpa4"}, and the error message should probably mention both names.
Please add a regression test that constructs or mocks a checkpoint with _extra_state.model_params.type == "dpa4" and verifies dp --pt freeze reaches the .pt2 exporter rather than raising at this guard.
Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
njzjz-bot
left a comment
There was a problem hiding this comment.
Reviewed the PR more broadly and left inline comments on the blocking issues I found.
Authored by OpenClaw 2026.5.12 (model: gpt-5.5)
njzjz-bot
left a comment
There was a problem hiding this comment.
A few more issues I found in the non-Python/example files.
Authored by OpenClaw 2026.5.12 (model: gpt-5.5)
| inv_cell = np.linalg.inv(cell) | ||
| else: | ||
| cell = None | ||
| inv_cell = None |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)
1341-1343: 💤 Low valueConsider adding
strict=Trueto zip() calls for defensive programming.Both loops zip
so2_linears,so2_inter_norms, andnon_linearities, which are constructed with the same length in__init__. Addingstrict=Truewould catch any future inconsistency at runtime.Also applies to: 1365-1367
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 1341 - 1343, The two for-loops that iterate with zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) should use zip(..., strict=True) to enforce equal-length sequences at runtime; update both occurrences (the loop that assigns layer_idx, so2_linear, inter_norm, non_linear and the later similar loop) so they pass strict=True to zip to catch future mismatches defensively.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1341-1343: The two for-loops that iterate with
zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) should use
zip(..., strict=True) to enforce equal-length sequences at runtime; update both
occurrences (the loop that assigns layer_idx, so2_linear, inter_norm, non_linear
and the later similar loop) so they pass strict=True to zip to catch future
mismatches defensively.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 354878b9-f7dd-4177-8113-519352bf7892
📒 Files selected for processing (19)
deepmd/dpmodel/utils/dist_check.pydeepmd/dpmodel/utils/lmdb_data.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/loss/dens.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/indexing.pydeepmd/pt/model/descriptor/sezm_nn/radial.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/sezm_spin_model.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pydeepmd/utils/data.pydoc/model/dpa4.mdexamples/water/dpa4/lmp/README.mdexamples/water/dpa4/lmp/in.lammpssource/tests/common/test_examples.pysource/tests/pt/model/test_sezm_export.py
🚧 Files skipped from review as they are similar to previous changes (5)
- deepmd/dpmodel/utils/lmdb_data.py
- deepmd/pt/model/descriptor/sezm_nn/indexing.py
- deepmd/pt/model/descriptor/sezm_nn/radial.py
- deepmd/pt/loss/dens.py
- deepmd/pt/model/atomic_model/sezm_atomic_model.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5448 +/- ##
==========================================
- Coverage 82.48% 79.25% -3.23%
==========================================
Files 830 865 +35
Lines 88521 95778 +7257
Branches 4232 4244 +12
==========================================
+ Hits 73014 75913 +2899
- Misses 14219 18744 +4525
+ Partials 1288 1121 -167 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 6
♻️ Duplicate comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)
865-868:⚠️ Potential issue | 🟠 Major | ⚡ Quick winRecompute the Triton-rotation gate from the runtime device.
self.use_triton_rotationsis frozen fromenv.DEVICEin__init__(), but this module is later moved withmodel.to(...)during freeze/export. If it is constructed on CUDA and then exported on CPU, the eval path still takes the Triton branches and hits a device mismatch. Gate on the runtime tensor device instead of the cached construction device.#!/bin/bash # Verify that the Triton gate is computed once at init and reused in both forward branches. rg -n -C2 'use_triton_rotations|sezm_triton_enabled|rotate_to_local_triton|rotate_back_triton|def _apply' deepmd/pt/model/descriptor/sezm_nn/so2.pyExpected result:
self.use_triton_rotationsis assigned in__init__, consumed in both rotation branches, and there is no_apply()override that recomputes it aftermodel.to(...).Also applies to: 1264-1265, 1410-1411
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 865 - 868, The flag self.use_triton_rotations is set once in __init__ using sezm_triton_enabled(device=self.device, dtype=self.dtype) and then reused during forward rotations, causing device-mismatch after model.to(...); change the branches in rotate_to_local_triton and rotate_back_triton (and any other rotation call sites) to re-evaluate sezm_triton_enabled at runtime using the actual tensor/device (e.g., use x.device or the rotation tensor's device and dtype) instead of the cached self.use_triton_rotations, and remove reliance on a frozen value (keep sezm_triton_enabled import and call it where the Triton vs fallback branch is selected so the correct backend is chosen after model.to(...)); also ensure there is no _apply override that leaves the old flag stale.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 268-285: ntypes should be derived from the model's actual reported
type count rather than len(model.get_type_map()), because get_type_map() may be
empty; replace the ntypes assignment so it queries the model for its canonical
type count (e.g., model.get_type_count() or equivalent API) and fall back to
max(1, len(model.get_type_map())) if that method is unavailable, then use that
ntypes when populating atype_np in the loop that assigns i % ntypes.
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1341-1366: The two for-loops that iterate over
zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) raise Ruff
B905; update both zips to use strict=True (i.e., zip(..., strict=True)) so
mismatched lengths will error and the linter warning is resolved — change the
zip in the first loop (around so2_layer_attn_res usage and so2_depth_sources
appends) and the zip in the else branch to include strict=True while leaving the
rest of the loop bodies unchanged.
- Around line 324-329: The cached tensor self._cached_weight can become stale or
on the wrong device after module.to(...) or load_state_dict(); update the class
to invalidate this cache on device/dtype moves and state loads by overriding
nn.Module._apply(self, fn) to set self._cached_weight = None (so moves clear the
cache) and/or override _load_from_state_dict(...) to clear self._cached_weight
when parameters are loaded; alternatively (or in addition) add a runtime check
where the cache is used (the weight-access/forward path that reads
self._cached_weight) to verify tensor.device and tensor.dtype match the module
parameters before reusing, and if not, set self._cached_weight = None and
rebuild.
In `@deepmd/pt/model/descriptor/sezm.py`:
- Around line 1937-1941: The current loop that strips every unexpected key under
the descriptor prefix (using prefix, state_dict(), expected_keys and iterating
over state_dict.keys()) must be replaced with a targeted whitelist of only known
transient keys that are intentionally rebuilt at construction; stop
blanket-popping unknown keys so load_state_dict(strict=True) can report
unexpected_keys. Update the logic to define an explicit set/list of transient
keys (or a clear pattern) that you will remove (e.g., transient_keys = {...} or
check exact suffixes), only pop entries that match those transient keys under
prefix, and then call load_state_dict(strict=True) so any other mismatches
surface as errors. Ensure references to prefix and expected_keys remain correct
when computing which keys are transient vs genuine parameters.
- Around line 1137-1152: forward_with_edges bypasses forward's normalization of
charge_spin, causing _apply_charge_spin_embedding to receive uncanonicalized
values; update forward_with_edges (the sparse-edge entrypoint where
type_embedding is applied) to canonicalize charge_spin the same way forward()
does: if charge_spin is None substitute default_chg_spin, ensure
shape/broadcasting matches (nf, nloc, 2) or the expected per-atom shape, and
convert dtype/device to self.compute_dtype and extended_coord/device before
passing into _apply_charge_spin_embedding; alternatively, call the existing
normalization helper used by forward() from within forward_with_edges to avoid
duplicating logic.
---
Duplicate comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 865-868: The flag self.use_triton_rotations is set once in
__init__ using sezm_triton_enabled(device=self.device, dtype=self.dtype) and
then reused during forward rotations, causing device-mismatch after
model.to(...); change the branches in rotate_to_local_triton and
rotate_back_triton (and any other rotation call sites) to re-evaluate
sezm_triton_enabled at runtime using the actual tensor/device (e.g., use
x.device or the rotation tensor's device and dtype) instead of the cached
self.use_triton_rotations, and remove reliance on a frozen value (keep
sezm_triton_enabled import and call it where the Triton vs fallback branch is
selected so the correct backend is chosen after model.to(...)); also ensure
there is no _apply override that leaves the old flag stale.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bb4636e1-275a-4666-ad13-46895d7fbed7
📒 Files selected for processing (20)
deepmd/dpmodel/utils/dist_check.pydeepmd/dpmodel/utils/lmdb_data.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/loss/dens.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/indexing.pydeepmd/pt/model/descriptor/sezm_nn/radial.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/sezm_spin_model.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pydeepmd/utils/data.pydoc/model/dpa4.mdexamples/water/dpa4/lmp/README.mdexamples/water/dpa4/lmp/in.lammpssource/tests/common/test_examples.pysource/tests/pt/model/test_descriptor_sezm_s2_equivariance.pysource/tests/pt/model/test_sezm_export.py
🚧 Files skipped from review as they are similar to previous changes (6)
- deepmd/pt/model/descriptor/sezm_nn/indexing.py
- deepmd/dpmodel/utils/lmdb_data.py
- deepmd/dpmodel/utils/dist_check.py
- deepmd/pt/model/descriptor/sezm_nn/radial.py
- deepmd/pt/model/atomic_model/sezm_atomic_model.py
- deepmd/pt/loss/dens.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 429-430: Update the docstring for the head parameter to reflect
current behavior: instead of stating "must be ``None``", document that the code
will auto-select "Default" for multitask checkpoints and otherwise requires an
explicit head name (or allow passing None to trigger auto-selection), and
mention the accepted values (e.g., "Default" or explicit head string) so callers
know the runtime semantics; target the docstring entry labeled head in
freeze_pt2.py (the docblock that currently says "Reserved for future multi-task
support; must be ``None``") and replace that sentence with the new, accurate
description.
- Around line 228-247: The metadata writer _collect_metadata() currently writes
an empty "type_map" (from model.get_type_map()) for SeZM/DPA4 models which leads
DeepEval._init_from_metadata() and DeepPotPTExpt to infer ntypes=0; fix
_collect_metadata() in freeze_pt2.py so that when list(model.get_type_map()) is
empty you populate a usable type map or explicit ntypes: e.g., set
metadata["type_map"] to a fallback list(range(N)) derived from
model.get_sel_type() or model.get_sel() (use len(model.get_sel_type()) or
len(model.get_sel()) to compute N), or include metadata["ntypes"]=int(N); ensure
the fallback logic uses model.get_sel_type()/model.get_sel() (and handles spin
if is_spin) so metadata-only loads give the correct ntypes for
DeepEval._init_from_metadata() and DeepPotPTExpt to consume.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e9a3035f-d385-4b45-8538-5a266b7d86e4
📒 Files selected for processing (21)
deepmd/dpmodel/utils/dist_check.pydeepmd/dpmodel/utils/lmdb_data.pydeepmd/pt/entrypoints/freeze_pt2.pydeepmd/pt/loss/dens.pydeepmd/pt/model/atomic_model/sezm_atomic_model.pydeepmd/pt/model/descriptor/sezm.pydeepmd/pt/model/descriptor/sezm_nn/indexing.pydeepmd/pt/model/descriptor/sezm_nn/radial.pydeepmd/pt/model/descriptor/sezm_nn/so2.pydeepmd/pt/model/model/sezm_model.pydeepmd/pt/model/model/sezm_spin_model.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pydeepmd/utils/data.pydoc/model/dpa4.mdexamples/water/dpa4/lmp/README.mdexamples/water/dpa4/lmp/in.lammpssource/tests/common/test_examples.pysource/tests/pt/model/test_descriptor_sezm_s2_equivariance.pysource/tests/pt/model/test_sezm_export.pysource/tests/pt/model/test_sezm_model.py
🚧 Files skipped from review as they are similar to previous changes (7)
- deepmd/dpmodel/utils/dist_check.py
- deepmd/pt/model/descriptor/sezm_nn/radial.py
- deepmd/pt/model/descriptor/sezm_nn/indexing.py
- deepmd/dpmodel/utils/lmdb_data.py
- deepmd/pt/loss/dens.py
- deepmd/pt/model/atomic_model/sezm_atomic_model.py
- deepmd/pt/model/descriptor/sezm_nn/so2.py
njzjz-bot
left a comment
There was a problem hiding this comment.
Follow-up review after the latest DPA4 fixes. Several earlier comments are addressed, but I found a couple of remaining issues in the new changes.
Authored by OpenClaw 2026.5.12 (model: gpt-5.5)
c4cd1c9 to
b42a7c1
Compare
|
Others LGTM. |
|
Re-reviewed PR #5448. Current state:
I agree with that request. For packaged examples, these — OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5) |
9dbb2d2 to
eacaeba
Compare
Summary
This PR adds PyTorch support for DPA4, the DeePMD-kit implementation of SeZM (Smooth Equivariant Zone-bridging Model). It introduces the DPA4/SeZM model, descriptor, fitting network, training integration, export path, documentation, examples, and tests.
Main Changes
model.type: "dpa4"/"sezm"descriptor.type: "dpa4"/"sezm"fitting_net.type: "dpa4_ener"/"sezm_ener"loss.type: "ener"loss.type: "dens".pt2freeze/export path using AOTInductor for checkpoints that cannot be represented by the regular TorchScript freeze path.doc/model/dpa4.md.Tests
This PR adds coverage for:
torch.compileeager/compiled consistency.pt2export and DeepPot inferenceRelevant test files include:
source/tests/pt/model/test_descriptor_sezm.pysource/tests/pt/model/test_descriptor_sezm_s2_equivariance.pysource/tests/pt/model/test_descriptor_sezm_triton.pysource/tests/pt/model/test_sezm_model.pysource/tests/pt/model/test_sezm_spin_model.pysource/tests/pt/model/test_sezm_export.pysource/tests/pt/test_training.pysource/tests/pt/test_train_utils.pysource/tests/common/dpmodel/test_dist_check.pysource/tests/common/dpmodel/test_lmdb_data.pyNotes
DPA4 is currently implemented for the PyTorch backend. Model compression is not supported, and DPA4 checkpoints use the
.pt2export path instead of the regular TorchScript freeze path.Summary by CodeRabbit
New Features
Improvements