Skip to content

Mix batch 0429#5439

Open
littlepeachs wants to merge 11 commits into
deepmodeling:masterfrom
littlepeachs:mix_batch_0429
Open

Mix batch 0429#5439
littlepeachs wants to merge 11 commits into
deepmodeling:masterfrom
littlepeachs:mix_batch_0429

Conversation

@littlepeachs
Copy link
Copy Markdown
Collaborator

@littlepeachs littlepeachs commented May 9, 2026

Add mixed-nloc LMDB batching for PyTorch training

Summary

This PR adds optional mixed_batch support to the PyTorch LMDB training and validation data path.

When mixed_batch is enabled, one LMDB batch may contain frames with different atom counts / nloc values. Atom-wise fields are flattened across the batch, frame ownership is tracked with batch and ptr, and the model runs through a flat forward path without padding.

The default behavior is unchanged. mixed_batch is false by default, so LMDB batches still use the existing same-nloc sampler unless the option is explicitly enabled.

Motivation

LMDB batches previously required all frames in a batch to have the same nloc. This can make batching inefficient for datasets that mix systems with different atom counts. Mixed batching lets those frames train together while keeping atom indexing explicit in a flattened graph layout.

Main Changes

  • Add mixed_batch to PyTorch LMDB training_data and validation_data.
  • Keep mix_batch as an alias for the same option.
  • In mixed_batch=true mode, interpret batch_size as the number of frames/systems in one mixed batch.
  • Add a mixed LMDB collate path that concatenates atom-wise fields, stacks frame-wise fields, and emits batch and ptr for frame ownership.
  • Precompute flat graph metadata and indices during collation for the DPA3/RepFlow path.
  • Add flat forward support through the training wrapper, energy model, atomic model, descriptors, fitting net, and energy loss.
  • Keep atom-wise outputs flat and reduce frame-wise quantities using the global atom-to-frame mapping.

Usage

{
  "training": {
    "training_data": {
      "systems": "path/to/train.lmdb",
      "batch_size": 8,
      "mixed_batch": true
    },
    "validation_data": {
      "systems": "path/to/valid.lmdb",
      "batch_size": 8,
      "mixed_batch": true
    }
  }
}

With mixed_batch=true, batch_size controls how many frames/systems are merged into one flattened batch.

With mixed_batch=false, the existing same-nloc LMDB batching behavior is preserved.

Tests

  • Add unit tests for mixed-nloc LMDB collation.
  • Cover flattened batch / ptr layout and flat graph precomputation.
  • Cover the mix_batch config alias.
  • Add a DPA3 LMDB mixed-batch example config.
  • Add test_mixed_batch.sh for end-to-end training coverage.

Scope

  • This change applies to the PyTorch LMDB data path.
  • The flat graph path is currently used by the DPA3/RepFlow descriptor flow.

Summary by CodeRabbit

  • New Features

    • Mixed-batch LMDB training: flat atom-wise collation, optional precomputed flat-graph inputs, and model paths to consume them.
    • Trainer and model wrapper accept flat-graph inputs for mixed-nloc workloads.
    • Deterministic default initialization seeds for model components.
  • Improvements

    • Per-frame normalization for energy and virial losses in mixed batches.
  • Documentation

    • Added mixed-batch LMDB batching and flat-graph preprocessing guide.
  • Tests

    • Added unit/integration tests, sample configs, and a runnable mixed-batch training script.

Review Change Stack

@github-actions github-actions Bot added the Python label May 9, 2026
@dosubot dosubot Bot added the new feature label May 9, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds LMDB mixed-batch support: flat-graph precomputation, collate functions that flatten atom-wise fields, new flat forward paths through descriptor/fitting/atomic/model layers, per-frame loss normalization, and training integration with sampler/collate changes and tests/docs.

Changes

Mixed-batch LMDB Training Pipeline

Layer / File(s) Summary
Configuration and CLI args
deepmd/utils/argcheck.py
Adds mixed_batch (mix_batch alias) docs and boolean args for training and validation.
Graph precomputation & graph indices
deepmd/pt/utils/nlist.py, deepmd/pt/model/network/graph_utils_flat.py
Adds FlatGraphData, build_precomputed_flat_graph, rebuild_extended_coord_from_flat_graph, get_central_ext_index, image-aware ghost extension, and get_graph_index_flat for edges/angles.
Descriptor flat-batch implementations
deepmd/pt/model/descriptor/env_mat.py, deepmd/pt/model/descriptor/repflows.py, deepmd/pt/model/descriptor/dpa3.py
Adds prod_env_mat_flat, DescrptBlockRepflows.forward_flat, and DescrptDPA3.forward_flat to consume flat-graph tensors and return descriptor/rotation outputs.
Fitting and atomic-model flat forward
deepmd/pt/model/task/invar_fitting.py, deepmd/pt/model/atomic_model/dp_atomic_model.py
Adds InvarFitting.forward_flat to pad/unpad flat inputs and DPAtomicModel.forward_common_atomic_flat to call descriptor/fitting flat paths, mask outputs, and apply stats.
Model routing, seeds, and derivatives
deepmd/pt/model/model/__init__.py, deepmd/pt/model/model/ener_model.py, deepmd/pt/model/model/make_model.py
Adds deterministic default seeds, extends EnergyModel.forward to route to forward_common_flat, and implements forward_common_flat_native, forward_common_lower_flat, and autograd _compute_derivatives_flat.
Loss normalization and training wrapper
deepmd/pt/loss/ener.py, deepmd/pt/train/wrapper.py
EnergyStdLoss uses per-frame normalization when ptr is present; ModelWrapper.forward accepts and injects flat-graph tensors and updates natoms handling.
LMDB mixed-batch collate and dataset
deepmd/pt/utils/lmdb_dataset.py
Adds _collate_lmdb_mixed_batch, make_lmdb_mixed_batch_collate (optional precompute via graph_config), and LmdbDataset mixed-batch init path using an internal DataLoader.
Trainer, dataloader, and entrypoint integration
deepmd/pt/entrypoints/main.py, deepmd/pt/train/training.py
Entrypoint forwards mixed_batch; trainer builds mixed-batch DataLoader with seeded sampler or sequential sampler, constructs graph_config, extends input keys for flat-graph tensors, and handles device transfer for mixed batches.
Tests, examples, and documentation
source/tests/pt/test_lmdb_dataloader.py, source/tests/pt/test_loss.py, test_mptraj/*, test_mixed_batch.sh, doc/development/lmdb-mixed-system-batching.md
Adds unit tests for collate/dataset/loss, example JSON configs and Bash runner, and developer documentation describing the flat-graph mixed-batch flow.

Sequence Diagram(s)

sequenceDiagram
  participant DataLoader
  participant Trainer
  participant Wrapper
  participant Model_CM
  participant Descriptor_Repflows
  participant FittingNet
  participant Loss_EnerStd
  DataLoader->>Trainer: mixed-batch batch (coord, atype, batch, ptr, ...)
  Trainer->>Wrapper: device-transfer & extended keys
  Wrapper->>Model_CM: forward_common_flat args (flat graph tensors)
  Model_CM->>Descriptor_Repflows: descriptor.forward_flat(extended_coord, nlist, ...)
  Descriptor_Repflows->>FittingNet: node descriptors -> fitting_net.forward_flat
  FittingNet->>Model_CM: atomic predictions (atom-wise)
  Model_CM->>Loss_EnerStd: aggregated predictions + ptr/batch for per-frame norms
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

Examples

Suggested reviewers

  • iProzd
  • wanghan-iapcm
  • njzjz
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Mix batch 0429' is vague and does not clearly convey the main change; it appears to reference a branch name/issue number rather than describing the feature. Revise the title to clearly describe the feature, such as 'Add mixed-nloc LMDB batch support for PyTorch training' or 'Support mixed batch sizes with variable atom counts in LMDB datasets'.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/loss/ener.py (1)

412-424: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Generalized-force loss still assumes a single natoms for every frame.

This branch reshapes force and drdq with natoms * 3, which only works for uniform-size batches. In mixed-batch mode those tensors are flattened across frames, so enabling generalized-force loss here will mis-shape the data or crash.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/loss/ener.py` around lines 412 - 424, The generalized-force branch
in ener.py assumes a single natoms by reshaping force_pred, force_label and drdq
with natoms * 3, which breaks for mixed-size batches; instead, use frame-aware
shapes: do not flatten all frames into one dimension—reshape or iterate using
per-frame atom counts (or keep forces as (nframes, natoms_i, 3) and drdq as
(nframes, natoms_i*3, n_gcoord)) and compute gen_force and gen_force_label
per-frame (or via batched operations that use the original nframes dimension) so
tensor dims align; also apply find_drdq/pref_gf per-frame. Locate symbols:
has_gf, drdq, find_drdq, pref_gf, force_pred, force_label,
force_reshape_nframes, force_label_reshape_nframes, drdq_reshape, gen_force,
gen_force_label and update the reshaping/Einstein-summation to operate over the
nframes dimension (or loop over frames) instead of using natoms * 3 across the
whole batch.
🧹 Nitpick comments (13)
deepmd/pt/model/descriptor/repflows.py (1)

840-874: 💤 Low value

Make the disjoint-frame invariant explicit.

The synthetic batch=1 wrap is only correct because RepFlowLayer.forward is strictly per-atom + per-neighbor, and because nlist/a_nlist here never cross frame boundaries (the LMDB collator builds per-frame extended atoms and offsets neighbor indices accordingly, then maps back via per-frame mapping). A brief comment near this block stating that invariant — and noting that any future RepFlowLayer change introducing global ops over the batch axis would silently mix frames — will save a future debugging session.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/repflows.py` around lines 840 - 874, Add a short
explanatory comment above the synthetic batch=1 wrapping that states the
disjoint-frame invariant: explain that the temporary batching of tensors like
node_ebd_batched, edge_ebd_batched, h2_batched, angle_ebd_batched,
nlist_batched, a_nlist_batched, etc. is safe because RepFlowLayer.forward
operates strictly per-atom and per-neighbor and the LMDB collator guarantees
nlist/a_nlist never reference atoms across different frames (mapping/offsets
keep frames disjoint); also warn that if RepFlowLayer.forward (or any layer in
self.layers) later introduces global ops over the batch axis the synthetic batch
will mix frames and break correctness, so such changes must preserve the
per-frame isolation or remove this batching trick.
deepmd/pt/utils/nlist.py (2)

101-170: 🏗️ Heavy lift

Per-frame Python loop with .item() calls forces a CPU/GPU sync each iteration.

For large batch_size (e.g., 128 in lmdb_baseline.json), this loop incurs 2 * batch_size device syncs from int(ptr[...].item()) plus per-frame kernel launches for extend_coord_with_ghosts_with_images and build_neighbor_list. Even though it runs in dataloader workers, this can dominate end-to-end throughput when frames have small atom counts. Two relatively cheap mitigations:

  • Pull ptr to CPU once before the loop (ptr_cpu = ptr.cpu().tolist()) and index lists directly, eliminating per-iteration syncs.
  • If the data loader ultimately runs on CPU tensors (typical for collate_fn), the per-frame work is largely free of device sync but the loop still costs Python overhead — consider a vectorized batched ghost-extension that operates on a padded tensor and then re-flattens, mirroring the existing extend_input_and_build_neighbor_list_with_images.

Not blocking, but worth measuring before this lands as the default mixed-batch path.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/utils/nlist.py` around lines 101 - 170, The loop is causing
repeated device syncs via int(ptr[frame_idx].item()) and per-frame kernel
overhead in extend_coord_with_ghosts_with_images and build_neighbor_list; fix by
materializing ptr on CPU once (e.g., ptr_cpu = ptr.cpu().tolist()) and use
ptr_cpu[frame_idx]/ptr_cpu[frame_idx+1] for start_idx/end_idx to eliminate
.item() calls, and where possible call
normalize_coord/extend_coord_with_ghosts_with_images/build_neighbor_list on
already-CPU tensors or add a batched/padded variant (similar to
extend_input_and_build_neighbor_list_with_images) to process frames in a
vectorized way to reduce Python per-frame overhead.

262-281: 💤 Low value

Document the "first nloc extended atoms per frame are local" assumption.

get_central_ext_index relies on the layout produced by _extend_coord_with_ghosts_impl, where shift indices are sorted by L2 norm and (0,0,0) is therefore the first image — making the first nloc extended atoms in each frame the local ones. This invariant is implicit; a one-line comment will save a future maintainer who changes the ghost layout from silently breaking this helper.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/utils/nlist.py` around lines 262 - 281, Add a one-line doc comment
inside get_central_ext_index stating the invariant it relies on: that the
extended layout produced by _extend_coord_with_ghosts_impl places the unshifted
image (shift (0,0,0)) first for each frame and thus the first nloc extended
atoms per frame are the local atoms; mention that shift indices are sorted by L2
norm so the (0,0,0) image is first and changing that ordering will break
get_central_ext_index.
.gitignore (1)

81-81: 💤 Low value

deepmd-kit/ ignore pattern is overly broad.

Ignoring a directory named after the project itself can mask legitimate content if the repo or a checkout is ever placed under a deepmd-kit/ subdirectory (e.g., nested workspaces, tooling that clones into a project-named folder). Consider narrowing or scoping the rule to only the artifact path you intend to ignore.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.gitignore at line 81, The .gitignore entry "deepmd-kit/" is too broad and
may unintentionally ignore legitimately nested repos; update the ignore rule
that currently uses deepmd-kit/ to a more specific pattern (for example scope it
to the repo root by changing to "/deepmd-kit/" or better yet target only
build/dist/artifact paths like "/deepmd-kit/build", "/deepmd-kit/dist", or
specific filenames/patterns such as "deepmd-kit/*.egg-info") so only the
intended artifact directories/files are ignored; locate the deepmd-kit/ entry in
the .gitignore and replace it with one or more narrow, explicit patterns.
deepmd/pt/model/task/invar_fitting.py (2)

299-315: 💤 Low value

Atom-wise output detection is shape-heuristic; consider being explicit.

The current rule un-pads any tensor whose first two dims happen to be (nframes, max_nloc) and leaves everything else untouched. That works for the current outputs (var_name, middle_output), but a future output that returns frame-level shape (nframes, max_nloc) would be silently un-padded, and any per-atom output whose first axis is something other than nframes would be silently passed through. Consider keying off output_def() (which already declares atomic vs. frame-level variables) instead of tensor shape.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/task/invar_fitting.py` around lines 299 - 315, The tensor
un-padding currently uses a shape heuristic in the loop that builds result_flat;
instead use the model's declared output types from output_def() to decide which
keys are atom-wise. Replace the shape-based condition in the for key, value in
result_batch loop with a lookup like output_def()[key].is_atomic (or equivalent
in your model API) and only apply valid_atom_mask to values for keys declared
atomic; leave frame-level or other outputs untouched and handle non-tensor
values as before (keep variable names result_batch, result_flat,
valid_atom_mask, and call output_def() to determine atomic vs frame-level).

259-264: 💤 Low value

Padding atype = -1 silently wraps to the last type in PyTorch indexing.

Per-type lookups inside self.forward (bias_atom_e/case embedding/etc.) use tensor[atype], where -1 does not error out — it wraps to the last type. Outputs are masked away later, so this does not corrupt the result, but it does waste compute on the padding rows and can interact unexpectedly with any branch that special-cases negative types. Using a valid sentinel like 0 is safer and semantically clearer; the valid_atom_mask continues to discard the rows.

♻️ Suggested change
-        atype_batch = torch.full(
-            (nframes, max_nloc),
-            -1,
-            dtype=atype.dtype,
-            device=device,
-        )
+        # Use 0 as a safe dummy type for padding rows; outputs are discarded
+        # by valid_atom_mask below so the value is not observable.
+        atype_batch = torch.zeros(
+            (nframes, max_nloc),
+            dtype=atype.dtype,
+            device=device,
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/task/invar_fitting.py` around lines 259 - 264, The padding
value for atype_batch is set to -1 which causes PyTorch indexing (e.g., in
self.forward when doing per-type lookups like bias_atom_e[atype] or embeddings)
to wrap to the last type; change the sentinel to a valid non-negative index
(e.g., 0 or a dedicated PAD type index) when creating atype_batch so lookups
won't accidentally reference real types, and keep using valid_atom_mask to
discard those rows downstream; update places that assume negative sentinel
behavior to use the new PAD index (refer to atype_batch, self.forward,
bias_atom_e, and valid_atom_mask).
deepmd/pt/model/network/graph_utils_flat.py (1)

41-89: 💤 Low value

LGTM.

Edge index construction (n2e_index / n_ext2e_index) and the angle index mapping via edge_lookup are consistent with the consumers in repflows.py and the dynamic-selection path. The a_nlist_mask_3d logic correctly requires both neighbors of an angle to be valid before producing a triplet. The single CPU sync via n_edge = ... .item() is necessary for torch.arange.

One small optional suggestion: an assert a_sel <= nnei near the top would surface mis-configured inputs as a clear error rather than a downstream broadcast failure (DescrptDPA3 already enforces e_sel >= a_sel, so this is mostly defensive).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/network/graph_utils_flat.py` around lines 41 - 89, Add a
defensive check to guard against mis-configured inputs by asserting that a_sel
(from a_nlist_mask.shape[1]) does not exceed nnei (from nlist_flat.shape[1])
near the top of the function; insert something like an assert or raise
ValueError after computing nnei and a_sel so that functions using n2e_index,
edge_lookup_a, and the a_nlist_mask_3d broadcasts fail fast with a clear message
instead of producing downstream broadcast errors.
deepmd/pt/model/model/make_model.py (3)

520-599: 💤 Low value

forward_common_flat is a pure passthrough to forward_common_flat_native.

The wrapper adds no behavior — same parameters, same return — and just forwards every argument. Two near-identical method definitions (with duplicated docstrings) is a maintenance hazard and adds friction for readers tracing the call. Either (a) inline forward_common_flat_native into forward_common_flat and drop the _native variant, or (b) document a clear reason for the split (e.g. JIT export boundary, planned non-native variant) so the duplication is justified.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/make_model.py` around lines 520 - 599,
forward_common_flat is a pure passthrough to forward_common_flat_native with
duplicated docstrings, creating maintenance overhead; either remove the
redundant wrapper by merging forward_common_flat_native into forward_common_flat
(rename/remove the _native variant and keep a single method and docstring) or
add a concise comment/docstring to forward_common_flat explaining the explicit
reason for the split (e.g., JIT/export boundary or future divergence). Update
any call sites that reference forward_common_flat_native to use the single
chosen method name (forward_common_flat) and ensure tests/type hints still pass.

296-309: 💤 Low value

Move the helper import to module level and tidy the validation order.

Two minor suggestions:

  1. from deepmd.pt.utils.nlist import rebuild_extended_coord_from_flat_graph is imported every call inside the hot forward path (line 296). Promote it to the existing nlist import block at the top of the file.
  2. The requires_grad_(True) work on coord/box at lines 280-282 happens before the field-presence validation at lines 283-309. If the validation raises, the clone+detach is wasted. Reorder so the validation runs first.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/make_model.py` around lines 296 - 309, Move the helper
import for rebuild_extended_coord_from_flat_graph out of the hot forward path
and add it to the module-level nlist import block with the other
deepmd.pt.utils.nlist imports so it is imported once at load time; also change
the validation order inside the flat mixed-batch branch so you check for
presence of the required graph fields (mapping, extended_batch, extended_image,
etc.) and raise the RuntimeError before mutating tensors, and only call
coord.requires_grad_(True) / box.requires_grad_(True) (and any clone+detach)
after the validation passes to avoid wasted work when validation fails.

311-314: 💤 Low value

Redundant assertions after the explicit None-check.

Lines 283-295 already verify extended_atype, extended_batch, mapping, nlist, nlist_ext, etc. are not None (otherwise the else raises), so the four assert statements at lines 311-314 are redundant at runtime. If they exist solely for type-narrowing under static checkers, a comment to that effect would help future readers. Otherwise they can be removed.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/make_model.py` around lines 311 - 314, The four assert
statements for extended_atype, extended_batch, mapping, and nlist are redundant
because the preceding explicit None-check/else already guarantees they are not
None; either remove those assert lines or replace them with a short explanatory
comment indicating they exist purely for static type-narrowing (for tools like
mypy) so future readers understand their purpose; locate the checks around the
make_model function where extended_atype, extended_batch, mapping, and nlist are
validated and update accordingly.
deepmd/pt/train/training.py (1)

1690-1723: 💤 Low value

Minor structural cleanup in get_data for mixed-batch handling.

The current loop skips batch/ptr (lines 1697-1698) only to transfer them separately at lines 1722-1723. Since they are tensors, they would be handled correctly by the generic not isinstance(..., list) branch at line 1699-1701. You can drop the explicit skip and the post-loop transfer, then just extend input_keys when is_mixed_batch is true:

♻️ Proposed simplification
         for key in batch_data.keys():
             if key == "sid" or key == "fid" or "find_" in key:
                 continue
-            # Skip batch and ptr for now, will handle them separately
-            elif key == "batch" or key == "ptr":
-                continue
             elif not isinstance(batch_data[key], list):
                 if batch_data[key] is not None:
                     batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True)
@@
         # Mixed-nloc LMDB batches include precomputed flat-graph tensors.
         if is_mixed_batch:
             input_keys = input_keys + list(_FLAT_GRAPH_INPUT_KEYS)
-            batch_data["batch"] = batch_data["batch"].to(DEVICE, non_blocking=True)
-            batch_data["ptr"] = batch_data["ptr"].to(DEVICE, non_blocking=True)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/train/training.py` around lines 1690 - 1723, In get_data, remove
the special-case skipping of "batch" and "ptr" inside the batch_data loop and
delete the separate post-loop device transfers for batch_data["batch"] and
batch_data["ptr"]; instead let the existing non-list tensor branch handle them
(they'll be moved to DEVICE automatically), and only keep the is_mixed_batch
detection so you can extend input_keys with list(_FLAT_GRAPH_INPUT_KEYS) when
is_mixed_batch is true (symbols: get_data, batch_data, is_mixed_batch,
input_keys, _FLAT_GRAPH_INPUT_KEYS, DEVICE).
deepmd/pt/model/model/ener_model.py (1)

92-176: ⚡ Quick win

Extract the shared output-assembly block to avoid maintaining two near-identical paths.

Lines 117-143 (flat path) and 153-174 (regular path) build model_predict in essentially the same way (atom_energy/energy, conditional force/dforce, virials, mask, hessian). Two parallel implementations are a maintenance hazard — future output keys or shape adjustments must be mirrored in both, and any drift becomes silent. Consider extracting a small helper like _assemble_energy_predict(model_ret, do_atomic_virial) and calling it from both branches, with the flat branch additionally adding model_predict["updated_coord"] += coord in the no-fitting-net case.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/model/ener_model.py` around lines 92 - 176, The two branches
in forward (flat and regular) duplicate the same output-assembly logic for
model_predict; extract that into a helper method (e.g.
_assemble_energy_predict(model_ret, do_atomic_virial)) that encapsulates setting
"atom_energy", "energy", choosing "force" vs "dforce" using do_grad_r("energy"),
adding "virial" and "atom_virial" via do_grad_c("energy") and do_atomic_virial,
including "mask" and "hessian" when present/when _hessian_enabled; then call
this helper from both the flat-path and regular-path after receiving model_ret,
and keep the flat/regular-specific behavior (the no-fitting-net case that does
model_predict["updated_coord"] += coord) in the respective branch only. Ensure
the helper references self.do_grad_r, self.do_grad_c, self._hessian_enabled and
self.get_fitting_net() behavior remains consistent.
deepmd/pt/train/wrapper.py (1)

237-247: 💤 Low value

natoms semantics now diverge between flat and regular batches; document the contract.

For regular batches atype.dim() > 1 and natoms = atype.shape[-1] is per-frame nloc, but for mixed batches atype is 1-D and natoms = atype.shape[0] is the total flattened atom count across all frames. Downstream loss code must distinguish these by inspecting ptr in input_dict to normalize correctly — that's a fragile, implicit contract spread across wrapper.py, ener.py, and make_model.py.

Consider either (a) renaming/splitting the argument so callees see total_atoms vs nloc_per_frame explicitly, or (b) adding a docstring/inline note here that clearly states the dual semantics and that the loss must branch on ptr. The current inline comment on lines 237-238 helps but does not surface the downstream consumer expectation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/train/wrapper.py` around lines 237 - 247, The natoms value
currently has dual semantics (per-frame nloc when atype.dim()>1 vs total
flattened atom count when 1-D) which is implicit and fragile; fix by making the
contract explicit: compute both nloc_per_frame (e.g., atype.shape[-1] when
atype.dim()>1 else None) and total_atoms (atype.numel() or atype.shape[0]) in
wrapper.py and change the call to self.loss[task_key](input_dict,
self.model[task_key], label, nloc_per_frame=nloc_per_frame,
total_atoms=total_atoms, learning_rate=cur_lr), then update all loss
implementations (ener.py) and any factory in make_model.py to accept these new
named args (or provide backward-compatible handling) so downstream code no
longer must infer semantics from ptr.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In @.gitignore:
- Line 78: The .gitignore entry "test_mptraj/" conflicts with tracked files
committed in this PR (test_mptraj/lmdb_baseline.json and
test_mptraj/lmdb_mixed_batch.json); fix by either removing the directory-level
ignore line "test_mptraj/" or by adding explicit negation patterns for the
tracked files (e.g., "!test_mptraj/lmdb_baseline.json" and
"!test_mptraj/lmdb_mixed_batch.json") so future edits to those config files are
not hidden and git add won't require -f; update .gitignore accordingly and
ensure the two committed files remain tracked.

In `@deepmd/pt/loss/ener.py`:
- Around line 228-240: The mixed-batch normalization currently computes a single
scalar atom_norm via atom_norm = mean(1.0 / natoms_per_frame) which misweights
frames with different atom counts; instead compute per-frame normalizers and
apply them before reducing l2_ener_loss and l2_virial_loss: when is_mixed_batch
(detect via ptr and natoms_per_frame derived from ptr), build per-frame
atom_norms = 1.0 / natoms_per_frame.float() and multiply each frame's loss by
its corresponding atom_norm before taking the final mean/sum (or perform a
weighted reduction using these per-frame weights) so that l2_ener_loss and
l2_virial_loss are normalized per-frame correctly rather than using a single
averaged scalar.

In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 378-390: The flat-path call to descriptor.forward_flat currently
passes fparam unmodified and thus skips the fallback logic used in
forward_atomic; update the flat-path before calling descriptor.forward_flat (the
block around descriptor.forward_flat) to mirror forward_atomic's behavior: if
self.add_chg_spin_ebd (or the equivalent flag) is True and fparam is None,
compute the default via self.get_default_fparam(...) using the same inputs/shape
logic used in forward_atomic and assign it to fparam, then pass that fparam into
descriptor.forward_flat so mixed-batch runs use the same default frame
parameters.

In `@deepmd/pt/model/model/__init__.py`:
- Around line 90-98: The hybrid-descriptor logic in
_set_default_descriptor_init_seed (and the analogous block later) always derives
child seeds from the passed-in default `seed` instead of the effective parent
config seed; update the logic to first determine the parent_seed =
params.get("seed", seed) (i.e., use params["seed"] when present, otherwise fall
back to the incoming seed) and then call child_seed(parent_seed, idx) when
recursing into each hybrid child (preserving the existing
isinstance(descriptor_params, dict) check and recursive call to
_set_default_descriptor_init_seed).

In `@deepmd/pt/model/model/make_model.py`:
- Around line 497-518: The flat-path branch in make_model.py quietly skips
atomic virial when do_grad_c("energy") is true and do_atomic_virial is true (the
pass under if do_atomic_virial), causing fit_ret to miss
atom_virial/energy_derv_c keys and silently breaking callers (e.g.,
EnergyModel.forward reading model_ret["energy_derv_c"]); update the branch in
the block handling energy_derv_c_redu so that if do_atomic_virial is true you
either (A) compute and populate the appropriate atomic virial / energy_derv_c
entries in fit_ret consistent with the reduced-to-atom shape used elsewhere (use
energy_redu, energy_derv_c_redu and ptr to derive per-atom contributions), or if
that is nontrivial, (B) raise a clear NotImplementedError (or emit a one-time
logged warning and raise) referencing do_atomic_virial and the flat-path so
callers know the feature is unsupported; ensure the error/warning mentions
make_model.make_model (or the surrounding function) and that fit_ret must
include the keys model_predict/EnergyModel.forward expects (e.g.,
"energy_derv_c" / atom_virial) to avoid silent failures.
- Around line 217-352: The flat forward path forward_common_flat_native is
missing the same input/output precision normalization used in forward_common;
call the model's _input_type_cast on incoming tensors (coord, box, fparam,
aparam, atype, batch, mapping, extended_* inputs as appropriate) at the start so
coord/box are promoted to global_pt_float_precision (and energy-related inputs
to global_pt_ener_float_precision if required) before doing
clone().detach().requires_grad_(True) or rebuilding extended_coord, and after
computing model_predict_lower run _output_type_cast to promote outputs
(especially energy_redu) back to global_pt_ener_float_precision; ensure you
reference the existing helpers _input_type_cast and _output_type_cast and keep
their casting behavior identical to forward_common.

In `@deepmd/pt/train/training.py`:
- Around line 264-299: When _data.mixed_batch is true but the model descriptor
lacks repflows, raise an immediate, clear error instead of constructing a
dataloader with graph_config=None; inside the mixed_batch branch (where
descriptor is obtained from model_for_graph.atomic_model.descriptor and
make_lmdb_mixed_batch_collate is called), check hasattr(descriptor, "repflows")
and if False raise a ValueError/RuntimeError with a message like
"mixed_batch=True requires a flat-graph capable descriptor with 'repflows' (used
by make_lmdb_mixed_batch_collate and
EnergyModel.forward/forward_common_flat_native) — disable mixed_batch or provide
a compatible descriptor." This ensures failure occurs at dataloader construction
rather than later in EnergyModel.forward.

In `@deepmd/pt/utils/lmdb_dataset.py`:
- Around line 42-49: The __all__ export list is not alphabetically sorted which
triggers Ruff RUF022; reorder the list items in the __all__ variable so they are
lexicographically sorted (e.g., "_collate_lmdb_batch",
"_collate_lmdb_mixed_batch", "LmdbDataset", "LmdbTestData", "is_lmdb",
"make_lmdb_mixed_batch_collate") while preserving the existing string entries,
quotes, and list structure so the symbol names (LmdbDataset, LmdbTestData,
_collate_lmdb_batch, _collate_lmdb_mixed_batch, make_lmdb_mixed_batch_collate,
is_lmdb) remain unchanged.
- Around line 247-255: Mixed batching currently replaces the shuffle-based
DataLoader and ignores computed block weights, so when mixed_batch is True you
must construct _inner_dataloader to honor _block_targets computed by
compute_block_targets() rather than using DataLoader(..., shuffle=True). Replace
the plain shuffle=True approach in the mixed_batch branch by creating and
passing a sampler that uses _block_targets (e.g.,
torch.utils.data.WeightedRandomSampler or an equivalent custom sampler) to
DataLoader, keep num_workers=0 and collate_fn=_collate_lmdb_mixed_batch, and
ensure _inner_dataloader is assigned the DataLoader configured with that sampler
so auto_prob_style/block weighting still affects sampling.
- Around line 51-61: The whitelisted atom-wise keys in
_ATOMWISE_MIXED_BATCH_KEYS are missing supported labels (e.g. "atom_ener",
"drdq"), causing those fields to be handled by collate_tensor_fn and break
mixed-batch logic; update the frozenset declaration named
_ATOMWISE_MIXED_BATCH_KEYS to include all atom-wise labels used elsewhere (at
minimum "atom_ener" and "drdq") and make the same addition to the duplicate
whitelist further down (the other frozenset at lines ~99-104) so mixed-batch
flattening consistently handles these fields instead of falling back to
collate_tensor_fn.

In `@deepmd/pt/utils/nlist.py`:
- Around line 246-259: The code currently inverts atom_cell (atom_cell =
cell[extended_batch]) per extended atom which is wasteful; instead compute the
inverse once per frame by calling torch.linalg.inv_ex on cell (after cell =
box.reshape(-1,3,3)) and then index the resulting rec_cell with extended_batch
(i.e., rec_cell_full, _ = torch.linalg.inv_ex(cell); rec_cell =
rec_cell_full[extended_batch]) so subsequent einsum uses the indexed per-atom
inverses without repeating the matrix inversion.

In `@test_mixed_batch.sh`:
- Around line 15-19: The script prints "Check mixed_batch_train.log for details"
but never writes that file; fix by redirecting the training command's
stdout/stderr to mixed_batch_train.log (e.g., change the dp --pt train
test_mptraj/lmdb_mixed_batch.json --skip-neighbor-stat invocation to append or
write both stdout and stderr into mixed_batch_train.log) or, if you prefer not
to create a log, update the echo lines to not reference mixed_batch_train.log;
modify the dp invocation or the echo accordingly to keep the message accurate.

In `@test_mptraj/lmdb_baseline.json`:
- Around line 195-205: The baseline config currently hardcodes
developer-specific absolute paths in the keys stat_file, training_data.systems,
and validation_data.systems which breaks portability; change these values to
reusable placeholders or environment-variable expansions (e.g. ${DATASET_PATH})
or point them to a small example dataset shipped in the repo (mirror other
example configs), update any code that loads the config to resolve those env
vars, and add a short note in the config or README explaining how
test_mixed_batch.sh sets or substitutes those variables before running.

In `@test_mptraj/lmdb_mixed_batch.json`:
- Around line 74-84: The fixture uses hard-coded absolute paths in stat_file and
LMDB system entries (stat_file, training_data.systems, validation_data.systems)
which prevents CI portability; update the test_mptraj/lmdb_mixed_batch.json
fixture to reference repo-managed test data or injected paths by replacing
absolute paths with relative paths inside the repo (e.g., a test-data directory)
or with placeholders/env variables read by the test runner, and ensure the test
harness supplies those fixtures before running so the config uses
repo-controlled files rather than machine-specific absolute paths.

---

Outside diff comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 412-424: The generalized-force branch in ener.py assumes a single
natoms by reshaping force_pred, force_label and drdq with natoms * 3, which
breaks for mixed-size batches; instead, use frame-aware shapes: do not flatten
all frames into one dimension—reshape or iterate using per-frame atom counts (or
keep forces as (nframes, natoms_i, 3) and drdq as (nframes, natoms_i*3,
n_gcoord)) and compute gen_force and gen_force_label per-frame (or via batched
operations that use the original nframes dimension) so tensor dims align; also
apply find_drdq/pref_gf per-frame. Locate symbols: has_gf, drdq, find_drdq,
pref_gf, force_pred, force_label, force_reshape_nframes,
force_label_reshape_nframes, drdq_reshape, gen_force, gen_force_label and update
the reshaping/Einstein-summation to operate over the nframes dimension (or loop
over frames) instead of using natoms * 3 across the whole batch.

---

Nitpick comments:
In @.gitignore:
- Line 81: The .gitignore entry "deepmd-kit/" is too broad and may
unintentionally ignore legitimately nested repos; update the ignore rule that
currently uses deepmd-kit/ to a more specific pattern (for example scope it to
the repo root by changing to "/deepmd-kit/" or better yet target only
build/dist/artifact paths like "/deepmd-kit/build", "/deepmd-kit/dist", or
specific filenames/patterns such as "deepmd-kit/*.egg-info") so only the
intended artifact directories/files are ignored; locate the deepmd-kit/ entry in
the .gitignore and replace it with one or more narrow, explicit patterns.

In `@deepmd/pt/model/descriptor/repflows.py`:
- Around line 840-874: Add a short explanatory comment above the synthetic
batch=1 wrapping that states the disjoint-frame invariant: explain that the
temporary batching of tensors like node_ebd_batched, edge_ebd_batched,
h2_batched, angle_ebd_batched, nlist_batched, a_nlist_batched, etc. is safe
because RepFlowLayer.forward operates strictly per-atom and per-neighbor and the
LMDB collator guarantees nlist/a_nlist never reference atoms across different
frames (mapping/offsets keep frames disjoint); also warn that if
RepFlowLayer.forward (or any layer in self.layers) later introduces global ops
over the batch axis the synthetic batch will mix frames and break correctness,
so such changes must preserve the per-frame isolation or remove this batching
trick.

In `@deepmd/pt/model/model/ener_model.py`:
- Around line 92-176: The two branches in forward (flat and regular) duplicate
the same output-assembly logic for model_predict; extract that into a helper
method (e.g. _assemble_energy_predict(model_ret, do_atomic_virial)) that
encapsulates setting "atom_energy", "energy", choosing "force" vs "dforce" using
do_grad_r("energy"), adding "virial" and "atom_virial" via do_grad_c("energy")
and do_atomic_virial, including "mask" and "hessian" when present/when
_hessian_enabled; then call this helper from both the flat-path and regular-path
after receiving model_ret, and keep the flat/regular-specific behavior (the
no-fitting-net case that does model_predict["updated_coord"] += coord) in the
respective branch only. Ensure the helper references self.do_grad_r,
self.do_grad_c, self._hessian_enabled and self.get_fitting_net() behavior
remains consistent.

In `@deepmd/pt/model/model/make_model.py`:
- Around line 520-599: forward_common_flat is a pure passthrough to
forward_common_flat_native with duplicated docstrings, creating maintenance
overhead; either remove the redundant wrapper by merging
forward_common_flat_native into forward_common_flat (rename/remove the _native
variant and keep a single method and docstring) or add a concise
comment/docstring to forward_common_flat explaining the explicit reason for the
split (e.g., JIT/export boundary or future divergence). Update any call sites
that reference forward_common_flat_native to use the single chosen method name
(forward_common_flat) and ensure tests/type hints still pass.
- Around line 296-309: Move the helper import for
rebuild_extended_coord_from_flat_graph out of the hot forward path and add it to
the module-level nlist import block with the other deepmd.pt.utils.nlist imports
so it is imported once at load time; also change the validation order inside the
flat mixed-batch branch so you check for presence of the required graph fields
(mapping, extended_batch, extended_image, etc.) and raise the RuntimeError
before mutating tensors, and only call coord.requires_grad_(True) /
box.requires_grad_(True) (and any clone+detach) after the validation passes to
avoid wasted work when validation fails.
- Around line 311-314: The four assert statements for extended_atype,
extended_batch, mapping, and nlist are redundant because the preceding explicit
None-check/else already guarantees they are not None; either remove those assert
lines or replace them with a short explanatory comment indicating they exist
purely for static type-narrowing (for tools like mypy) so future readers
understand their purpose; locate the checks around the make_model function where
extended_atype, extended_batch, mapping, and nlist are validated and update
accordingly.

In `@deepmd/pt/model/network/graph_utils_flat.py`:
- Around line 41-89: Add a defensive check to guard against mis-configured
inputs by asserting that a_sel (from a_nlist_mask.shape[1]) does not exceed nnei
(from nlist_flat.shape[1]) near the top of the function; insert something like
an assert or raise ValueError after computing nnei and a_sel so that functions
using n2e_index, edge_lookup_a, and the a_nlist_mask_3d broadcasts fail fast
with a clear message instead of producing downstream broadcast errors.

In `@deepmd/pt/model/task/invar_fitting.py`:
- Around line 299-315: The tensor un-padding currently uses a shape heuristic in
the loop that builds result_flat; instead use the model's declared output types
from output_def() to decide which keys are atom-wise. Replace the shape-based
condition in the for key, value in result_batch loop with a lookup like
output_def()[key].is_atomic (or equivalent in your model API) and only apply
valid_atom_mask to values for keys declared atomic; leave frame-level or other
outputs untouched and handle non-tensor values as before (keep variable names
result_batch, result_flat, valid_atom_mask, and call output_def() to determine
atomic vs frame-level).
- Around line 259-264: The padding value for atype_batch is set to -1 which
causes PyTorch indexing (e.g., in self.forward when doing per-type lookups like
bias_atom_e[atype] or embeddings) to wrap to the last type; change the sentinel
to a valid non-negative index (e.g., 0 or a dedicated PAD type index) when
creating atype_batch so lookups won't accidentally reference real types, and
keep using valid_atom_mask to discard those rows downstream; update places that
assume negative sentinel behavior to use the new PAD index (refer to
atype_batch, self.forward, bias_atom_e, and valid_atom_mask).

In `@deepmd/pt/train/training.py`:
- Around line 1690-1723: In get_data, remove the special-case skipping of
"batch" and "ptr" inside the batch_data loop and delete the separate post-loop
device transfers for batch_data["batch"] and batch_data["ptr"]; instead let the
existing non-list tensor branch handle them (they'll be moved to DEVICE
automatically), and only keep the is_mixed_batch detection so you can extend
input_keys with list(_FLAT_GRAPH_INPUT_KEYS) when is_mixed_batch is true
(symbols: get_data, batch_data, is_mixed_batch, input_keys,
_FLAT_GRAPH_INPUT_KEYS, DEVICE).

In `@deepmd/pt/train/wrapper.py`:
- Around line 237-247: The natoms value currently has dual semantics (per-frame
nloc when atype.dim()>1 vs total flattened atom count when 1-D) which is
implicit and fragile; fix by making the contract explicit: compute both
nloc_per_frame (e.g., atype.shape[-1] when atype.dim()>1 else None) and
total_atoms (atype.numel() or atype.shape[0]) in wrapper.py and change the call
to self.loss[task_key](input_dict, self.model[task_key], label,
nloc_per_frame=nloc_per_frame, total_atoms=total_atoms, learning_rate=cur_lr),
then update all loss implementations (ener.py) and any factory in make_model.py
to accept these new named args (or provide backward-compatible handling) so
downstream code no longer must infer semantics from ptr.

In `@deepmd/pt/utils/nlist.py`:
- Around line 101-170: The loop is causing repeated device syncs via
int(ptr[frame_idx].item()) and per-frame kernel overhead in
extend_coord_with_ghosts_with_images and build_neighbor_list; fix by
materializing ptr on CPU once (e.g., ptr_cpu = ptr.cpu().tolist()) and use
ptr_cpu[frame_idx]/ptr_cpu[frame_idx+1] for start_idx/end_idx to eliminate
.item() calls, and where possible call
normalize_coord/extend_coord_with_ghosts_with_images/build_neighbor_list on
already-CPU tensors or add a batched/padded variant (similar to
extend_input_and_build_neighbor_list_with_images) to process frames in a
vectorized way to reduce Python per-frame overhead.
- Around line 262-281: Add a one-line doc comment inside get_central_ext_index
stating the invariant it relies on: that the extended layout produced by
_extend_coord_with_ghosts_impl places the unshifted image (shift (0,0,0)) first
for each frame and thus the first nloc extended atoms per frame are the local
atoms; mention that shift indices are sorted by L2 norm so the (0,0,0) image is
first and changing that ordering will break get_central_ext_index.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b6ad87fb-4157-40de-8f34-510d8232bbef

📥 Commits

Reviewing files that changed from the base of the PR and between 78bffde and 9f4ea0a.

📒 Files selected for processing (21)
  • .gitignore
  • deepmd/pt/entrypoints/main.py
  • deepmd/pt/loss/ener.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/descriptor/dpa3.py
  • deepmd/pt/model/descriptor/env_mat.py
  • deepmd/pt/model/descriptor/repflows.py
  • deepmd/pt/model/model/__init__.py
  • deepmd/pt/model/model/ener_model.py
  • deepmd/pt/model/model/make_model.py
  • deepmd/pt/model/network/graph_utils_flat.py
  • deepmd/pt/model/task/invar_fitting.py
  • deepmd/pt/train/training.py
  • deepmd/pt/train/wrapper.py
  • deepmd/pt/utils/lmdb_dataset.py
  • deepmd/pt/utils/nlist.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_lmdb_dataloader.py
  • test_mixed_batch.sh
  • test_mptraj/lmdb_baseline.json
  • test_mptraj/lmdb_mixed_batch.json

Comment thread .gitignore Outdated
Comment thread deepmd/pt/loss/ener.py
Comment thread deepmd/pt/model/atomic_model/dp_atomic_model.py Outdated
Comment thread deepmd/pt/model/model/__init__.py
Comment thread deepmd/pt/model/model/make_model.py Outdated
Comment thread deepmd/pt/utils/lmdb_dataset.py
Comment thread deepmd/pt/utils/nlist.py
Comment thread test_mixed_batch.sh Outdated
Comment thread test_mptraj/lmdb_baseline.json Outdated
Comment thread test_mptraj/lmdb_mixed_batch.json Outdated
Signed-off-by: liwentao <liwt24@mails.tsinghua.edu.cn>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
deepmd/pt/loss/ener.py (1)

237-249: ⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy lift

Mixed-batch normalization computes the wrong weighted average.

The current implementation averages 1/N_i across frames first (line 247), then multiplies that scalar by the frame-averaged loss. This is not equivalent to applying each frame's 1/N_i weight before averaging.

Example demonstrating the mismatch:

  • Frame 1: N₁=10, error²=100
  • Frame 2: N₂=100, error²=400

Current (incorrect):
mean(1/10, 1/100) × mean(100, 400) = 0.055 × 250 = 13.75

Correct:
mean((1/10)×100, (1/100)×400) = mean(10, 4) = 7.0

Impact: Energy and virial losses (MSE/MAE/Huber paths) and their metrics are all misweighted, breaking gradient correctness for mixed-batch training.

Required fix: Compute per-frame atom_norms = 1.0 / natoms_per_frame (and raise to norm_exp when needed), apply them element-wise to per-frame errors before the reduction (mean/Huber), then remove the post-reduction scalar multiplication by atom_norm.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/loss/ener.py` around lines 237 - 249, The mixed-batch path
currently computes a scalar atom_norm by averaging 1/N_i (is_mixed_batch / ptr /
natoms_per_frame -> atom_norm) and multiplies it after the loss reduction, which
is incorrect; instead compute per-frame weights atom_norms = 1.0 /
natoms_per_frame (and if a norm exponent is used, apply atom_norms =
atom_norms.pow(norm_exp)), then apply these weights element-wise to the
per-frame error terms before calling the reduction (mean/Huber) for energy and
virial loss paths, and remove the later multiplication by the scalar atom_norm;
adjust the branches that handle MSE/MAE/Huber so they accept per-frame-weighted
errors rather than multiplying a scalar after reduction.
🧹 Nitpick comments (2)
deepmd/pt/model/descriptor/dpa3.py (1)

613-613: ⚡ Quick win

Align forward_flat return type with actual None-capable values.

The signature says dict[str, torch.Tensor], but "rot_mat", "g2", and "h2" may be None. This can mislead callers and type checks.

Proposed fix
-    ) -> dict[str, torch.Tensor]:
+    ) -> dict[str, torch.Tensor | None]:

Also applies to: 704-718

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/model/descriptor/dpa3.py` at line 613, The return type of
forward_flat is too narrow: some keys ("rot_mat", "g2", "h2") can be None, so
update the signature to reflect Optional values (e.g. change -> dict[str,
Optional[torch.Tensor]]) and import typing.Optional; apply the same change to
the other similar function(s) noted (the second return signature around the
704–718 region) so callers and type checkers know those keys may be None.
deepmd/pt/loss/ener.py (1)

243-243: 💤 Low value

Remove unused variable.

nframes is computed but never used.

🧹 Proposed fix
-        nframes = ptr.numel() - 1
         # Compute natoms for each frame
         natoms_per_frame = ptr[1:] - ptr[:-1]  # [nframes]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/loss/ener.py` at line 243, The variable nframes is computed but
never used; remove the unused assignment "nframes = ptr.numel() - 1" (or replace
it with a deliberate discard like "_ = ptr.numel() - 1" if a side-effect is
required) in the function that contains the ptr reference so no dead local
remains and linter warnings go away.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/pt/utils/lmdb_dataset.py`:
- Line 102: The call to collate_tensor_fn in lmdb_dataset.py is failing because
collate_tensor_fn is not defined; import PyTorch's collate utility and wire it
up. Add an import for default_collate from torch.utils.data._utils.collate (or
the PyTorch collate function available in your env) and either replace
collate_tensor_fn calls with default_collate or add an alias like
collate_tensor_fn = default_collate at module scope so the line result[key] =
collate_tensor_fn(tensors) works; reference the symbol collate_tensor_fn in
lmdb_dataset.py to locate where to add the import/alias.

---

Duplicate comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 237-249: The mixed-batch path currently computes a scalar
atom_norm by averaging 1/N_i (is_mixed_batch / ptr / natoms_per_frame ->
atom_norm) and multiplies it after the loss reduction, which is incorrect;
instead compute per-frame weights atom_norms = 1.0 / natoms_per_frame (and if a
norm exponent is used, apply atom_norms = atom_norms.pow(norm_exp)), then apply
these weights element-wise to the per-frame error terms before calling the
reduction (mean/Huber) for energy and virial loss paths, and remove the later
multiplication by the scalar atom_norm; adjust the branches that handle
MSE/MAE/Huber so they accept per-frame-weighted errors rather than multiplying a
scalar after reduction.

---

Nitpick comments:
In `@deepmd/pt/loss/ener.py`:
- Line 243: The variable nframes is computed but never used; remove the unused
assignment "nframes = ptr.numel() - 1" (or replace it with a deliberate discard
like "_ = ptr.numel() - 1" if a side-effect is required) in the function that
contains the ptr reference so no dead local remains and linter warnings go away.

In `@deepmd/pt/model/descriptor/dpa3.py`:
- Line 613: The return type of forward_flat is too narrow: some keys ("rot_mat",
"g2", "h2") can be None, so update the signature to reflect Optional values
(e.g. change -> dict[str, Optional[torch.Tensor]]) and import typing.Optional;
apply the same change to the other similar function(s) noted (the second return
signature around the 704–718 region) so callers and type checkers know those
keys may be None.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9f5f5f91-6d22-4772-adac-e7c96dc0c041

📥 Commits

Reviewing files that changed from the base of the PR and between 9f4ea0a and 38c86d6.

📒 Files selected for processing (6)
  • deepmd/pt/loss/ener.py
  • deepmd/pt/model/descriptor/dpa3.py
  • deepmd/pt/model/descriptor/repflows.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/lmdb_dataset.py
  • deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/utils/argcheck.py
  • deepmd/pt/train/training.py
  • deepmd/pt/model/descriptor/repflows.py

Comment thread deepmd/pt/utils/lmdb_dataset.py
@github-actions github-actions Bot added the Docs label May 20, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

933-960: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Distributed path references DDP and LOCAL_RANK without imports.

This branch will fail with NameError in distributed runs at lines 934, 955-959.

💡 Proposed fix
 from deepmd.pt.utils.env import (
     DEVICE,
     JIT,
+    LOCAL_RANK,
     NUM_WORKERS,
     SAMPLER_RECORD,
 )
+from torch.nn.parallel import (
+    DistributedDataParallel as DDP,
+)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/train/training.py` around lines 933 - 960, The distributed branch
references DDP, dist and LOCAL_RANK but they are not imported/defined; add the
necessary imports and LOCAL_RANK definition: import torch.distributed as dist
and import DistributedDataParallel as DDP (from torch.nn.parallel import
DistributedDataParallel as DDP), and ensure LOCAL_RANK is defined (e.g.
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) or equivalent) before use so
torch.cuda.set_device(LOCAL_RANK), dist.broadcast(...) and DDP(...) work
correctly.
♻️ Duplicate comments (1)
deepmd/pt/utils/lmdb_dataset.py (1)

106-112: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

collate_tensor_fn is undefined in mixed-batch collation.

result[key] = collate_tensor_fn(tensors) will raise NameError on the first frame-wise key that hits this branch.

💡 Proposed fix
 from torch.utils.data import (
     DataLoader,
     Dataset,
     Sampler,
 )
+from torch.utils.data._utils.collate import (
+    collate_tensor_fn,
+)
#!/bin/bash
# Verify unresolved symbol + missing import in this file.
rg -n "collate_tensor_fn" deepmd/pt/utils/lmdb_dataset.py
python - <<'PY'
import ast, pathlib
p = pathlib.Path("deepmd/pt/utils/lmdb_dataset.py")
tree = ast.parse(p.read_text())
imports = set()
for n in ast.walk(tree):
    if isinstance(n, ast.ImportFrom):
        mod = n.module or ""
        for a in n.names:
            imports.add((mod, a.name))
print(("torch.utils.data._utils.collate", "collate_tensor_fn") in imports)
PY
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/utils/lmdb_dataset.py` around lines 106 - 112, The code calls an
undefined symbol collate_tensor_fn in deepmd/pt/utils/lmdb_dataset.py (inside
the mixed-batch collation branch that assigns result[key] when key not in
_ATOMWISE_MIXED_BATCH_KEYS), which causes a NameError; fix by importing the
proper collate helper from PyTorch (e.g., add from
torch.utils.data._utils.collate import collate_tensor_fn at the top) or replace
collate_tensor_fn with the correct public API (e.g.,
torch.utils.data.default_collate) and update the call in the block that builds
result[key] so the collate function is defined and used consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@deepmd/pt/train/training.py`:
- Around line 933-960: The distributed branch references DDP, dist and
LOCAL_RANK but they are not imported/defined; add the necessary imports and
LOCAL_RANK definition: import torch.distributed as dist and import
DistributedDataParallel as DDP (from torch.nn.parallel import
DistributedDataParallel as DDP), and ensure LOCAL_RANK is defined (e.g.
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) or equivalent) before use so
torch.cuda.set_device(LOCAL_RANK), dist.broadcast(...) and DDP(...) work
correctly.

---

Duplicate comments:
In `@deepmd/pt/utils/lmdb_dataset.py`:
- Around line 106-112: The code calls an undefined symbol collate_tensor_fn in
deepmd/pt/utils/lmdb_dataset.py (inside the mixed-batch collation branch that
assigns result[key] when key not in _ATOMWISE_MIXED_BATCH_KEYS), which causes a
NameError; fix by importing the proper collate helper from PyTorch (e.g., add
from torch.utils.data._utils.collate import collate_tensor_fn at the top) or
replace collate_tensor_fn with the correct public API (e.g.,
torch.utils.data.default_collate) and update the call in the block that builds
result[key] so the collate function is defined and used consistently.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 0a52198d-56d6-4626-8400-09f27d5b022c

📥 Commits

Reviewing files that changed from the base of the PR and between 38c86d6 and 57e3ffd.

📒 Files selected for processing (8)
  • deepmd/pt/loss/ener.py
  • deepmd/pt/model/atomic_model/dp_atomic_model.py
  • deepmd/pt/model/model/make_model.py
  • deepmd/pt/train/training.py
  • deepmd/pt/utils/lmdb_dataset.py
  • doc/development/lmdb-mixed-system-batching.md
  • source/tests/pt/test_lmdb_dataloader.py
  • source/tests/pt/test_loss.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/utils/argcheck.py (1)

3763-3813: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Clarify that batch_size becomes a frame count in mixed mode.

The new mixed_batch docs explain the flattened LMDB path, but they never say that batch_size is reinterpreted as the number of frames/systems in a mixed batch. That is the part most likely to change users’ memory sizing and throughput expectations when they flip this flag.

Also applies to: 3858-3918

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/utils/argcheck.py` around lines 3763 - 3813, Update the documentation
strings so users know that when the Argument "mixed_batch" (alias "mix_batch")
is True the meaning of the "batch_size" Argument changes: it is interpreted as
the number of frames/systems (i.e., frame count) in a mixed batch rather than
the per-sample/atom count used in non-mixed mode. Modify the doc_mixed_batch
text (and also the doc_batch_size string referenced by the "batch_size"
Argument) to explicitly state this behavioral change and its implications for
memory sizing and throughput so both locations (the mixed_batch Argument and the
batch_size Argument) describe the reinterpreted unit.
deepmd/pt/loss/ener.py (1)

447-466: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard drdq loss when ptr indicates a mixed batch.

This branch still reshapes force/drdq with a single natoms, so it cannot preserve frame boundaries once a batch contains mixed nloc. In mixed mode this will either mis-shape tensors or compute generalized-force terms across the wrong atoms.

💡 Minimal safe fix
             if self.has_gf and "drdq" in label:
+                if is_mixed_batch:
+                    raise NotImplementedError(
+                        "Generalized force loss is not supported with mixed_batch inputs."
+                    )
                 drdq = label["drdq"]
                 find_drdq = label.get("find_drdq", 0.0)
                 pref_gf = pref_gf * find_drdq
                 force_reshape_nframes = force_pred.reshape(-1, natoms * 3)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt/loss/ener.py` around lines 447 - 466, The generalized-force block
(has_gf branch using drdq, drdq_reshape, force_pred, force_label, natoms,
numb_generalized_coord) assumes a single uniform natoms and thus breaks for
mixed batches indicated by ptr; either skip computing l2_gen_force_loss when ptr
shows mixed nloc or compute it per-frame using ptr slices so frame boundaries
are preserved. Update the branch to check the batch pointer (ptr) / mixed-batch
indicator before reshaping: if ptr indicates mixed sizes, loop over frames using
ptr to slice force_pred, force_label and drdq and compute
gen_force/gen_force_label per slice (accumulating weighted loss), otherwise keep
the existing vectorized reshape path; ensure the guard uses the same symbol the
code provides for the pointer and does not try to reshape using a single natoms
when ptr denotes mixed frames.
♻️ Duplicate comments (1)
test_mptraj/lmdb_mixed_batch.json (1)

74-83: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Replace remaining absolute dataset paths with injectable placeholders.

Line 76 and Line 82 still hard-code machine-specific LMDB paths, so this fixture is not reproducible in CI/other dev environments. Please switch both to env placeholders (consistent with lmdb_baseline.json) or repo-managed test paths.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@test_mptraj/lmdb_mixed_batch.json` around lines 74 - 83, The fixture
hard-codes machine-specific LMDB paths in the JSON keys training_data.systems
and validation_data.systems; replace those absolute paths with injectable
placeholders (e.g. "${LMDB_MPTRAJ_PATH}" and "${LMDB_WBM_PATH}" or the same
placeholders used in lmdb_baseline.json) or point them to repo-managed test
paths, keeping batch_size and mixed_batch entries unchanged so CI/dev
environments can supply paths via env vars or test config.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 447-466: The generalized-force block (has_gf branch using drdq,
drdq_reshape, force_pred, force_label, natoms, numb_generalized_coord) assumes a
single uniform natoms and thus breaks for mixed batches indicated by ptr; either
skip computing l2_gen_force_loss when ptr shows mixed nloc or compute it
per-frame using ptr slices so frame boundaries are preserved. Update the branch
to check the batch pointer (ptr) / mixed-batch indicator before reshaping: if
ptr indicates mixed sizes, loop over frames using ptr to slice force_pred,
force_label and drdq and compute gen_force/gen_force_label per slice
(accumulating weighted loss), otherwise keep the existing vectorized reshape
path; ensure the guard uses the same symbol the code provides for the pointer
and does not try to reshape using a single natoms when ptr denotes mixed frames.

In `@deepmd/utils/argcheck.py`:
- Around line 3763-3813: Update the documentation strings so users know that
when the Argument "mixed_batch" (alias "mix_batch") is True the meaning of the
"batch_size" Argument changes: it is interpreted as the number of frames/systems
(i.e., frame count) in a mixed batch rather than the per-sample/atom count used
in non-mixed mode. Modify the doc_mixed_batch text (and also the doc_batch_size
string referenced by the "batch_size" Argument) to explicitly state this
behavioral change and its implications for memory sizing and throughput so both
locations (the mixed_batch Argument and the batch_size Argument) describe the
reinterpreted unit.

---

Duplicate comments:
In `@test_mptraj/lmdb_mixed_batch.json`:
- Around line 74-83: The fixture hard-codes machine-specific LMDB paths in the
JSON keys training_data.systems and validation_data.systems; replace those
absolute paths with injectable placeholders (e.g. "${LMDB_MPTRAJ_PATH}" and
"${LMDB_WBM_PATH}" or the same placeholders used in lmdb_baseline.json) or point
them to repo-managed test paths, keeping batch_size and mixed_batch entries
unchanged so CI/dev environments can supply paths via env vars or test config.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ee28998d-94e8-4b3f-9a65-611aef50db4f

📥 Commits

Reviewing files that changed from the base of the PR and between 57e3ffd and 57369be.

📒 Files selected for processing (6)
  • .gitignore
  • deepmd/pt/loss/ener.py
  • deepmd/utils/argcheck.py
  • test_mixed_batch.sh
  • test_mptraj/lmdb_baseline.json
  • test_mptraj/lmdb_mixed_batch.json
✅ Files skipped from review due to trivial changes (2)
  • .gitignore
  • test_mixed_batch.sh


# Compute virial: dE/dh
if self.do_grad_c("energy"):
nframes = ptr.numel() - 1
Copy link
Copy Markdown
Member

@iProzd iProzd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Implementation in other backend (at least dpmodel backend) should be added.

  2. Tests for new forward method, such as universal tests in source/tests/universal, end-to-end training in source/tests/pt/test_training.py and consistent tests in source/tests/consistent should be added.

Comment thread .gitignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Comment thread test_mixed_batch.sh
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc should be written in english, and placed as a feature introduction instead of a development doc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants