Skip to content

feat(pt): Add DPA4/SeZM descriptor & model 🎉🎉🎉#5448

Open
OutisLi wants to merge 7 commits into
deepmodeling:masterfrom
OutisLi:dpa4
Open

feat(pt): Add DPA4/SeZM descriptor & model 🎉🎉🎉#5448
OutisLi wants to merge 7 commits into
deepmodeling:masterfrom
OutisLi:dpa4

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented May 19, 2026

Summary

This PR adds PyTorch support for DPA4, the DeePMD-kit implementation of SeZM (Smooth Equivariant Zone-bridging Model). It introduces the DPA4/SeZM model, descriptor, fitting network, training integration, export path, documentation, examples, and tests.

Main Changes

  • Add the DPA4/SeZM PyTorch model stack:
    • model.type: "dpa4" / "sezm"
    • descriptor.type: "dpa4" / "sezm"
    • fitting_net.type: "dpa4_ener" / "sezm_ener"
  • Implement the SO(3)-equivariant descriptor with edge-local SO(2) convolutions, angular schedules, smooth radial envelopes, attention/focus streams, and environment-seeded initial features.
  • Add zone-bridging support for short-range analytical repulsion, including ZBL coupling and descriptor-side short-range clamping.
  • Add DPA4 training support for:
    • conservative energy/force training through loss.type: "ener"
    • experimental direct-force denoising through loss.type: "dens"
    • spin models in the PyTorch backend
    • shared-fitting multitask case FiLM conditioning
    • LoRA fine-tuning and merged checkpoint export
  • Add the DPA4 .pt2 freeze/export path using AOTInductor for checkpoints that cannot be represented by the regular TorchScript freeze path.
  • Add CLI, argcheck, validation, data-system, and inference integration needed to route DPA4 configs and exported models correctly.
  • Add water examples for standard DPA4, ZBL bridging, spin, DeNS, multitask/shared-fitting, LoRA fine-tuning, and LAMMPS inference.
  • Add official model documentation at doc/model/dpa4.md.

Tests

This PR adds coverage for:

  • DPA4/SeZM model and descriptor construction
  • DPA4 aliases in model, descriptor, and fitting configuration
  • SO(3)/SO(2) equivariance behavior
  • conservative energy/force paths
  • torch.compile eager/compiled consistency
  • DPA4 .pt2 export and DeepPot inference
  • spin model behavior
  • ZBL zone bridging
  • DeNS loss and direct-force mode
  • LoRA adapter injection, freezing, merging, and compile compatibility
  • optional Triton kernel dispatch and numerical consistency
  • supporting utility changes in neighbor-list, LMDB data, and distributed checks

Relevant test files include:

  • source/tests/pt/model/test_descriptor_sezm.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • source/tests/pt/model/test_descriptor_sezm_triton.py
  • source/tests/pt/model/test_sezm_model.py
  • source/tests/pt/model/test_sezm_spin_model.py
  • source/tests/pt/model/test_sezm_export.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_train_utils.py
  • source/tests/common/dpmodel/test_dist_check.py
  • source/tests/common/dpmodel/test_lmdb_data.py

Notes

DPA4 is currently implemented for the PyTorch backend. Model compression is not supported, and DPA4 checkpoints use the .pt2 export path instead of the regular TorchScript freeze path.

Summary by CodeRabbit

  • New Features

    • Added SeZM model family and DeNS denoising loss for training; new optimized ".pt2" export path with embedded metadata.
  • Improvements

    • LoRA fine-tuning workflow (apply/merge/strip) for lightweight adapters.
    • On-demand minimum pairwise-distance computation during data reads.
    • Better JAX neighbor-list handling and optional GPU/Triton-accelerated descriptor kernels for faster inference/training.

Review Change Stack

Copilot AI review requested due to automatic review settings May 19, 2026 04:15
@dosubot dosubot Bot added the new feature label May 19, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of lines (20,000). Try reducing the number of changed lines and requesting a review from Copilot again.

@OutisLi OutisLi changed the title feat(pt): Add DPA-4 (a.k.a SeZM) descriptor & model 🎉🎉🎉 feat(pt): Add DPA-4 (aka SeZM) descriptor & model 🎉🎉🎉 May 19, 2026
@OutisLi OutisLi changed the title feat(pt): Add DPA-4 (aka SeZM) descriptor & model 🎉🎉🎉 feat(pt): Add DPA-4/SeZM descriptor & model 🎉🎉🎉 May 19, 2026
Comment thread deepmd/pt/model/descriptor/sezm_nn/so2.py Fixed
Comment thread deepmd/pt/model/descriptor/sezm.py Fixed
Comment thread deepmd/pt/model/atomic_model/sezm_atomic_model.py Fixed
# === Step 1. Setup dimensions ===
extended_coord = extended_coord.to(self.compute_dtype)
nf, nloc, nnei = nlist.shape
nall = extended_coord.shape[1]
try:
import triton # noqa: F401
except ImportError:
SEZM_TRITON_AVAILABLE = False
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used

except ImportError:
SEZM_TRITON_AVAILABLE = False
else:
SEZM_TRITON_AVAILABLE = True
else:
SEZM_TRITON_AVAILABLE = True
else:
SEZM_TRITON_AVAILABLE = False
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 19, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SeZM descriptor stack and NN building blocks (SO(3)/SO(2), radial, activations, FFN, attention), Triton kernels and autograd, DeNS fitting and loss, SeZM atomic model and model factories, PT2 AOTInductor freezer, and dpmodel utilities (min-pair-distance and JAX nlist fix).

Changes

SeZM descriptor, NN building blocks, and integrations

Layer / File(s) Summary
Public re-exports, utilities, indexing, and Wigner-D
deepmd/pt/model/descriptor/sezm_nn/__init__.py, .../indexing.py, .../wignerd.py, .../utils.py
Package re-export hub, SO(3) indexing helpers, quaternion/Wigner-D evaluator, and general SeZM utilities.
Activations, S2 grid, FFN, attention
.../activation.py, .../ffn.py, .../attention.py
Gated/SwiGLU activations, S2-grid projector, SwiGLUS2, EquivariantFFN, PointwiseGridMLP, and envelope-gated softmax attention.
SO(3)/SO(2) linears and norms
.../so3.py, .../so2.py, .../norm.py
Focus/Channel/SO3 linear blocks, SO2Linear, DynamicRadialDegreeMixer, SO2Convolution pipeline, and RMS-based normalization layers.
Radial, edge cache, embeddings
.../radial.py, .../edge_cache.py, .../embedding.py, .../lebedev.py
C3 cutoff/InnerClamp/BridgingSwitch, RadialBasis/MLP, EdgeFeatureCache builders (nlist/edges), geometric/environment/charge-spin embeddings, and Lebedev loader.
Interaction block and descriptor core
.../block.py, deepmd/pt/model/descriptor/sezm.py
SeZMInteractionBlock (SO2 + FFN units, attention-res modes), DescrptSeZM constructor, forward paths, block orchestration, schedules, helpers, LoRA-aware state loading, and serialization.
LoRA adapters and utilities
.../lora.py
LoRA wrappers for SO3/SO2 layers, apply/fold/merge utilities, fine-tuning policy, and compile-cache clearing.
DeNS fitting and DeNSLoss
.../sezm_nn/dens.py, deepmd/pt/loss/dens.py, deepmd/pt/loss/__init__.py
SeZM DeNS fitting net with force-embedding and vector heads; DeNSLoss implements denoising corruption, targets, split energy/force losses, metrics, and (de)serialization; exported in loss package.
SeZMAtomicModel and model factories
.../atomic_model/sezm_atomic_model.py, deepmd/pt/model/model/__init__.py, .../atomic_model/__init__.py
SeZMAtomicModel with ener/dens modes, dens-force stats, normalization, mode switching, serialization; get_sezm_model/get_sezm_spin_model factories and get_model routing, exports updated.

Triton kernels, autograd, dispatch, and custom ops

Layer / File(s) Summary
Triton constants, dispatch, and autograd API
.../triton/constants.py, .../triton/dispatch.py, .../triton/autograd.py, .../triton/__init__.py
Runtime SEZM_TRITON_AVAILABLE flag, Triton rotation-mode enum and resolver, eager fallbacks, autograd Function wrappers and public triton API functions.
Custom-op registration and launchers
.../triton/custom_ops.py
Registers torch.library triton_op launchers, small vs generic kernel dispatch, and fused edge-geometry/RBF op registration.
Generic & small Triton kernels, fused edge-geometry+RBF
.../triton/kernels_generic.py, .../triton/kernels_small.py, .../triton/kernels_edge_geometry_rbf.py
Tiled generic rotation kernels, specialized small-family kernels for lmax<=3, and fused edge-geometry+RBF Triton kernels (forward + two-part backward).

PT2 freezer, CLI, and JIT gating

Layer / File(s) Summary
freeze_sezm_to_pt2 and helpers
deepmd/pt/entrypoints/freeze_pt2.py
Implements SeZM/DPA4 checkpoint detection, head selection, sample-input creation, torch.export with dynamic shapes, move-to-device pass, AOTInductor compile/package, and metadata writes into .pt2 archives.
Entrypoint routing and DeepEval JIT gating
deepmd/pt/entrypoints/main.py, deepmd/pt/infer/deep_eval.py
freeze() routes SeZM checkpoints to .pt2 and preserves .pth handling for legacy; DeepEval disables JIT for SeZM/DPA4 and hardens type-embed detection.

dpmodel utilities and LMDB reader

Layer / File(s) Summary
JAX-safe nlist ghost shift
deepmd/dpmodel/utils/nlist.py
Adds a JAX-specific branch in extend_coord_with_ghosts computing shift_vec without tensordot to avoid JAX internal errors.
Min pair distance and LMDB integration
deepmd/dpmodel/utils/dist_check.py, deepmd/dpmodel/utils/lmdb_data.py
compute_min_pair_dist_single: memory-bounded blockwise min pairwise distance with optional PBC and early stop; LmdbDataReader derives missing min_pair_dist on-the-fly using the new function.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • wanghan-iapcm
  • njzjz
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)

1340-1342: 💤 Low value

Add strict=True to zip() calls to catch length mismatches.

The three module lists (so2_linears, so2_inter_norms, non_linearities) are constructed to have equal length, but adding strict=True provides a runtime guard against future refactoring errors.

Proposed fix
                 for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate(
-                    zip(self.so2_linears, self.so2_inter_norms, self.non_linearities)
+                    zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True)
                 ):

And similarly for line 1365:

                 for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate(
-                    zip(self.so2_linears, self.so2_inter_norms, self.non_linearities)
+                    zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True)
                 ):

Also applies to: 1364-1366

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 1340 - 1342, The zip
over the three module lists (so2_linears, so2_inter_norms, non_linearities)
should use strict=True to surface any length mismatches at runtime; update the
zip(...) calls in the for loops that iterate over (so2_linear, inter_norm,
non_linear) (and the similar loop around lines 1364–1366) to call
zip(self.so2_linears, self.so2_inter_norms, self.non_linearities, strict=True)
so a ValueError is raised if the lists differ in length.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/dpmodel/utils/dist_check.py`:
- Around line 45-56: The current implementation materializes diff =
real_coord[np.newaxis, :, :] - real_coord[:, np.newaxis, :] which creates an
(n_real, n_real, 3) array and O(n^2) memory; replace this with a blocked (or
streaming) pairwise-min computation to avoid full NxN allocation: iterate over
chunks/blocks of indices (e.g., outer blocks i and inner blocks j), compute the
block-wise displacement between real_coord[i:i_b] and real_coord[j:j_b], apply
the PBC transform using cell and inv_cell (reuse the same frac_diff -=
round(frac_diff) and back-transform logic on the block), compute block-wise
dist_sq, ignore self-pair entries when i==j (mask diagonal within the block),
and accumulate the global minimum distance (or a per-atom min array) instead of
assembling dist_sq for all pairs; update code locations around symbols diff,
frac_diff, cell, inv_cell, and dist_sq to perform block-wise operations and
maintain identical results with bounded memory.

In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 377-380: The type check in freeze_sezm_to_pt2 currently only
allows "sezm" and rejects valid "dpa4" checkpoints; update the validation to
accept both "sezm" and "dpa4" (case-insensitive) by checking params.get("type")
against the set {"sezm", "dpa4"} so that is_sezm_checkpoint()-routed dpa4
checkpoints are not erroneously rejected.

In `@deepmd/pt/model/descriptor/sezm_nn/indexing.py`:
- Line 70: Rename the ambiguous single-letter identifier `l` to a clearer name
such as `ell` throughout indexing.py to satisfy Ruff E741; specifically update
list comprehensions like `[2 * l + 1 for l in range(lmax + 1)]` to `[2 * ell + 1
for ell in range(lmax + 1)]`, and rename the loop/parameter occurrences
referenced in the file (including the other reported occurrences around the same
blocks) so function signatures, list comprehensions, for-loops and any uses
inside functions/methods (e.g., variables used alongside `lmax`) are
consistently changed to `ell` (and all references updated) to avoid breaking
references. Ensure tests/uses of symbols from functions/methods in this module
reflect the new name and run `ruff check .` to confirm E741 is resolved.

In `@deepmd/pt/model/descriptor/sezm_nn/radial.py`:
- Around line 230-241: The constructor currently allows non-positive rcut which
later causes invalid divisions; in the class __init__ (the shown __init__) and
the other constructor/initializer that assigns self.rcut (the occurrence around
lines 487-488), add a strict check that rcut > 0 and raise ValueError("`rcut`
must be positive") if not, then convert rcut to float and assign to self.rcut
(and only derive widths or compute 1/rcut after that check) so all downstream
computations are safe.

---

Nitpick comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1340-1342: The zip over the three module lists (so2_linears,
so2_inter_norms, non_linearities) should use strict=True to surface any length
mismatches at runtime; update the zip(...) calls in the for loops that iterate
over (so2_linear, inter_norm, non_linear) (and the similar loop around lines
1364–1366) to call zip(self.so2_linears, self.so2_inter_norms,
self.non_linearities, strict=True) so a ValueError is raised if the lists differ
in length.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9c1d7bdc-93ba-406f-8498-2f63fd44c217

📥 Commits

Reviewing files that changed from the base of the PR and between 01bcf47 and eae375e.

📒 Files selected for processing (84)
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/dpmodel/utils/nlist.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/entrypoints/main.py
  • deepmd/pt/infer/deep_eval.py
  • deepmd/pt/loss/__init__.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/__init__.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/__init__.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/activation.py
  • deepmd/pt/model/descriptor/sezm_nn/attention.py
  • deepmd/pt/model/descriptor/sezm_nn/attn_res.py
  • deepmd/pt/model/descriptor/sezm_nn/block.py
  • deepmd/pt/model/descriptor/sezm_nn/dens.py
  • deepmd/pt/model/descriptor/sezm_nn/edge_cache.py
  • deepmd/pt/model/descriptor/sezm_nn/embedding.py
  • deepmd/pt/model/descriptor/sezm_nn/ffn.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/lebedev.py
  • deepmd/pt/model/descriptor/sezm_nn/lebedev_rules.npz
  • deepmd/pt/model/descriptor/sezm_nn/lora.py
  • deepmd/pt/model/descriptor/sezm_nn/norm.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/descriptor/sezm_nn/so3.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/constants.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py
  • deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py
  • deepmd/pt/model/descriptor/sezm_nn/utils.py
  • deepmd/pt/model/descriptor/sezm_nn/wignerd.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/sezm_spin_model.py
  • deepmd/pt/model/model/spin_model.py
  • deepmd/pt/model/network/mlp.py
  • deepmd/pt/model/task/__init__.py
  • deepmd/pt/model/task/sezm_ener.py
  • deepmd/pt/train/training.py
  • deepmd/pt/train/utils.py
  • deepmd/pt/train/validation.py
  • deepmd/pt/utils/multi_task.py
  • deepmd/pt/utils/serialization.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/data.py
  • deepmd/utils/data_system.py
  • doc/model/dpa4.md
  • doc/model/index.rst
  • examples/water/dpa4/.gitignore
  • examples/water/dpa4/README.md
  • examples/water/dpa4/input-spin.json
  • examples/water/dpa4/input-zbl.json
  • examples/water/dpa4/input.json
  • examples/water/dpa4/input_dens.json
  • examples/water/dpa4/input_multitask.json
  • examples/water/dpa4/input_multitask_sharefit-zbl.json
  • examples/water/dpa4/input_multitask_sharefit.json
  • examples/water/dpa4/lmp/.gitignore
  • examples/water/dpa4/lmp/README.md
  • examples/water/dpa4/lmp/in.lammps
  • examples/water/dpa4/lmp/input.json
  • examples/water/dpa4/lmp/pretrained.pt
  • examples/water/dpa4/lmp/water.lmp
  • examples/water/dpa4/lora_ft.json
  • pyproject.toml
  • source/tests/common/dpmodel/test_dist_check.py
  • source/tests/common/dpmodel/test_lmdb_data.py
  • source/tests/pt/model/test_descriptor_sezm.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • source/tests/pt/model/test_descriptor_sezm_triton.py
  • source/tests/pt/model/test_sezm_export.py
  • source/tests/pt/model/test_sezm_model.py
  • source/tests/pt/model/test_sezm_spin_model.py
  • source/tests/pt/requirements.txt
  • source/tests/pt/test_train_utils.py
  • source/tests/pt/test_training.py

Comment thread deepmd/dpmodel/utils/dist_check.py Outdated
Comment thread deepmd/pt/entrypoints/freeze_pt2.py Outdated
Comment thread deepmd/pt/model/descriptor/sezm_nn/indexing.py Outdated
Comment thread deepmd/pt/model/descriptor/sezm_nn/radial.py
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the large DPA-4/SeZM integration. I focused this pass on the new model registration and the freeze/export path, since that is the user-facing path that determines whether trained checkpoints can be deployed.

I found one blocking issue in the new .pt2 freeze routing:

deepmd/pt/entrypoints/main.py::freeze() routes both type: "sezm" and type: "dpa4" checkpoints to freeze_sezm_to_pt2() because is_sezm_checkpoint() returns true for both aliases. However, freeze_sezm_to_pt2() then rejects anything whose normalized model_params["type"] is not exactly "sezm":

if str(params.get("type", "")).lower() != "sezm":
    raise ValueError(...)

As a result, a checkpoint saved with the documented/registered "dpa4" model alias cannot be frozen at all: the legacy TorchScript path is bypassed, then the new AOTInductor path raises before model construction. This should accept both aliases, e.g. not in {"sezm", "dpa4"}, and the error message should probably mention both names.

Please add a regression test that constructs or mocks a checkpoint with _extra_state.model_params.type == "dpa4" and verifies dp --pt freeze reaches the .pt2 exporter rather than raising at this guard.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the PR more broadly and left inline comments on the blocking issues I found.

Authored by OpenClaw 2026.5.12 (model: gpt-5.5)

Comment thread deepmd/pt/entrypoints/freeze_pt2.py Outdated
Comment thread deepmd/pt/loss/dens.py Outdated
Comment thread deepmd/pt/model/model/sezm_model.py
Comment thread deepmd/pt/model/descriptor/sezm.py Outdated
Comment thread examples/water/dpa4/input_multitask_sharefit.json
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more issues I found in the non-Python/example files.

Authored by OpenClaw 2026.5.12 (model: gpt-5.5)

Comment thread examples/water/dpa4/lmp/README.md Outdated
Comment thread examples/water/dpa4/lmp/in.lammps Outdated
inv_cell = np.linalg.inv(cell)
else:
cell = None
inv_cell = None
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)

1341-1343: 💤 Low value

Consider adding strict=True to zip() calls for defensive programming.

Both loops zip so2_linears, so2_inter_norms, and non_linearities, which are constructed with the same length in __init__. Adding strict=True would catch any future inconsistency at runtime.

Also applies to: 1365-1367

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 1341 - 1343, The two
for-loops that iterate with zip(self.so2_linears, self.so2_inter_norms,
self.non_linearities) should use zip(..., strict=True) to enforce equal-length
sequences at runtime; update both occurrences (the loop that assigns layer_idx,
so2_linear, inter_norm, non_linear and the later similar loop) so they pass
strict=True to zip to catch future mismatches defensively.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1341-1343: The two for-loops that iterate with
zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) should use
zip(..., strict=True) to enforce equal-length sequences at runtime; update both
occurrences (the loop that assigns layer_idx, so2_linear, inter_norm, non_linear
and the later similar loop) so they pass strict=True to zip to catch future
mismatches defensively.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 354878b9-f7dd-4177-8113-519352bf7892

📥 Commits

Reviewing files that changed from the base of the PR and between eae375e and ba9cb8d.

📒 Files selected for processing (19)
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/sezm_spin_model.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/data.py
  • doc/model/dpa4.md
  • examples/water/dpa4/lmp/README.md
  • examples/water/dpa4/lmp/in.lammps
  • source/tests/common/test_examples.py
  • source/tests/pt/model/test_sezm_export.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py

@OutisLi OutisLi requested review from iProzd, njzjz-bot and wanghan-iapcm and removed request for njzjz-bot May 19, 2026 06:34
@codecov
Copy link
Copy Markdown

codecov Bot commented May 19, 2026

Codecov Report

❌ Patch coverage is 68.69126% with 1799 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.25%. Comparing base (01bcf47) to head (eacaeba).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
...t/model/descriptor/sezm_nn/triton/kernels_small.py 0.00% 322 Missing ⚠️
...riptor/sezm_nn/triton/kernels_edge_geometry_rbf.py 0.00% 180 Missing ⚠️
...pmd/pt/model/descriptor/sezm_nn/triton/autograd.py 10.10% 169 Missing ⚠️
...model/descriptor/sezm_nn/triton/kernels_generic.py 0.00% 145 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/norm.py 48.94% 121 Missing ⚠️
...d/pt/model/descriptor/sezm_nn/triton/custom_ops.py 0.00% 119 Missing ⚠️
deepmd/pt/model/descriptor/sezm.py 86.38% 75 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/lora.py 73.99% 71 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/so2.py 87.45% 68 Missing ⚠️
deepmd/pt/model/descriptor/sezm_nn/activation.py 80.00% 64 Missing ⚠️
... and 20 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5448      +/-   ##
==========================================
- Coverage   82.48%   79.25%   -3.23%     
==========================================
  Files         830      865      +35     
  Lines       88521    95778    +7257     
  Branches     4232     4244      +12     
==========================================
+ Hits        73014    75913    +2899     
- Misses      14219    18744    +4525     
+ Partials     1288     1121     -167     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

♻️ Duplicate comments (1)
deepmd/pt/model/descriptor/sezm_nn/so2.py (1)

865-868: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Recompute the Triton-rotation gate from the runtime device.

self.use_triton_rotations is frozen from env.DEVICE in __init__(), but this module is later moved with model.to(...) during freeze/export. If it is constructed on CUDA and then exported on CPU, the eval path still takes the Triton branches and hits a device mismatch. Gate on the runtime tensor device instead of the cached construction device.

#!/bin/bash
# Verify that the Triton gate is computed once at init and reused in both forward branches.
rg -n -C2 'use_triton_rotations|sezm_triton_enabled|rotate_to_local_triton|rotate_back_triton|def _apply' deepmd/pt/model/descriptor/sezm_nn/so2.py

Expected result: self.use_triton_rotations is assigned in __init__, consumed in both rotation branches, and there is no _apply() override that recomputes it after model.to(...).

Also applies to: 1264-1265, 1410-1411

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/sezm_nn/so2.py` around lines 865 - 868, The flag
self.use_triton_rotations is set once in __init__ using
sezm_triton_enabled(device=self.device, dtype=self.dtype) and then reused during
forward rotations, causing device-mismatch after model.to(...); change the
branches in rotate_to_local_triton and rotate_back_triton (and any other
rotation call sites) to re-evaluate sezm_triton_enabled at runtime using the
actual tensor/device (e.g., use x.device or the rotation tensor's device and
dtype) instead of the cached self.use_triton_rotations, and remove reliance on a
frozen value (keep sezm_triton_enabled import and call it where the Triton vs
fallback branch is selected so the correct backend is chosen after
model.to(...)); also ensure there is no _apply override that leaves the old flag
stale.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 268-285: ntypes should be derived from the model's actual reported
type count rather than len(model.get_type_map()), because get_type_map() may be
empty; replace the ntypes assignment so it queries the model for its canonical
type count (e.g., model.get_type_count() or equivalent API) and fall back to
max(1, len(model.get_type_map())) if that method is unavailable, then use that
ntypes when populating atype_np in the loop that assigns i % ntypes.

In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 1341-1366: The two for-loops that iterate over
zip(self.so2_linears, self.so2_inter_norms, self.non_linearities) raise Ruff
B905; update both zips to use strict=True (i.e., zip(..., strict=True)) so
mismatched lengths will error and the linter warning is resolved — change the
zip in the first loop (around so2_layer_attn_res usage and so2_depth_sources
appends) and the zip in the else branch to include strict=True while leaving the
rest of the loop bodies unchanged.
- Around line 324-329: The cached tensor self._cached_weight can become stale or
on the wrong device after module.to(...) or load_state_dict(); update the class
to invalidate this cache on device/dtype moves and state loads by overriding
nn.Module._apply(self, fn) to set self._cached_weight = None (so moves clear the
cache) and/or override _load_from_state_dict(...) to clear self._cached_weight
when parameters are loaded; alternatively (or in addition) add a runtime check
where the cache is used (the weight-access/forward path that reads
self._cached_weight) to verify tensor.device and tensor.dtype match the module
parameters before reusing, and if not, set self._cached_weight = None and
rebuild.

In `@deepmd/pt/model/descriptor/sezm.py`:
- Around line 1937-1941: The current loop that strips every unexpected key under
the descriptor prefix (using prefix, state_dict(), expected_keys and iterating
over state_dict.keys()) must be replaced with a targeted whitelist of only known
transient keys that are intentionally rebuilt at construction; stop
blanket-popping unknown keys so load_state_dict(strict=True) can report
unexpected_keys. Update the logic to define an explicit set/list of transient
keys (or a clear pattern) that you will remove (e.g., transient_keys = {...} or
check exact suffixes), only pop entries that match those transient keys under
prefix, and then call load_state_dict(strict=True) so any other mismatches
surface as errors. Ensure references to prefix and expected_keys remain correct
when computing which keys are transient vs genuine parameters.
- Around line 1137-1152: forward_with_edges bypasses forward's normalization of
charge_spin, causing _apply_charge_spin_embedding to receive uncanonicalized
values; update forward_with_edges (the sparse-edge entrypoint where
type_embedding is applied) to canonicalize charge_spin the same way forward()
does: if charge_spin is None substitute default_chg_spin, ensure
shape/broadcasting matches (nf, nloc, 2) or the expected per-atom shape, and
convert dtype/device to self.compute_dtype and extended_coord/device before
passing into _apply_charge_spin_embedding; alternatively, call the existing
normalization helper used by forward() from within forward_with_edges to avoid
duplicating logic.

---

Duplicate comments:
In `@deepmd/pt/model/descriptor/sezm_nn/so2.py`:
- Around line 865-868: The flag self.use_triton_rotations is set once in
__init__ using sezm_triton_enabled(device=self.device, dtype=self.dtype) and
then reused during forward rotations, causing device-mismatch after
model.to(...); change the branches in rotate_to_local_triton and
rotate_back_triton (and any other rotation call sites) to re-evaluate
sezm_triton_enabled at runtime using the actual tensor/device (e.g., use
x.device or the rotation tensor's device and dtype) instead of the cached
self.use_triton_rotations, and remove reliance on a frozen value (keep
sezm_triton_enabled import and call it where the Triton vs fallback branch is
selected so the correct backend is chosen after model.to(...)); also ensure
there is no _apply override that leaves the old flag stale.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bb4636e1-275a-4666-ad13-46895d7fbed7

📥 Commits

Reviewing files that changed from the base of the PR and between ba9cb8d and 703a441.

📒 Files selected for processing (20)
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/sezm_spin_model.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/data.py
  • doc/model/dpa4.md
  • examples/water/dpa4/lmp/README.md
  • examples/water/dpa4/lmp/in.lammps
  • source/tests/common/test_examples.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • source/tests/pt/model/test_sezm_export.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/loss/dens.py

Comment thread deepmd/pt/entrypoints/freeze_pt2.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/so2.py
Comment thread deepmd/pt/model/descriptor/sezm_nn/so2.py Outdated
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Comment thread deepmd/pt/model/descriptor/sezm.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/entrypoints/freeze_pt2.py`:
- Around line 429-430: Update the docstring for the head parameter to reflect
current behavior: instead of stating "must be ``None``", document that the code
will auto-select "Default" for multitask checkpoints and otherwise requires an
explicit head name (or allow passing None to trigger auto-selection), and
mention the accepted values (e.g., "Default" or explicit head string) so callers
know the runtime semantics; target the docstring entry labeled head in
freeze_pt2.py (the docblock that currently says "Reserved for future multi-task
support; must be ``None``") and replace that sentence with the new, accurate
description.
- Around line 228-247: The metadata writer _collect_metadata() currently writes
an empty "type_map" (from model.get_type_map()) for SeZM/DPA4 models which leads
DeepEval._init_from_metadata() and DeepPotPTExpt to infer ntypes=0; fix
_collect_metadata() in freeze_pt2.py so that when list(model.get_type_map()) is
empty you populate a usable type map or explicit ntypes: e.g., set
metadata["type_map"] to a fallback list(range(N)) derived from
model.get_sel_type() or model.get_sel() (use len(model.get_sel_type()) or
len(model.get_sel()) to compute N), or include metadata["ntypes"]=int(N); ensure
the fallback logic uses model.get_sel_type()/model.get_sel() (and handles spin
if is_spin) so metadata-only loads give the correct ntypes for
DeepEval._init_from_metadata() and DeepPotPTExpt to consume.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e9a3035f-d385-4b45-8538-5a266b7d86e4

📥 Commits

Reviewing files that changed from the base of the PR and between 703a441 and e43d4b8.

📒 Files selected for processing (21)
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/entrypoints/freeze_pt2.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py
  • deepmd/pt/model/model/sezm_model.py
  • deepmd/pt/model/model/sezm_spin_model.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • deepmd/utils/data.py
  • doc/model/dpa4.md
  • examples/water/dpa4/lmp/README.md
  • examples/water/dpa4/lmp/in.lammps
  • source/tests/common/test_examples.py
  • source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py
  • source/tests/pt/model/test_sezm_export.py
  • source/tests/pt/model/test_sezm_model.py
🚧 Files skipped from review as they are similar to previous changes (7)
  • deepmd/dpmodel/utils/dist_check.py
  • deepmd/pt/model/descriptor/sezm_nn/radial.py
  • deepmd/pt/model/descriptor/sezm_nn/indexing.py
  • deepmd/dpmodel/utils/lmdb_data.py
  • deepmd/pt/loss/dens.py
  • deepmd/pt/model/atomic_model/sezm_atomic_model.py
  • deepmd/pt/model/descriptor/sezm_nn/so2.py

Comment thread deepmd/pt/entrypoints/freeze_pt2.py
Comment thread deepmd/pt/entrypoints/freeze_pt2.py Outdated
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up review after the latest DPA4 fixes. Several earlier comments are addressed, but I found a couple of remaining issues in the new changes.

Authored by OpenClaw 2026.5.12 (model: gpt-5.5)

Comment thread examples/water/dpa4/lmp/README.md
Comment thread deepmd/pt/model/model/sezm_model.py
Comment thread deepmd/pt/model/model/sezm_spin_model.py
@OutisLi OutisLi force-pushed the dpa4 branch 2 times, most recently from c4cd1c9 to b42a7c1 Compare May 20, 2026 06:02
@OutisLi OutisLi requested a review from njzjz-bot May 20, 2026 06:49
Comment thread examples/water/dpa4/lmp/.gitignore Outdated
Comment thread examples/water/dpa4/.gitignore Outdated
@iProzd
Copy link
Copy Markdown
Member

iProzd commented May 20, 2026

Others LGTM.

@njzjz-bot
Copy link
Copy Markdown
Contributor

Re-reviewed PR #5448. Current state:

  • CI/checks are green (CUDA job is skipped as configured; other build/test/CodeQL/Codecov checks pass).
  • The previous CodeRabbit points around .pt2 metadata and the head docstring look addressed: _collect_metadata() now writes both type_map and explicit ntypes, and the public docstring matches the current multi-task head behavior.
  • iProzd’s remaining requested changes are just the two example .gitignore files:
    • examples/water/dpa4/.gitignore
    • examples/water/dpa4/lmp/.gitignore

I agree with that request. For packaged examples, these .gitignores are not needed and can hide generated local artifacts in a way that makes example maintenance less transparent. After deleting those two files, this looks mergeable to me.

— OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)

@OutisLi OutisLi requested review from iProzd and njzjz May 20, 2026 14:25
Comment thread pyproject.toml Outdated
@OutisLi OutisLi requested a review from njzjz May 20, 2026 14:41
Comment thread deepmd/pt/utils/multi_task.py
Comment thread deepmd/pt/model/atomic_model/sezm_atomic_model.py Outdated
Comment thread deepmd/dpmodel/utils/dist_check.py
Comment thread deepmd/dpmodel/utils/nlist.py Outdated
Comment thread deepmd/dpmodel/utils/nlist.py
Comment thread deepmd/pt/train/utils.py
Comment thread source/api_cc/src/DeepPotPTExpt.cc Outdated
@OutisLi OutisLi changed the title feat(pt): Add DPA-4/SeZM descriptor & model 🎉🎉🎉 feat(pt): Add DPA4/SeZM descriptor & model 🎉🎉🎉 May 21, 2026
@OutisLi OutisLi force-pushed the dpa4 branch 3 times, most recently from 9dbb2d2 to eacaeba Compare May 21, 2026 03:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants