Mix batch 0429#5439
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds LMDB mixed-batch support: flat-graph precomputation, collate functions that flatten atom-wise fields, new flat forward paths through descriptor/fitting/atomic/model layers, per-frame loss normalization, and training integration with sampler/collate changes and tests/docs. ChangesMixed-batch LMDB Training Pipeline
Sequence Diagram(s)sequenceDiagram
participant DataLoader
participant Trainer
participant Wrapper
participant Model_CM
participant Descriptor_Repflows
participant FittingNet
participant Loss_EnerStd
DataLoader->>Trainer: mixed-batch batch (coord, atype, batch, ptr, ...)
Trainer->>Wrapper: device-transfer & extended keys
Wrapper->>Model_CM: forward_common_flat args (flat graph tensors)
Model_CM->>Descriptor_Repflows: descriptor.forward_flat(extended_coord, nlist, ...)
Descriptor_Repflows->>FittingNet: node descriptors -> fitting_net.forward_flat
FittingNet->>Model_CM: atomic predictions (atom-wise)
Model_CM->>Loss_EnerStd: aggregated predictions + ptr/batch for per-frame norms
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 14
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/loss/ener.py (1)
412-424:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftGeneralized-force loss still assumes a single
natomsfor every frame.This branch reshapes
forceanddrdqwithnatoms * 3, which only works for uniform-size batches. In mixed-batch mode those tensors are flattened across frames, so enabling generalized-force loss here will mis-shape the data or crash.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/loss/ener.py` around lines 412 - 424, The generalized-force branch in ener.py assumes a single natoms by reshaping force_pred, force_label and drdq with natoms * 3, which breaks for mixed-size batches; instead, use frame-aware shapes: do not flatten all frames into one dimension—reshape or iterate using per-frame atom counts (or keep forces as (nframes, natoms_i, 3) and drdq as (nframes, natoms_i*3, n_gcoord)) and compute gen_force and gen_force_label per-frame (or via batched operations that use the original nframes dimension) so tensor dims align; also apply find_drdq/pref_gf per-frame. Locate symbols: has_gf, drdq, find_drdq, pref_gf, force_pred, force_label, force_reshape_nframes, force_label_reshape_nframes, drdq_reshape, gen_force, gen_force_label and update the reshaping/Einstein-summation to operate over the nframes dimension (or loop over frames) instead of using natoms * 3 across the whole batch.
🧹 Nitpick comments (13)
deepmd/pt/model/descriptor/repflows.py (1)
840-874: 💤 Low valueMake the disjoint-frame invariant explicit.
The synthetic batch=1 wrap is only correct because
RepFlowLayer.forwardis strictly per-atom + per-neighbor, and becausenlist/a_nlisthere never cross frame boundaries (the LMDB collator builds per-frame extended atoms and offsets neighbor indices accordingly, then maps back via per-framemapping). A brief comment near this block stating that invariant — and noting that any future RepFlowLayer change introducing global ops over the batch axis would silently mix frames — will save a future debugging session.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/repflows.py` around lines 840 - 874, Add a short explanatory comment above the synthetic batch=1 wrapping that states the disjoint-frame invariant: explain that the temporary batching of tensors like node_ebd_batched, edge_ebd_batched, h2_batched, angle_ebd_batched, nlist_batched, a_nlist_batched, etc. is safe because RepFlowLayer.forward operates strictly per-atom and per-neighbor and the LMDB collator guarantees nlist/a_nlist never reference atoms across different frames (mapping/offsets keep frames disjoint); also warn that if RepFlowLayer.forward (or any layer in self.layers) later introduces global ops over the batch axis the synthetic batch will mix frames and break correctness, so such changes must preserve the per-frame isolation or remove this batching trick.deepmd/pt/utils/nlist.py (2)
101-170: 🏗️ Heavy liftPer-frame Python loop with
.item()calls forces a CPU/GPU sync each iteration.For large
batch_size(e.g., 128 inlmdb_baseline.json), this loop incurs2 * batch_sizedevice syncs fromint(ptr[...].item())plus per-frame kernel launches forextend_coord_with_ghosts_with_imagesandbuild_neighbor_list. Even though it runs in dataloader workers, this can dominate end-to-end throughput when frames have small atom counts. Two relatively cheap mitigations:
- Pull
ptrto CPU once before the loop (ptr_cpu = ptr.cpu().tolist()) and index lists directly, eliminating per-iteration syncs.- If the data loader ultimately runs on CPU tensors (typical for collate_fn), the per-frame work is largely free of device sync but the loop still costs Python overhead — consider a vectorized batched ghost-extension that operates on a padded tensor and then re-flattens, mirroring the existing
extend_input_and_build_neighbor_list_with_images.Not blocking, but worth measuring before this lands as the default mixed-batch path.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/utils/nlist.py` around lines 101 - 170, The loop is causing repeated device syncs via int(ptr[frame_idx].item()) and per-frame kernel overhead in extend_coord_with_ghosts_with_images and build_neighbor_list; fix by materializing ptr on CPU once (e.g., ptr_cpu = ptr.cpu().tolist()) and use ptr_cpu[frame_idx]/ptr_cpu[frame_idx+1] for start_idx/end_idx to eliminate .item() calls, and where possible call normalize_coord/extend_coord_with_ghosts_with_images/build_neighbor_list on already-CPU tensors or add a batched/padded variant (similar to extend_input_and_build_neighbor_list_with_images) to process frames in a vectorized way to reduce Python per-frame overhead.
262-281: 💤 Low valueDocument the "first nloc extended atoms per frame are local" assumption.
get_central_ext_indexrelies on the layout produced by_extend_coord_with_ghosts_impl, where shift indices are sorted by L2 norm and(0,0,0)is therefore the first image — making the firstnlocextended atoms in each frame the local ones. This invariant is implicit; a one-line comment will save a future maintainer who changes the ghost layout from silently breaking this helper.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/utils/nlist.py` around lines 262 - 281, Add a one-line doc comment inside get_central_ext_index stating the invariant it relies on: that the extended layout produced by _extend_coord_with_ghosts_impl places the unshifted image (shift (0,0,0)) first for each frame and thus the first nloc extended atoms per frame are the local atoms; mention that shift indices are sorted by L2 norm so the (0,0,0) image is first and changing that ordering will break get_central_ext_index..gitignore (1)
81-81: 💤 Low value
deepmd-kit/ignore pattern is overly broad.Ignoring a directory named after the project itself can mask legitimate content if the repo or a checkout is ever placed under a
deepmd-kit/subdirectory (e.g., nested workspaces, tooling that clones into a project-named folder). Consider narrowing or scoping the rule to only the artifact path you intend to ignore.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In @.gitignore at line 81, The .gitignore entry "deepmd-kit/" is too broad and may unintentionally ignore legitimately nested repos; update the ignore rule that currently uses deepmd-kit/ to a more specific pattern (for example scope it to the repo root by changing to "/deepmd-kit/" or better yet target only build/dist/artifact paths like "/deepmd-kit/build", "/deepmd-kit/dist", or specific filenames/patterns such as "deepmd-kit/*.egg-info") so only the intended artifact directories/files are ignored; locate the deepmd-kit/ entry in the .gitignore and replace it with one or more narrow, explicit patterns.deepmd/pt/model/task/invar_fitting.py (2)
299-315: 💤 Low valueAtom-wise output detection is shape-heuristic; consider being explicit.
The current rule un-pads any tensor whose first two dims happen to be
(nframes, max_nloc)and leaves everything else untouched. That works for the current outputs (var_name,middle_output), but a future output that returns frame-level shape(nframes, max_nloc)would be silently un-padded, and any per-atom output whose first axis is something other thannframeswould be silently passed through. Consider keying offoutput_def()(which already declares atomic vs. frame-level variables) instead of tensor shape.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/task/invar_fitting.py` around lines 299 - 315, The tensor un-padding currently uses a shape heuristic in the loop that builds result_flat; instead use the model's declared output types from output_def() to decide which keys are atom-wise. Replace the shape-based condition in the for key, value in result_batch loop with a lookup like output_def()[key].is_atomic (or equivalent in your model API) and only apply valid_atom_mask to values for keys declared atomic; leave frame-level or other outputs untouched and handle non-tensor values as before (keep variable names result_batch, result_flat, valid_atom_mask, and call output_def() to determine atomic vs frame-level).
259-264: 💤 Low valuePadding
atype = -1silently wraps to the last type in PyTorch indexing.Per-type lookups inside
self.forward(bias_atom_e/case embedding/etc.) usetensor[atype], where-1does not error out — it wraps to the last type. Outputs are masked away later, so this does not corrupt the result, but it does waste compute on the padding rows and can interact unexpectedly with any branch that special-cases negative types. Using a valid sentinel like0is safer and semantically clearer; thevalid_atom_maskcontinues to discard the rows.♻️ Suggested change
- atype_batch = torch.full( - (nframes, max_nloc), - -1, - dtype=atype.dtype, - device=device, - ) + # Use 0 as a safe dummy type for padding rows; outputs are discarded + # by valid_atom_mask below so the value is not observable. + atype_batch = torch.zeros( + (nframes, max_nloc), + dtype=atype.dtype, + device=device, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/task/invar_fitting.py` around lines 259 - 264, The padding value for atype_batch is set to -1 which causes PyTorch indexing (e.g., in self.forward when doing per-type lookups like bias_atom_e[atype] or embeddings) to wrap to the last type; change the sentinel to a valid non-negative index (e.g., 0 or a dedicated PAD type index) when creating atype_batch so lookups won't accidentally reference real types, and keep using valid_atom_mask to discard those rows downstream; update places that assume negative sentinel behavior to use the new PAD index (refer to atype_batch, self.forward, bias_atom_e, and valid_atom_mask).deepmd/pt/model/network/graph_utils_flat.py (1)
41-89: 💤 Low valueLGTM.
Edge index construction (
n2e_index/n_ext2e_index) and the angle index mapping viaedge_lookupare consistent with the consumers inrepflows.pyand the dynamic-selection path. Thea_nlist_mask_3dlogic correctly requires both neighbors of an angle to be valid before producing a triplet. The single CPU sync vian_edge = ... .item()is necessary fortorch.arange.One small optional suggestion: an
assert a_sel <= nneinear the top would surface mis-configured inputs as a clear error rather than a downstream broadcast failure (DescrptDPA3already enforcese_sel >= a_sel, so this is mostly defensive).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/network/graph_utils_flat.py` around lines 41 - 89, Add a defensive check to guard against mis-configured inputs by asserting that a_sel (from a_nlist_mask.shape[1]) does not exceed nnei (from nlist_flat.shape[1]) near the top of the function; insert something like an assert or raise ValueError after computing nnei and a_sel so that functions using n2e_index, edge_lookup_a, and the a_nlist_mask_3d broadcasts fail fast with a clear message instead of producing downstream broadcast errors.deepmd/pt/model/model/make_model.py (3)
520-599: 💤 Low value
forward_common_flatis a pure passthrough toforward_common_flat_native.The wrapper adds no behavior — same parameters, same return — and just forwards every argument. Two near-identical method definitions (with duplicated docstrings) is a maintenance hazard and adds friction for readers tracing the call. Either (a) inline
forward_common_flat_nativeintoforward_common_flatand drop the_nativevariant, or (b) document a clear reason for the split (e.g. JIT export boundary, planned non-native variant) so the duplication is justified.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/make_model.py` around lines 520 - 599, forward_common_flat is a pure passthrough to forward_common_flat_native with duplicated docstrings, creating maintenance overhead; either remove the redundant wrapper by merging forward_common_flat_native into forward_common_flat (rename/remove the _native variant and keep a single method and docstring) or add a concise comment/docstring to forward_common_flat explaining the explicit reason for the split (e.g., JIT/export boundary or future divergence). Update any call sites that reference forward_common_flat_native to use the single chosen method name (forward_common_flat) and ensure tests/type hints still pass.
296-309: 💤 Low valueMove the helper import to module level and tidy the validation order.
Two minor suggestions:
from deepmd.pt.utils.nlist import rebuild_extended_coord_from_flat_graphis imported every call inside the hot forward path (line 296). Promote it to the existing nlist import block at the top of the file.- The
requires_grad_(True)work oncoord/boxat lines 280-282 happens before the field-presence validation at lines 283-309. If the validation raises, the clone+detach is wasted. Reorder so the validation runs first.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/make_model.py` around lines 296 - 309, Move the helper import for rebuild_extended_coord_from_flat_graph out of the hot forward path and add it to the module-level nlist import block with the other deepmd.pt.utils.nlist imports so it is imported once at load time; also change the validation order inside the flat mixed-batch branch so you check for presence of the required graph fields (mapping, extended_batch, extended_image, etc.) and raise the RuntimeError before mutating tensors, and only call coord.requires_grad_(True) / box.requires_grad_(True) (and any clone+detach) after the validation passes to avoid wasted work when validation fails.
311-314: 💤 Low valueRedundant assertions after the explicit None-check.
Lines 283-295 already verify
extended_atype,extended_batch,mapping,nlist,nlist_ext, etc. are not None (otherwise theelseraises), so the fourassertstatements at lines 311-314 are redundant at runtime. If they exist solely for type-narrowing under static checkers, a comment to that effect would help future readers. Otherwise they can be removed.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/make_model.py` around lines 311 - 314, The four assert statements for extended_atype, extended_batch, mapping, and nlist are redundant because the preceding explicit None-check/else already guarantees they are not None; either remove those assert lines or replace them with a short explanatory comment indicating they exist purely for static type-narrowing (for tools like mypy) so future readers understand their purpose; locate the checks around the make_model function where extended_atype, extended_batch, mapping, and nlist are validated and update accordingly.deepmd/pt/train/training.py (1)
1690-1723: 💤 Low valueMinor structural cleanup in
get_datafor mixed-batch handling.The current loop skips
batch/ptr(lines 1697-1698) only to transfer them separately at lines 1722-1723. Since they are tensors, they would be handled correctly by the genericnot isinstance(..., list)branch at line 1699-1701. You can drop the explicit skip and the post-loop transfer, then just extendinput_keyswhenis_mixed_batchis true:♻️ Proposed simplification
for key in batch_data.keys(): if key == "sid" or key == "fid" or "find_" in key: continue - # Skip batch and ptr for now, will handle them separately - elif key == "batch" or key == "ptr": - continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) @@ # Mixed-nloc LMDB batches include precomputed flat-graph tensors. if is_mixed_batch: input_keys = input_keys + list(_FLAT_GRAPH_INPUT_KEYS) - batch_data["batch"] = batch_data["batch"].to(DEVICE, non_blocking=True) - batch_data["ptr"] = batch_data["ptr"].to(DEVICE, non_blocking=True)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/training.py` around lines 1690 - 1723, In get_data, remove the special-case skipping of "batch" and "ptr" inside the batch_data loop and delete the separate post-loop device transfers for batch_data["batch"] and batch_data["ptr"]; instead let the existing non-list tensor branch handle them (they'll be moved to DEVICE automatically), and only keep the is_mixed_batch detection so you can extend input_keys with list(_FLAT_GRAPH_INPUT_KEYS) when is_mixed_batch is true (symbols: get_data, batch_data, is_mixed_batch, input_keys, _FLAT_GRAPH_INPUT_KEYS, DEVICE).deepmd/pt/model/model/ener_model.py (1)
92-176: ⚡ Quick winExtract the shared output-assembly block to avoid maintaining two near-identical paths.
Lines 117-143 (flat path) and 153-174 (regular path) build
model_predictin essentially the same way (atom_energy/energy, conditionalforce/dforce, virials,mask,hessian). Two parallel implementations are a maintenance hazard — future output keys or shape adjustments must be mirrored in both, and any drift becomes silent. Consider extracting a small helper like_assemble_energy_predict(model_ret, do_atomic_virial)and calling it from both branches, with the flat branch additionally addingmodel_predict["updated_coord"] += coordin the no-fitting-net case.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/model/ener_model.py` around lines 92 - 176, The two branches in forward (flat and regular) duplicate the same output-assembly logic for model_predict; extract that into a helper method (e.g. _assemble_energy_predict(model_ret, do_atomic_virial)) that encapsulates setting "atom_energy", "energy", choosing "force" vs "dforce" using do_grad_r("energy"), adding "virial" and "atom_virial" via do_grad_c("energy") and do_atomic_virial, including "mask" and "hessian" when present/when _hessian_enabled; then call this helper from both the flat-path and regular-path after receiving model_ret, and keep the flat/regular-specific behavior (the no-fitting-net case that does model_predict["updated_coord"] += coord) in the respective branch only. Ensure the helper references self.do_grad_r, self.do_grad_c, self._hessian_enabled and self.get_fitting_net() behavior remains consistent.deepmd/pt/train/wrapper.py (1)
237-247: 💤 Low value
natomssemantics now diverge between flat and regular batches; document the contract.For regular batches
atype.dim() > 1andnatoms = atype.shape[-1]is per-frame nloc, but for mixed batchesatypeis 1-D andnatoms = atype.shape[0]is the total flattened atom count across all frames. Downstream loss code must distinguish these by inspectingptrininput_dictto normalize correctly — that's a fragile, implicit contract spread acrosswrapper.py,ener.py, andmake_model.py.Consider either (a) renaming/splitting the argument so callees see
total_atomsvsnloc_per_frameexplicitly, or (b) adding a docstring/inline note here that clearly states the dual semantics and that the loss must branch onptr. The current inline comment on lines 237-238 helps but does not surface the downstream consumer expectation.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/wrapper.py` around lines 237 - 247, The natoms value currently has dual semantics (per-frame nloc when atype.dim()>1 vs total flattened atom count when 1-D) which is implicit and fragile; fix by making the contract explicit: compute both nloc_per_frame (e.g., atype.shape[-1] when atype.dim()>1 else None) and total_atoms (atype.numel() or atype.shape[0]) in wrapper.py and change the call to self.loss[task_key](input_dict, self.model[task_key], label, nloc_per_frame=nloc_per_frame, total_atoms=total_atoms, learning_rate=cur_lr), then update all loss implementations (ener.py) and any factory in make_model.py to accept these new named args (or provide backward-compatible handling) so downstream code no longer must infer semantics from ptr.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In @.gitignore:
- Line 78: The .gitignore entry "test_mptraj/" conflicts with tracked files
committed in this PR (test_mptraj/lmdb_baseline.json and
test_mptraj/lmdb_mixed_batch.json); fix by either removing the directory-level
ignore line "test_mptraj/" or by adding explicit negation patterns for the
tracked files (e.g., "!test_mptraj/lmdb_baseline.json" and
"!test_mptraj/lmdb_mixed_batch.json") so future edits to those config files are
not hidden and git add won't require -f; update .gitignore accordingly and
ensure the two committed files remain tracked.
In `@deepmd/pt/loss/ener.py`:
- Around line 228-240: The mixed-batch normalization currently computes a single
scalar atom_norm via atom_norm = mean(1.0 / natoms_per_frame) which misweights
frames with different atom counts; instead compute per-frame normalizers and
apply them before reducing l2_ener_loss and l2_virial_loss: when is_mixed_batch
(detect via ptr and natoms_per_frame derived from ptr), build per-frame
atom_norms = 1.0 / natoms_per_frame.float() and multiply each frame's loss by
its corresponding atom_norm before taking the final mean/sum (or perform a
weighted reduction using these per-frame weights) so that l2_ener_loss and
l2_virial_loss are normalized per-frame correctly rather than using a single
averaged scalar.
In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 378-390: The flat-path call to descriptor.forward_flat currently
passes fparam unmodified and thus skips the fallback logic used in
forward_atomic; update the flat-path before calling descriptor.forward_flat (the
block around descriptor.forward_flat) to mirror forward_atomic's behavior: if
self.add_chg_spin_ebd (or the equivalent flag) is True and fparam is None,
compute the default via self.get_default_fparam(...) using the same inputs/shape
logic used in forward_atomic and assign it to fparam, then pass that fparam into
descriptor.forward_flat so mixed-batch runs use the same default frame
parameters.
In `@deepmd/pt/model/model/__init__.py`:
- Around line 90-98: The hybrid-descriptor logic in
_set_default_descriptor_init_seed (and the analogous block later) always derives
child seeds from the passed-in default `seed` instead of the effective parent
config seed; update the logic to first determine the parent_seed =
params.get("seed", seed) (i.e., use params["seed"] when present, otherwise fall
back to the incoming seed) and then call child_seed(parent_seed, idx) when
recursing into each hybrid child (preserving the existing
isinstance(descriptor_params, dict) check and recursive call to
_set_default_descriptor_init_seed).
In `@deepmd/pt/model/model/make_model.py`:
- Around line 497-518: The flat-path branch in make_model.py quietly skips
atomic virial when do_grad_c("energy") is true and do_atomic_virial is true (the
pass under if do_atomic_virial), causing fit_ret to miss
atom_virial/energy_derv_c keys and silently breaking callers (e.g.,
EnergyModel.forward reading model_ret["energy_derv_c"]); update the branch in
the block handling energy_derv_c_redu so that if do_atomic_virial is true you
either (A) compute and populate the appropriate atomic virial / energy_derv_c
entries in fit_ret consistent with the reduced-to-atom shape used elsewhere (use
energy_redu, energy_derv_c_redu and ptr to derive per-atom contributions), or if
that is nontrivial, (B) raise a clear NotImplementedError (or emit a one-time
logged warning and raise) referencing do_atomic_virial and the flat-path so
callers know the feature is unsupported; ensure the error/warning mentions
make_model.make_model (or the surrounding function) and that fit_ret must
include the keys model_predict/EnergyModel.forward expects (e.g.,
"energy_derv_c" / atom_virial) to avoid silent failures.
- Around line 217-352: The flat forward path forward_common_flat_native is
missing the same input/output precision normalization used in forward_common;
call the model's _input_type_cast on incoming tensors (coord, box, fparam,
aparam, atype, batch, mapping, extended_* inputs as appropriate) at the start so
coord/box are promoted to global_pt_float_precision (and energy-related inputs
to global_pt_ener_float_precision if required) before doing
clone().detach().requires_grad_(True) or rebuilding extended_coord, and after
computing model_predict_lower run _output_type_cast to promote outputs
(especially energy_redu) back to global_pt_ener_float_precision; ensure you
reference the existing helpers _input_type_cast and _output_type_cast and keep
their casting behavior identical to forward_common.
In `@deepmd/pt/train/training.py`:
- Around line 264-299: When _data.mixed_batch is true but the model descriptor
lacks repflows, raise an immediate, clear error instead of constructing a
dataloader with graph_config=None; inside the mixed_batch branch (where
descriptor is obtained from model_for_graph.atomic_model.descriptor and
make_lmdb_mixed_batch_collate is called), check hasattr(descriptor, "repflows")
and if False raise a ValueError/RuntimeError with a message like
"mixed_batch=True requires a flat-graph capable descriptor with 'repflows' (used
by make_lmdb_mixed_batch_collate and
EnergyModel.forward/forward_common_flat_native) — disable mixed_batch or provide
a compatible descriptor." This ensures failure occurs at dataloader construction
rather than later in EnergyModel.forward.
In `@deepmd/pt/utils/lmdb_dataset.py`:
- Around line 42-49: The __all__ export list is not alphabetically sorted which
triggers Ruff RUF022; reorder the list items in the __all__ variable so they are
lexicographically sorted (e.g., "_collate_lmdb_batch",
"_collate_lmdb_mixed_batch", "LmdbDataset", "LmdbTestData", "is_lmdb",
"make_lmdb_mixed_batch_collate") while preserving the existing string entries,
quotes, and list structure so the symbol names (LmdbDataset, LmdbTestData,
_collate_lmdb_batch, _collate_lmdb_mixed_batch, make_lmdb_mixed_batch_collate,
is_lmdb) remain unchanged.
- Around line 247-255: Mixed batching currently replaces the shuffle-based
DataLoader and ignores computed block weights, so when mixed_batch is True you
must construct _inner_dataloader to honor _block_targets computed by
compute_block_targets() rather than using DataLoader(..., shuffle=True). Replace
the plain shuffle=True approach in the mixed_batch branch by creating and
passing a sampler that uses _block_targets (e.g.,
torch.utils.data.WeightedRandomSampler or an equivalent custom sampler) to
DataLoader, keep num_workers=0 and collate_fn=_collate_lmdb_mixed_batch, and
ensure _inner_dataloader is assigned the DataLoader configured with that sampler
so auto_prob_style/block weighting still affects sampling.
- Around line 51-61: The whitelisted atom-wise keys in
_ATOMWISE_MIXED_BATCH_KEYS are missing supported labels (e.g. "atom_ener",
"drdq"), causing those fields to be handled by collate_tensor_fn and break
mixed-batch logic; update the frozenset declaration named
_ATOMWISE_MIXED_BATCH_KEYS to include all atom-wise labels used elsewhere (at
minimum "atom_ener" and "drdq") and make the same addition to the duplicate
whitelist further down (the other frozenset at lines ~99-104) so mixed-batch
flattening consistently handles these fields instead of falling back to
collate_tensor_fn.
In `@deepmd/pt/utils/nlist.py`:
- Around line 246-259: The code currently inverts atom_cell (atom_cell =
cell[extended_batch]) per extended atom which is wasteful; instead compute the
inverse once per frame by calling torch.linalg.inv_ex on cell (after cell =
box.reshape(-1,3,3)) and then index the resulting rec_cell with extended_batch
(i.e., rec_cell_full, _ = torch.linalg.inv_ex(cell); rec_cell =
rec_cell_full[extended_batch]) so subsequent einsum uses the indexed per-atom
inverses without repeating the matrix inversion.
In `@test_mixed_batch.sh`:
- Around line 15-19: The script prints "Check mixed_batch_train.log for details"
but never writes that file; fix by redirecting the training command's
stdout/stderr to mixed_batch_train.log (e.g., change the dp --pt train
test_mptraj/lmdb_mixed_batch.json --skip-neighbor-stat invocation to append or
write both stdout and stderr into mixed_batch_train.log) or, if you prefer not
to create a log, update the echo lines to not reference mixed_batch_train.log;
modify the dp invocation or the echo accordingly to keep the message accurate.
In `@test_mptraj/lmdb_baseline.json`:
- Around line 195-205: The baseline config currently hardcodes
developer-specific absolute paths in the keys stat_file, training_data.systems,
and validation_data.systems which breaks portability; change these values to
reusable placeholders or environment-variable expansions (e.g. ${DATASET_PATH})
or point them to a small example dataset shipped in the repo (mirror other
example configs), update any code that loads the config to resolve those env
vars, and add a short note in the config or README explaining how
test_mixed_batch.sh sets or substitutes those variables before running.
In `@test_mptraj/lmdb_mixed_batch.json`:
- Around line 74-84: The fixture uses hard-coded absolute paths in stat_file and
LMDB system entries (stat_file, training_data.systems, validation_data.systems)
which prevents CI portability; update the test_mptraj/lmdb_mixed_batch.json
fixture to reference repo-managed test data or injected paths by replacing
absolute paths with relative paths inside the repo (e.g., a test-data directory)
or with placeholders/env variables read by the test runner, and ensure the test
harness supplies those fixtures before running so the config uses
repo-controlled files rather than machine-specific absolute paths.
---
Outside diff comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 412-424: The generalized-force branch in ener.py assumes a single
natoms by reshaping force_pred, force_label and drdq with natoms * 3, which
breaks for mixed-size batches; instead, use frame-aware shapes: do not flatten
all frames into one dimension—reshape or iterate using per-frame atom counts (or
keep forces as (nframes, natoms_i, 3) and drdq as (nframes, natoms_i*3,
n_gcoord)) and compute gen_force and gen_force_label per-frame (or via batched
operations that use the original nframes dimension) so tensor dims align; also
apply find_drdq/pref_gf per-frame. Locate symbols: has_gf, drdq, find_drdq,
pref_gf, force_pred, force_label, force_reshape_nframes,
force_label_reshape_nframes, drdq_reshape, gen_force, gen_force_label and update
the reshaping/Einstein-summation to operate over the nframes dimension (or loop
over frames) instead of using natoms * 3 across the whole batch.
---
Nitpick comments:
In @.gitignore:
- Line 81: The .gitignore entry "deepmd-kit/" is too broad and may
unintentionally ignore legitimately nested repos; update the ignore rule that
currently uses deepmd-kit/ to a more specific pattern (for example scope it to
the repo root by changing to "/deepmd-kit/" or better yet target only
build/dist/artifact paths like "/deepmd-kit/build", "/deepmd-kit/dist", or
specific filenames/patterns such as "deepmd-kit/*.egg-info") so only the
intended artifact directories/files are ignored; locate the deepmd-kit/ entry in
the .gitignore and replace it with one or more narrow, explicit patterns.
In `@deepmd/pt/model/descriptor/repflows.py`:
- Around line 840-874: Add a short explanatory comment above the synthetic
batch=1 wrapping that states the disjoint-frame invariant: explain that the
temporary batching of tensors like node_ebd_batched, edge_ebd_batched,
h2_batched, angle_ebd_batched, nlist_batched, a_nlist_batched, etc. is safe
because RepFlowLayer.forward operates strictly per-atom and per-neighbor and the
LMDB collator guarantees nlist/a_nlist never reference atoms across different
frames (mapping/offsets keep frames disjoint); also warn that if
RepFlowLayer.forward (or any layer in self.layers) later introduces global ops
over the batch axis the synthetic batch will mix frames and break correctness,
so such changes must preserve the per-frame isolation or remove this batching
trick.
In `@deepmd/pt/model/model/ener_model.py`:
- Around line 92-176: The two branches in forward (flat and regular) duplicate
the same output-assembly logic for model_predict; extract that into a helper
method (e.g. _assemble_energy_predict(model_ret, do_atomic_virial)) that
encapsulates setting "atom_energy", "energy", choosing "force" vs "dforce" using
do_grad_r("energy"), adding "virial" and "atom_virial" via do_grad_c("energy")
and do_atomic_virial, including "mask" and "hessian" when present/when
_hessian_enabled; then call this helper from both the flat-path and regular-path
after receiving model_ret, and keep the flat/regular-specific behavior (the
no-fitting-net case that does model_predict["updated_coord"] += coord) in the
respective branch only. Ensure the helper references self.do_grad_r,
self.do_grad_c, self._hessian_enabled and self.get_fitting_net() behavior
remains consistent.
In `@deepmd/pt/model/model/make_model.py`:
- Around line 520-599: forward_common_flat is a pure passthrough to
forward_common_flat_native with duplicated docstrings, creating maintenance
overhead; either remove the redundant wrapper by merging
forward_common_flat_native into forward_common_flat (rename/remove the _native
variant and keep a single method and docstring) or add a concise
comment/docstring to forward_common_flat explaining the explicit reason for the
split (e.g., JIT/export boundary or future divergence). Update any call sites
that reference forward_common_flat_native to use the single chosen method name
(forward_common_flat) and ensure tests/type hints still pass.
- Around line 296-309: Move the helper import for
rebuild_extended_coord_from_flat_graph out of the hot forward path and add it to
the module-level nlist import block with the other deepmd.pt.utils.nlist imports
so it is imported once at load time; also change the validation order inside the
flat mixed-batch branch so you check for presence of the required graph fields
(mapping, extended_batch, extended_image, etc.) and raise the RuntimeError
before mutating tensors, and only call coord.requires_grad_(True) /
box.requires_grad_(True) (and any clone+detach) after the validation passes to
avoid wasted work when validation fails.
- Around line 311-314: The four assert statements for extended_atype,
extended_batch, mapping, and nlist are redundant because the preceding explicit
None-check/else already guarantees they are not None; either remove those assert
lines or replace them with a short explanatory comment indicating they exist
purely for static type-narrowing (for tools like mypy) so future readers
understand their purpose; locate the checks around the make_model function where
extended_atype, extended_batch, mapping, and nlist are validated and update
accordingly.
In `@deepmd/pt/model/network/graph_utils_flat.py`:
- Around line 41-89: Add a defensive check to guard against mis-configured
inputs by asserting that a_sel (from a_nlist_mask.shape[1]) does not exceed nnei
(from nlist_flat.shape[1]) near the top of the function; insert something like
an assert or raise ValueError after computing nnei and a_sel so that functions
using n2e_index, edge_lookup_a, and the a_nlist_mask_3d broadcasts fail fast
with a clear message instead of producing downstream broadcast errors.
In `@deepmd/pt/model/task/invar_fitting.py`:
- Around line 299-315: The tensor un-padding currently uses a shape heuristic in
the loop that builds result_flat; instead use the model's declared output types
from output_def() to decide which keys are atom-wise. Replace the shape-based
condition in the for key, value in result_batch loop with a lookup like
output_def()[key].is_atomic (or equivalent in your model API) and only apply
valid_atom_mask to values for keys declared atomic; leave frame-level or other
outputs untouched and handle non-tensor values as before (keep variable names
result_batch, result_flat, valid_atom_mask, and call output_def() to determine
atomic vs frame-level).
- Around line 259-264: The padding value for atype_batch is set to -1 which
causes PyTorch indexing (e.g., in self.forward when doing per-type lookups like
bias_atom_e[atype] or embeddings) to wrap to the last type; change the sentinel
to a valid non-negative index (e.g., 0 or a dedicated PAD type index) when
creating atype_batch so lookups won't accidentally reference real types, and
keep using valid_atom_mask to discard those rows downstream; update places that
assume negative sentinel behavior to use the new PAD index (refer to
atype_batch, self.forward, bias_atom_e, and valid_atom_mask).
In `@deepmd/pt/train/training.py`:
- Around line 1690-1723: In get_data, remove the special-case skipping of
"batch" and "ptr" inside the batch_data loop and delete the separate post-loop
device transfers for batch_data["batch"] and batch_data["ptr"]; instead let the
existing non-list tensor branch handle them (they'll be moved to DEVICE
automatically), and only keep the is_mixed_batch detection so you can extend
input_keys with list(_FLAT_GRAPH_INPUT_KEYS) when is_mixed_batch is true
(symbols: get_data, batch_data, is_mixed_batch, input_keys,
_FLAT_GRAPH_INPUT_KEYS, DEVICE).
In `@deepmd/pt/train/wrapper.py`:
- Around line 237-247: The natoms value currently has dual semantics (per-frame
nloc when atype.dim()>1 vs total flattened atom count when 1-D) which is
implicit and fragile; fix by making the contract explicit: compute both
nloc_per_frame (e.g., atype.shape[-1] when atype.dim()>1 else None) and
total_atoms (atype.numel() or atype.shape[0]) in wrapper.py and change the call
to self.loss[task_key](input_dict, self.model[task_key], label,
nloc_per_frame=nloc_per_frame, total_atoms=total_atoms, learning_rate=cur_lr),
then update all loss implementations (ener.py) and any factory in make_model.py
to accept these new named args (or provide backward-compatible handling) so
downstream code no longer must infer semantics from ptr.
In `@deepmd/pt/utils/nlist.py`:
- Around line 101-170: The loop is causing repeated device syncs via
int(ptr[frame_idx].item()) and per-frame kernel overhead in
extend_coord_with_ghosts_with_images and build_neighbor_list; fix by
materializing ptr on CPU once (e.g., ptr_cpu = ptr.cpu().tolist()) and use
ptr_cpu[frame_idx]/ptr_cpu[frame_idx+1] for start_idx/end_idx to eliminate
.item() calls, and where possible call
normalize_coord/extend_coord_with_ghosts_with_images/build_neighbor_list on
already-CPU tensors or add a batched/padded variant (similar to
extend_input_and_build_neighbor_list_with_images) to process frames in a
vectorized way to reduce Python per-frame overhead.
- Around line 262-281: Add a one-line doc comment inside get_central_ext_index
stating the invariant it relies on: that the extended layout produced by
_extend_coord_with_ghosts_impl places the unshifted image (shift (0,0,0)) first
for each frame and thus the first nloc extended atoms per frame are the local
atoms; mention that shift indices are sorted by L2 norm so the (0,0,0) image is
first and changing that ordering will break get_central_ext_index.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b6ad87fb-4157-40de-8f34-510d8232bbef
📒 Files selected for processing (21)
.gitignoredeepmd/pt/entrypoints/main.pydeepmd/pt/loss/ener.pydeepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/descriptor/dpa3.pydeepmd/pt/model/descriptor/env_mat.pydeepmd/pt/model/descriptor/repflows.pydeepmd/pt/model/model/__init__.pydeepmd/pt/model/model/ener_model.pydeepmd/pt/model/model/make_model.pydeepmd/pt/model/network/graph_utils_flat.pydeepmd/pt/model/task/invar_fitting.pydeepmd/pt/train/training.pydeepmd/pt/train/wrapper.pydeepmd/pt/utils/lmdb_dataset.pydeepmd/pt/utils/nlist.pydeepmd/utils/argcheck.pysource/tests/pt/test_lmdb_dataloader.pytest_mixed_batch.shtest_mptraj/lmdb_baseline.jsontest_mptraj/lmdb_mixed_batch.json
Signed-off-by: liwentao <liwt24@mails.tsinghua.edu.cn>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
deepmd/pt/loss/ener.py (1)
237-249:⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy liftMixed-batch normalization computes the wrong weighted average.
The current implementation averages
1/N_iacross frames first (line 247), then multiplies that scalar by the frame-averaged loss. This is not equivalent to applying each frame's1/N_iweight before averaging.Example demonstrating the mismatch:
- Frame 1: N₁=10, error²=100
- Frame 2: N₂=100, error²=400
Current (incorrect):
mean(1/10, 1/100) × mean(100, 400) = 0.055 × 250 = 13.75Correct:
mean((1/10)×100, (1/100)×400) = mean(10, 4) = 7.0Impact: Energy and virial losses (MSE/MAE/Huber paths) and their metrics are all misweighted, breaking gradient correctness for mixed-batch training.
Required fix: Compute per-frame
atom_norms = 1.0 / natoms_per_frame(and raise tonorm_expwhen needed), apply them element-wise to per-frame errors before the reduction (mean/Huber), then remove the post-reduction scalar multiplication byatom_norm.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/loss/ener.py` around lines 237 - 249, The mixed-batch path currently computes a scalar atom_norm by averaging 1/N_i (is_mixed_batch / ptr / natoms_per_frame -> atom_norm) and multiplies it after the loss reduction, which is incorrect; instead compute per-frame weights atom_norms = 1.0 / natoms_per_frame (and if a norm exponent is used, apply atom_norms = atom_norms.pow(norm_exp)), then apply these weights element-wise to the per-frame error terms before calling the reduction (mean/Huber) for energy and virial loss paths, and remove the later multiplication by the scalar atom_norm; adjust the branches that handle MSE/MAE/Huber so they accept per-frame-weighted errors rather than multiplying a scalar after reduction.
🧹 Nitpick comments (2)
deepmd/pt/model/descriptor/dpa3.py (1)
613-613: ⚡ Quick winAlign
forward_flatreturn type with actualNone-capable values.The signature says
dict[str, torch.Tensor], but"rot_mat","g2", and"h2"may beNone. This can mislead callers and type checks.Proposed fix
- ) -> dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor | None]:Also applies to: 704-718
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/dpa3.py` at line 613, The return type of forward_flat is too narrow: some keys ("rot_mat", "g2", "h2") can be None, so update the signature to reflect Optional values (e.g. change -> dict[str, Optional[torch.Tensor]]) and import typing.Optional; apply the same change to the other similar function(s) noted (the second return signature around the 704–718 region) so callers and type checkers know those keys may be None.deepmd/pt/loss/ener.py (1)
243-243: 💤 Low valueRemove unused variable.
nframesis computed but never used.🧹 Proposed fix
- nframes = ptr.numel() - 1 # Compute natoms for each frame natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/loss/ener.py` at line 243, The variable nframes is computed but never used; remove the unused assignment "nframes = ptr.numel() - 1" (or replace it with a deliberate discard like "_ = ptr.numel() - 1" if a side-effect is required) in the function that contains the ptr reference so no dead local remains and linter warnings go away.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/utils/lmdb_dataset.py`:
- Line 102: The call to collate_tensor_fn in lmdb_dataset.py is failing because
collate_tensor_fn is not defined; import PyTorch's collate utility and wire it
up. Add an import for default_collate from torch.utils.data._utils.collate (or
the PyTorch collate function available in your env) and either replace
collate_tensor_fn calls with default_collate or add an alias like
collate_tensor_fn = default_collate at module scope so the line result[key] =
collate_tensor_fn(tensors) works; reference the symbol collate_tensor_fn in
lmdb_dataset.py to locate where to add the import/alias.
---
Duplicate comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 237-249: The mixed-batch path currently computes a scalar
atom_norm by averaging 1/N_i (is_mixed_batch / ptr / natoms_per_frame ->
atom_norm) and multiplies it after the loss reduction, which is incorrect;
instead compute per-frame weights atom_norms = 1.0 / natoms_per_frame (and if a
norm exponent is used, apply atom_norms = atom_norms.pow(norm_exp)), then apply
these weights element-wise to the per-frame error terms before calling the
reduction (mean/Huber) for energy and virial loss paths, and remove the later
multiplication by the scalar atom_norm; adjust the branches that handle
MSE/MAE/Huber so they accept per-frame-weighted errors rather than multiplying a
scalar after reduction.
---
Nitpick comments:
In `@deepmd/pt/loss/ener.py`:
- Line 243: The variable nframes is computed but never used; remove the unused
assignment "nframes = ptr.numel() - 1" (or replace it with a deliberate discard
like "_ = ptr.numel() - 1" if a side-effect is required) in the function that
contains the ptr reference so no dead local remains and linter warnings go away.
In `@deepmd/pt/model/descriptor/dpa3.py`:
- Line 613: The return type of forward_flat is too narrow: some keys ("rot_mat",
"g2", "h2") can be None, so update the signature to reflect Optional values
(e.g. change -> dict[str, Optional[torch.Tensor]]) and import typing.Optional;
apply the same change to the other similar function(s) noted (the second return
signature around the 704–718 region) so callers and type checkers know those
keys may be None.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9f5f5f91-6d22-4772-adac-e7c96dc0c041
📒 Files selected for processing (6)
deepmd/pt/loss/ener.pydeepmd/pt/model/descriptor/dpa3.pydeepmd/pt/model/descriptor/repflows.pydeepmd/pt/train/training.pydeepmd/pt/utils/lmdb_dataset.pydeepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/utils/argcheck.py
- deepmd/pt/train/training.py
- deepmd/pt/model/descriptor/repflows.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
933-960:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winDistributed path references
DDPandLOCAL_RANKwithout imports.This branch will fail with
NameErrorin distributed runs at lines 934, 955-959.💡 Proposed fix
from deepmd.pt.utils.env import ( DEVICE, JIT, + LOCAL_RANK, NUM_WORKERS, SAMPLER_RECORD, ) +from torch.nn.parallel import ( + DistributedDataParallel as DDP, +)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/train/training.py` around lines 933 - 960, The distributed branch references DDP, dist and LOCAL_RANK but they are not imported/defined; add the necessary imports and LOCAL_RANK definition: import torch.distributed as dist and import DistributedDataParallel as DDP (from torch.nn.parallel import DistributedDataParallel as DDP), and ensure LOCAL_RANK is defined (e.g. LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) or equivalent) before use so torch.cuda.set_device(LOCAL_RANK), dist.broadcast(...) and DDP(...) work correctly.
♻️ Duplicate comments (1)
deepmd/pt/utils/lmdb_dataset.py (1)
106-112:⚠️ Potential issue | 🔴 Critical | ⚡ Quick win
collate_tensor_fnis undefined in mixed-batch collation.
result[key] = collate_tensor_fn(tensors)will raiseNameErroron the first frame-wise key that hits this branch.💡 Proposed fix
from torch.utils.data import ( DataLoader, Dataset, Sampler, ) +from torch.utils.data._utils.collate import ( + collate_tensor_fn, +)#!/bin/bash # Verify unresolved symbol + missing import in this file. rg -n "collate_tensor_fn" deepmd/pt/utils/lmdb_dataset.py python - <<'PY' import ast, pathlib p = pathlib.Path("deepmd/pt/utils/lmdb_dataset.py") tree = ast.parse(p.read_text()) imports = set() for n in ast.walk(tree): if isinstance(n, ast.ImportFrom): mod = n.module or "" for a in n.names: imports.add((mod, a.name)) print(("torch.utils.data._utils.collate", "collate_tensor_fn") in imports) PY🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/utils/lmdb_dataset.py` around lines 106 - 112, The code calls an undefined symbol collate_tensor_fn in deepmd/pt/utils/lmdb_dataset.py (inside the mixed-batch collation branch that assigns result[key] when key not in _ATOMWISE_MIXED_BATCH_KEYS), which causes a NameError; fix by importing the proper collate helper from PyTorch (e.g., add from torch.utils.data._utils.collate import collate_tensor_fn at the top) or replace collate_tensor_fn with the correct public API (e.g., torch.utils.data.default_collate) and update the call in the block that builds result[key] so the collate function is defined and used consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/pt/train/training.py`:
- Around line 933-960: The distributed branch references DDP, dist and
LOCAL_RANK but they are not imported/defined; add the necessary imports and
LOCAL_RANK definition: import torch.distributed as dist and import
DistributedDataParallel as DDP (from torch.nn.parallel import
DistributedDataParallel as DDP), and ensure LOCAL_RANK is defined (e.g.
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) or equivalent) before use so
torch.cuda.set_device(LOCAL_RANK), dist.broadcast(...) and DDP(...) work
correctly.
---
Duplicate comments:
In `@deepmd/pt/utils/lmdb_dataset.py`:
- Around line 106-112: The code calls an undefined symbol collate_tensor_fn in
deepmd/pt/utils/lmdb_dataset.py (inside the mixed-batch collation branch that
assigns result[key] when key not in _ATOMWISE_MIXED_BATCH_KEYS), which causes a
NameError; fix by importing the proper collate helper from PyTorch (e.g., add
from torch.utils.data._utils.collate import collate_tensor_fn at the top) or
replace collate_tensor_fn with the correct public API (e.g.,
torch.utils.data.default_collate) and update the call in the block that builds
result[key] so the collate function is defined and used consistently.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 0a52198d-56d6-4626-8400-09f27d5b022c
📒 Files selected for processing (8)
deepmd/pt/loss/ener.pydeepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/model/make_model.pydeepmd/pt/train/training.pydeepmd/pt/utils/lmdb_dataset.pydoc/development/lmdb-mixed-system-batching.mdsource/tests/pt/test_lmdb_dataloader.pysource/tests/pt/test_loss.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/utils/argcheck.py (1)
3763-3813:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winClarify that
batch_sizebecomes a frame count in mixed mode.The new
mixed_batchdocs explain the flattened LMDB path, but they never say thatbatch_sizeis reinterpreted as the number of frames/systems in a mixed batch. That is the part most likely to change users’ memory sizing and throughput expectations when they flip this flag.Also applies to: 3858-3918
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/utils/argcheck.py` around lines 3763 - 3813, Update the documentation strings so users know that when the Argument "mixed_batch" (alias "mix_batch") is True the meaning of the "batch_size" Argument changes: it is interpreted as the number of frames/systems (i.e., frame count) in a mixed batch rather than the per-sample/atom count used in non-mixed mode. Modify the doc_mixed_batch text (and also the doc_batch_size string referenced by the "batch_size" Argument) to explicitly state this behavioral change and its implications for memory sizing and throughput so both locations (the mixed_batch Argument and the batch_size Argument) describe the reinterpreted unit.deepmd/pt/loss/ener.py (1)
447-466:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGuard
drdqloss whenptrindicates a mixed batch.This branch still reshapes
force/drdqwith a singlenatoms, so it cannot preserve frame boundaries once a batch contains mixednloc. In mixed mode this will either mis-shape tensors or compute generalized-force terms across the wrong atoms.💡 Minimal safe fix
if self.has_gf and "drdq" in label: + if is_mixed_batch: + raise NotImplementedError( + "Generalized force loss is not supported with mixed_batch inputs." + ) drdq = label["drdq"] find_drdq = label.get("find_drdq", 0.0) pref_gf = pref_gf * find_drdq force_reshape_nframes = force_pred.reshape(-1, natoms * 3)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/loss/ener.py` around lines 447 - 466, The generalized-force block (has_gf branch using drdq, drdq_reshape, force_pred, force_label, natoms, numb_generalized_coord) assumes a single uniform natoms and thus breaks for mixed batches indicated by ptr; either skip computing l2_gen_force_loss when ptr shows mixed nloc or compute it per-frame using ptr slices so frame boundaries are preserved. Update the branch to check the batch pointer (ptr) / mixed-batch indicator before reshaping: if ptr indicates mixed sizes, loop over frames using ptr to slice force_pred, force_label and drdq and compute gen_force/gen_force_label per slice (accumulating weighted loss), otherwise keep the existing vectorized reshape path; ensure the guard uses the same symbol the code provides for the pointer and does not try to reshape using a single natoms when ptr denotes mixed frames.
♻️ Duplicate comments (1)
test_mptraj/lmdb_mixed_batch.json (1)
74-83:⚠️ Potential issue | 🟠 Major | ⚡ Quick winReplace remaining absolute dataset paths with injectable placeholders.
Line 76 and Line 82 still hard-code machine-specific LMDB paths, so this fixture is not reproducible in CI/other dev environments. Please switch both to env placeholders (consistent with
lmdb_baseline.json) or repo-managed test paths.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@test_mptraj/lmdb_mixed_batch.json` around lines 74 - 83, The fixture hard-codes machine-specific LMDB paths in the JSON keys training_data.systems and validation_data.systems; replace those absolute paths with injectable placeholders (e.g. "${LMDB_MPTRAJ_PATH}" and "${LMDB_WBM_PATH}" or the same placeholders used in lmdb_baseline.json) or point them to repo-managed test paths, keeping batch_size and mixed_batch entries unchanged so CI/dev environments can supply paths via env vars or test config.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@deepmd/pt/loss/ener.py`:
- Around line 447-466: The generalized-force block (has_gf branch using drdq,
drdq_reshape, force_pred, force_label, natoms, numb_generalized_coord) assumes a
single uniform natoms and thus breaks for mixed batches indicated by ptr; either
skip computing l2_gen_force_loss when ptr shows mixed nloc or compute it
per-frame using ptr slices so frame boundaries are preserved. Update the branch
to check the batch pointer (ptr) / mixed-batch indicator before reshaping: if
ptr indicates mixed sizes, loop over frames using ptr to slice force_pred,
force_label and drdq and compute gen_force/gen_force_label per slice
(accumulating weighted loss), otherwise keep the existing vectorized reshape
path; ensure the guard uses the same symbol the code provides for the pointer
and does not try to reshape using a single natoms when ptr denotes mixed frames.
In `@deepmd/utils/argcheck.py`:
- Around line 3763-3813: Update the documentation strings so users know that
when the Argument "mixed_batch" (alias "mix_batch") is True the meaning of the
"batch_size" Argument changes: it is interpreted as the number of frames/systems
(i.e., frame count) in a mixed batch rather than the per-sample/atom count used
in non-mixed mode. Modify the doc_mixed_batch text (and also the doc_batch_size
string referenced by the "batch_size" Argument) to explicitly state this
behavioral change and its implications for memory sizing and throughput so both
locations (the mixed_batch Argument and the batch_size Argument) describe the
reinterpreted unit.
---
Duplicate comments:
In `@test_mptraj/lmdb_mixed_batch.json`:
- Around line 74-83: The fixture hard-codes machine-specific LMDB paths in the
JSON keys training_data.systems and validation_data.systems; replace those
absolute paths with injectable placeholders (e.g. "${LMDB_MPTRAJ_PATH}" and
"${LMDB_WBM_PATH}" or the same placeholders used in lmdb_baseline.json) or point
them to repo-managed test paths, keeping batch_size and mixed_batch entries
unchanged so CI/dev environments can supply paths via env vars or test config.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ee28998d-94e8-4b3f-9a65-611aef50db4f
📒 Files selected for processing (6)
.gitignoredeepmd/pt/loss/ener.pydeepmd/utils/argcheck.pytest_mixed_batch.shtest_mptraj/lmdb_baseline.jsontest_mptraj/lmdb_mixed_batch.json
✅ Files skipped from review due to trivial changes (2)
- .gitignore
- test_mixed_batch.sh
|
|
||
| # Compute virial: dE/dh | ||
| if self.do_grad_c("energy"): | ||
| nframes = ptr.numel() - 1 |
iProzd
left a comment
There was a problem hiding this comment.
-
Implementation in other backend (at least dpmodel backend) should be added.
-
Tests for new forward method, such as universal tests in
source/tests/universal, end-to-end training insource/tests/pt/test_training.pyand consistent tests insource/tests/consistentshould be added.
There was a problem hiding this comment.
This modification should be removed.
There was a problem hiding this comment.
This modification should be removed.
There was a problem hiding this comment.
This modification should be removed.
There was a problem hiding this comment.
This modification should be removed.
There was a problem hiding this comment.
Doc should be written in english, and placed as a feature introduction instead of a development doc.
Add mixed-nloc LMDB batching for PyTorch training
Summary
This PR adds optional
mixed_batchsupport to the PyTorch LMDB training and validation data path.When
mixed_batchis enabled, one LMDB batch may contain frames with different atom counts /nlocvalues. Atom-wise fields are flattened across the batch, frame ownership is tracked withbatchandptr, and the model runs through a flat forward path without padding.The default behavior is unchanged.
mixed_batchisfalseby default, so LMDB batches still use the existing same-nlocsampler unless the option is explicitly enabled.Motivation
LMDB batches previously required all frames in a batch to have the same
nloc. This can make batching inefficient for datasets that mix systems with different atom counts. Mixed batching lets those frames train together while keeping atom indexing explicit in a flattened graph layout.Main Changes
mixed_batchto PyTorch LMDBtraining_dataandvalidation_data.mix_batchas an alias for the same option.mixed_batch=truemode, interpretbatch_sizeas the number of frames/systems in one mixed batch.batchandptrfor frame ownership.Usage
{ "training": { "training_data": { "systems": "path/to/train.lmdb", "batch_size": 8, "mixed_batch": true }, "validation_data": { "systems": "path/to/valid.lmdb", "batch_size": 8, "mixed_batch": true } } }With
mixed_batch=true,batch_sizecontrols how many frames/systems are merged into one flattened batch.With
mixed_batch=false, the existing same-nlocLMDB batching behavior is preserved.Tests
batch/ptrlayout and flat graph precomputation.mix_batchconfig alias.test_mixed_batch.shfor end-to-end training coverage.Scope
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests