diff --git a/.gitignore b/.gitignore index 897a224371..993674b6c8 100644 --- a/.gitignore +++ b/.gitignore @@ -74,3 +74,8 @@ frozen_model.* # Test system directories system/ *.expected +temp/ +pkl/ +history/ +deepmd-kit/ +*.hdf5 diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7b45c46333..e0a469787c 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -167,10 +167,12 @@ def _make_dp_loader_set( # LMDB path: single string → LmdbDataset if isinstance(training_systems, str) and is_lmdb(training_systems): auto_prob = training_dataset_params.get("auto_prob", None) + mixed_batch = training_dataset_params.get("mixed_batch", False) train_data_single = LmdbDataset( training_systems, model_params_single["type_map"], training_dataset_params["batch_size"], + mixed_batch=mixed_batch, auto_prob_style=auto_prob, ) if ( @@ -178,10 +180,12 @@ def _make_dp_loader_set( and isinstance(validation_systems, str) and is_lmdb(validation_systems) ): + val_mixed_batch = validation_dataset_params.get("mixed_batch", False) validation_data_single = LmdbDataset( validation_systems, model_params_single["type_map"], validation_dataset_params["batch_size"], + mixed_batch=val_mixed_batch, ) elif validation_systems is not None: validation_data_single = _make_dp_loader_set( diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 50d83a4ac9..6bbef5babf 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -238,11 +238,38 @@ def forward( more_loss = {} # more_loss['log_keys'] = [] # showed when validation on the fly # more_loss['test_keys'] = [] # showed when doing dp test - atom_norm = 1.0 / natoms - # Normalization exponent controls loss scaling with system size: - # - norm_exp=2 (intensive_ener_virial=True): loss uses 1/N² scaling, making it independent of system size - # - norm_exp=1 (intensive_ener_virial=False, legacy): loss uses 1/N scaling, which varies with system size + + # Detect mixed batch format + is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None + + atom_norms = None + if is_mixed_batch: + ptr = input_dict["ptr"] + natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes] + atom_norms = 1.0 / natoms_per_frame.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + atom_norm = None + else: + atom_norm = 1.0 / natoms norm_exp = 2 if self.intensive_ener_virial else 1 + + def get_frame_norm(value: torch.Tensor) -> torch.Tensor: + assert atom_norms is not None + return atom_norms.to(device=value.device, dtype=value.dtype).view( + [-1] + [1] * (value.dim() - 1) + ) + + def weighted_mean(value: torch.Tensor, power: int = 1) -> torch.Tensor: + if atom_norms is None: + assert atom_norm is not None + return value.mean() * (atom_norm**power) + return (value * get_frame_norm(value) ** power).mean() + + def normalized_rmse(diff: torch.Tensor) -> torch.Tensor: + if atom_norms is None: + assert atom_norm is not None + return torch.mean(torch.square(diff)).sqrt() * atom_norm + return torch.mean(torch.square(diff * get_frame_norm(diff))).sqrt() + if self.has_e and "energy" in model_pred and "energy" in label: energy_pred = model_pred["energy"] energy_label = label["energy"] @@ -261,35 +288,37 @@ def forward( energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1) find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy + diff_e = energy_pred - energy_label if self.loss_func == "mse": - l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label)) + square_ener_diff = torch.square(diff_e) + l2_ener_loss = torch.mean(square_ener_diff) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy ) if not self.use_huber: - loss += atom_norm**norm_exp * (pref_e * l2_ener_loss) + loss += pref_e * weighted_mean(square_ener_diff, norm_exp) else: + energy_norm = ( + atom_norm if atom_norms is None else get_frame_norm(energy_pred) + ) l_huber_loss = custom_huber_loss( - atom_norm * energy_pred, - atom_norm * energy_label, + energy_norm * energy_pred, + energy_norm * energy_label, delta=self._huber_delta_energy, ) loss += pref_e * l_huber_loss - rmse_e = l2_ener_loss.sqrt() * atom_norm + rmse_e = normalized_rmse(diff_e) more_loss["rmse_e"] = self.display_if_exist( rmse_e.detach(), find_energy ) # more_loss['log_keys'].append('rmse_e') elif self.loss_func == "mae": - l1_ener_loss = F.l1_loss( - energy_pred.reshape(-1), - energy_label.reshape(-1), - reduction="mean", - ) - loss += atom_norm * (pref_e * l1_ener_loss) + abs_ener_diff = torch.abs(diff_e) + mae_e = weighted_mean(abs_ener_diff) + loss += pref_e * mae_e more_loss["mae_e"] = self.display_if_exist( - l1_ener_loss.detach() * atom_norm, + mae_e.detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') @@ -298,9 +327,9 @@ def forward( f"Loss type {self.loss_func} is not implemented for energy loss." ) if mae: - mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm + mae_e = weighted_mean(torch.abs(diff_e)) more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) + mae_e_all = torch.mean(torch.abs(diff_e)) more_loss["mae_e_all"] = self.display_if_exist( mae_e_all.detach(), find_energy ) @@ -417,6 +446,10 @@ def forward( ) if self.has_gf and "drdq" in label: + if is_mixed_batch: + raise NotImplementedError( + "Generalized force loss is not supported with mixed_batch=True yet." + ) drdq = label["drdq"] find_drdq = label.get("find_drdq", 0.0) pref_gf = pref_gf * find_drdq @@ -446,33 +479,36 @@ def forward( pref_v = pref_v * find_virial diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) if self.loss_func == "mse": - l2_virial_loss = torch.mean(torch.square(diff_v)) + square_virial_diff = torch.square(diff_v) + l2_virial_loss = torch.mean(square_virial_diff) if not self.inference: more_loss["l2_virial_loss"] = self.display_if_exist( l2_virial_loss.detach(), find_virial ) if not self.use_huber: - loss += atom_norm**norm_exp * (pref_v * l2_virial_loss) + loss += pref_v * weighted_mean(square_virial_diff, norm_exp) else: + virial = model_pred["virial"].reshape(-1, 9) + virial_label = label["virial"].reshape(-1, 9) + virial_norm = ( + atom_norm if atom_norms is None else get_frame_norm(virial) + ) l_huber_loss = custom_huber_loss( - atom_norm * model_pred["virial"].reshape(-1), - atom_norm * label["virial"].reshape(-1), + (virial_norm * virial).reshape(-1), + (virial_norm * virial_label).reshape(-1), delta=self._huber_delta_virial, ) loss += pref_v * l_huber_loss - rmse_v = l2_virial_loss.sqrt() * atom_norm + rmse_v = normalized_rmse(diff_v) more_loss["rmse_v"] = self.display_if_exist( rmse_v.detach(), find_virial ) elif self.loss_func == "mae": - l1_virial_loss = F.l1_loss( - label["virial"].reshape(-1), - model_pred["virial"].reshape(-1), - reduction="mean", - ) - loss += atom_norm * (pref_v * l1_virial_loss) + abs_virial_diff = torch.abs(diff_v) + mae_v = weighted_mean(abs_virial_diff) + loss += pref_v * mae_v more_loss["mae_v"] = self.display_if_exist( - l1_virial_loss.detach() * atom_norm, + mae_v.detach(), find_virial, ) else: @@ -480,7 +516,7 @@ def forward( f"Loss type {self.loss_func} is not implemented for virial loss." ) if mae: - mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + mae_v = weighted_mean(torch.abs(diff_v)) more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index efb2a532e5..e3a634c98f 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -318,6 +318,153 @@ def forward_atomic( ) return fit_ret + def forward_common_atomic_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass with flat batch format. + + Parameters + ---------- + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + nlist : torch.Tensor + Neighbor list [total_atoms, nnei]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + aparam : torch.Tensor | None + Atomic parameters [total_atoms, nda]. + central_ext_index : torch.Tensor | None + Extended-atom indices corresponding to local atoms. + nlist_ext, a_nlist_ext : torch.Tensor | None + Edge and angle neighbor lists indexing concatenated extended atoms. + nlist_mask, a_nlist_mask : torch.Tensor | None + Valid-neighbor masks for flat edge and angle neighbor lists. + edge_index, angle_index : torch.Tensor | None + Dynamic graph indices produced by the flat graph preprocessor. + + Returns + ------- + result_dict : dict[str, torch.Tensor] + Model predictions in flat format. + """ + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + + if ( + hasattr(self.fitting_net, "get_dim_fparam") + and self.fitting_net.get_dim_fparam() > 0 + and fparam is None + ): + default_fparam_tensor = self.fitting_net.get_default_fparam() + assert default_fparam_tensor is not None + fparam_input_for_des = torch.tile( + default_fparam_tensor.to(device=extended_coord.device).unsqueeze(0), + [ptr.numel() - 1, 1], + ) + else: + fparam_input_for_des = fparam + + # Descriptor and fitting both consume the flat atom layout. + descriptor_out = self.descriptor.forward_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + fparam=fparam_input_for_des if self.add_chg_spin_ebd else None, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + descriptor = descriptor_out.get("descriptor") + rot_mat = descriptor_out.get("rot_mat") + g2 = descriptor_out.get("g2") + h2 = descriptor_out.get("h2") + + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + if central_ext_index is None: + from deepmd.pt.utils.nlist import ( + get_central_ext_index, + ) + + central_ext_index = get_central_ext_index(extended_batch, ptr) + atype = extended_atype[central_ext_index] + else: + atype = extended_atype[central_ext_index] + + fit_ret = self.fitting_net.forward_flat( + descriptor, + atype, + batch, + ptr, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + fit_ret = self.apply_out_stat(fit_ret, atype) + + atom_mask = self.make_atom_mask(atype).to(torch.int32) + if self.atom_excl is not None: + atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0) + + for kk in fit_ret.keys(): + out_shape = fit_ret[kk].shape + out_shape2 = 1 + for ss in out_shape[1:]: + out_shape2 *= ss + fit_ret[kk] = ( + fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None] + ).view(out_shape) + fit_ret["mask"] = atom_mask + + if self.enable_eval_fitting_last_layer_hook: + if "middle_output" in fit_ret: + self.eval_fitting_last_layer_list.append( + fit_ret.pop("middle_output").detach() + ) + + return fit_ret + def compute_or_load_stat( self, sampled_func: Callable[[], list[dict]], diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index a5f79280fa..a10935d613 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -592,6 +592,132 @@ def forward( sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None, ) + def forward_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + fparam: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Compute the descriptor with flat batch format. + + Parameters + ---------- + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + nlist : torch.Tensor + Neighbor list [total_atoms, nnei]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + central_ext_index : torch.Tensor | None + Extended-atom indices corresponding to local atoms. + nlist_ext, a_nlist_ext : torch.Tensor | None + Edge and angle neighbor lists indexing concatenated extended atoms. + nlist_mask, a_nlist_mask : torch.Tensor | None + Valid-neighbor masks for flat edge and angle neighbor lists. + edge_index, angle_index : torch.Tensor | None + Dynamic graph indices produced by the flat graph preprocessor. + + Returns + ------- + result : dict[str, torch.Tensor] + Dictionary containing: + - 'descriptor': [total_atoms, descriptor_dim] + - 'rot_mat': [total_atoms, e_dim, 3] or None + - 'g2': edge embedding or None + - 'h2': pair representation or None + """ + extended_coord = extended_coord.to(dtype=self.prec) + + # Flat batches embed all extended atoms, then gather central atoms. + node_ebd_ext = self.type_embedding( + extended_atype + ) # [total_extended_atoms, tebd_dim] + + if self.add_chg_spin_ebd: + assert fparam is not None + assert self.chg_embedding is not None + assert self.spin_embedding is not None + + # Expand frame-level charge/spin parameters to extended atoms. + charge = fparam[extended_batch, 0].to(dtype=torch.int64) + 100 + spin = fparam[extended_batch, 1].to(dtype=torch.int64) + chg_ebd = self.chg_embedding(charge) + spin_ebd = self.spin_embedding(spin) + sys_cs_embd = self.act( + self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1)) + ) + node_ebd_ext = node_ebd_ext + sys_cs_embd + + if central_ext_index is None: + from deepmd.pt.utils.nlist import ( + get_central_ext_index, + ) + + central_ext_index = get_central_ext_index(extended_batch, ptr) + node_ebd_inp = node_ebd_ext[central_ext_index] + + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward_flat( + nlist, + extended_coord, + extended_atype, + extended_batch, + node_ebd_ext, + mapping, + batch, + ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + if self.concat_output_tebd: + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) + + return { + "descriptor": node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + "rot_mat": ( + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + if rot_mat is not None + else None + ), + "g2": ( + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) + if edge_ebd is not None + else None + ), + "h2": ( + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if h2 is not None else None + ), + } + @classmethod def update_sel( cls, diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index 0ffdbb7dbb..5d73d784e3 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -90,3 +90,98 @@ def prod_env_mat( t_std = stddev[atype] # [n_atom, dim, 4 or 1] env_mat_se_a = (_env_mat_se_a - t_avg) / t_std return env_mat_se_a, diff, switch + + +def prod_env_mat_flat( + extended_coord_flat: torch.Tensor, + nlist_flat: torch.Tensor, + atype_flat: torch.Tensor, + mean: torch.Tensor, + stddev: torch.Tensor, + rcut: float, + rcut_smth: float, + radial_only: bool = False, + protection: float = 0.0, + use_exp_switch: bool = False, + coord_flat: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate smooth environment matrix in flat format. + + Parameters + ---------- + extended_coord_flat + Extended atom coordinates with shape ``[nall, 3]``. + nlist_flat + Neighbor list with shape ``[nloc, nnei]``. ``-1`` marks padding. + atype_flat + Central atom types with shape ``[nloc]``. + mean, stddev + Descriptor statistics with shape ``[ntypes, nnei, 4 or 1]``. + rcut, rcut_smth + Cutoff radius and smooth cutoff radius. + radial_only + Whether to return radial-only descriptors. + protection + Small positive value used in radial divisions. + use_exp_switch + Whether to use the exponential switch function. + coord_flat + Optional central atom coordinates with shape ``[nloc, 3]``. + + Returns + ------- + env_mat + Environment matrix with shape ``[nloc, nnei, 4 or 1]``. + diff + Difference vectors with shape ``[nloc, nnei, 3]``. + switch + Switch function values with shape ``[nloc, nnei, 1]``. + """ + nloc, nnei = nlist_flat.shape + nall = extended_coord_flat.shape[0] + + mask = nlist_flat >= 0 + nlist_safe = torch.where(mask, nlist_flat, nall) + + # coord_l: [nloc, 1, 3] + if coord_flat is not None: + coord_l = coord_flat.view(nloc, 1, 3) + else: + coord_l = extended_coord_flat[:nloc].view(nloc, 1, 3) + + # Gather neighbor coordinates + index = nlist_safe.view(-1).unsqueeze(-1).expand(-1, 3) + coord_pad = torch.cat( + [extended_coord_flat, extended_coord_flat[-1:, :] + rcut], dim=0 + ) + coord_r = torch.gather(coord_pad, 0, index) + coord_r = coord_r.view(nloc, nnei, 3) + + # Compute differences and distances + diff = coord_r - coord_l + length = torch.linalg.norm(diff, dim=-1, keepdim=True) + length = length + ~mask.unsqueeze(-1) + + t0 = 1 / (length + protection) + t1 = diff / (length + protection) ** 2 + + weight = ( + compute_smooth_weight(length, rcut_smth, rcut) + if not use_exp_switch + else compute_exp_sw(length, rcut_smth, rcut) + ) + weight = weight * mask.unsqueeze(-1) + + if radial_only: + env_mat = t0 * weight + else: + env_mat = torch.cat([t0, t1], dim=-1) * weight + + diff = diff * mask.unsqueeze(-1) + + # Normalize by mean and stddev + t_avg = mean[atype_flat] # [nloc, nnei, 4] + t_std = stddev[atype_flat] # [nloc, nnei, 4] + env_mat = (env_mat - t_avg) / t_std + + return env_mat, diff, weight diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index e29fe01ac6..470f432200 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -681,6 +681,240 @@ def forward( return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + def forward_flat( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + extended_atype_embd: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + """Forward pass for a precomputed flat graph batch. + + Parameters + ---------- + nlist : torch.Tensor + Neighbor list [total_atoms, nnei] with global extended indices. + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + extended_atype_embd : torch.Tensor + Type embeddings for extended atoms [total_extended_atoms, tebd_dim]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + central_ext_index + Extended-atom indices corresponding to local atoms. + nlist_ext, a_nlist_ext + Edge and angle neighbor lists indexing concatenated extended atoms. + nlist_mask, a_nlist_mask + Valid-neighbor masks for edge and angle neighbor lists. + edge_index, angle_index + Dynamic graph indices generated from the flat neighbor lists. + + Returns + ------- + node_ebd : torch.Tensor + Node embeddings [total_atoms, n_dim]. + edge_ebd : torch.Tensor | None + Edge embeddings. + h2 : torch.Tensor | None + Pair representation. + rot_mat : torch.Tensor | None + Rotation matrix [total_atoms, e_dim, 3]. + sw : torch.Tensor | None + Switch function. + """ + from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat_flat, + ) + + nloc = batch.shape[0] + if ( + central_ext_index is None + or nlist_ext is None + or a_nlist is None + or a_nlist_ext is None + or nlist_mask is None + or a_nlist_mask is None + ): + raise RuntimeError( + "Repflows flat forward requires precomputed graph fields from collate_fn." + ) + coord_central = extended_coord[central_ext_index] + atype = extended_atype[central_ext_index] + + # Edge environment matrix in extended-atom index space. + dmatrix, diff, sw = prod_env_mat_flat( + extended_coord, + nlist_ext, + atype, + self.mean, + self.stddev, + self.e_rcut, + self.e_rcut_smth, + protection=self.env_protection, + use_exp_switch=self.use_exp_switch, + coord_flat=coord_central, + ) + + sw = torch.squeeze(sw, -1) + sw = sw.masked_fill(~nlist_mask, 0.0) + + # Angle environment matrix uses the angle cutoff and angle neighbor list. + _, a_diff, a_sw = prod_env_mat_flat( + extended_coord, + a_nlist_ext, + atype, + self.mean[:, : self.a_sel], + self.stddev[:, : self.a_sel], + self.a_rcut, + self.a_rcut_smth, + protection=self.env_protection, + use_exp_switch=self.use_exp_switch, + coord_flat=coord_central, + ) + + a_sw = torch.squeeze(a_sw, -1) + a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + + # Node embedding for central atoms. + atype_embd = extended_atype_embd[central_ext_index] + assert list(atype_embd.shape) == [nloc, self.n_dim] + node_ebd = self.act(atype_embd) + + # Edge and angle embedding inputs. + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + if self.edge_init_use_dist: + edge_input = safe_for_norm(diff, dim=-1, keepdim=True) + + # Angle input is the normalized cosine between neighbor directions. + normalized_diff_i = a_diff / ( + safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6 + ) + normalized_diff_j = torch.transpose(normalized_diff_i, 1, 2) + cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + + if self.use_dynamic_sel: + if edge_index is None or angle_index is None: + raise RuntimeError( + "Dynamic flat forward requires precomputed edge_index and angle_index." + ) + # Flatten dynamic-selection tensors to match graph indices. + edge_input = edge_input[nlist_mask] + h2 = h2[nlist_mask] + sw = sw[nlist_mask] + a_nlist_mask_2d = a_nlist_mask[:, :, None] & a_nlist_mask[:, None, :] + angle_input = angle_input[a_nlist_mask_2d] + a_sw = (a_sw[:, :, None] * a_sw[:, None, :])[a_nlist_mask_2d] + else: + edge_index = torch.zeros([2, 1], device=nlist.device, dtype=nlist.dtype) + angle_index = torch.zeros([3, 1], device=nlist.device, dtype=nlist.dtype) + + # Edge and angle embeddings. + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) + angle_ebd = self.angle_embd(angle_input) + + # RepFlowLayer expects batched tensors. Use a synthetic one-frame batch + # while preserving flattened dynamic edge and angle tensors. + node_ebd_batched = node_ebd.unsqueeze(0) # [1, nloc, n_dim] + edge_ebd_batched = ( + edge_ebd.unsqueeze(0) if not self.use_dynamic_sel else edge_ebd + ) + h2_batched = h2.unsqueeze(0) if not self.use_dynamic_sel else h2 + angle_ebd_batched = ( + angle_ebd.unsqueeze(0) if not self.use_dynamic_sel else angle_ebd + ) + nlist_batched = nlist.unsqueeze(0) # [1, nloc, nnei] + nlist_mask_batched = nlist_mask.unsqueeze(0) # [1, nloc, nnei] + sw_batched = sw.unsqueeze(0) if not self.use_dynamic_sel else sw + a_nlist_batched = a_nlist.unsqueeze(0) # [1, nloc, a_nnei] + a_nlist_mask_batched = a_nlist_mask.unsqueeze(0) # [1, nloc, a_nnei] + a_sw_batched = a_sw.unsqueeze(0) if not self.use_dynamic_sel else a_sw + + for ll in self.layers: + # Flat precomputed graphs already use local atom indexing here. + node_ebd_ext_batched = node_ebd_batched + + node_ebd_batched, edge_ebd_batched, angle_ebd_batched = ll.forward( + node_ebd_ext_batched, + edge_ebd_batched, + h2_batched, + angle_ebd_batched, + nlist_batched, + nlist_mask_batched, + sw_batched, + a_nlist_batched, + a_nlist_mask_batched, + a_sw_batched, + edge_index=edge_index, + angle_index=angle_index, + ) + + # Rotation matrix from final edge representation. + if self.use_dynamic_sel: + h2g2 = RepFlowLayer._cal_hg_dynamic( + edge_ebd_batched, + h2_batched, + sw_batched, + owner=edge_index[0], + num_owner=nloc, + nb=1, + nloc=nloc, + scale_factor=(self.nnei / self.sel_reduce_factor) ** (-0.5), + ).squeeze(0) + else: + # Use batched versions for _cal_hg, then squeeze + h2g2 = RepFlowLayer._cal_hg( + edge_ebd_batched, + h2_batched, + nlist_mask_batched, + sw_batched, + ) + h2g2 = h2g2.squeeze(0) # Remove batch dimension + + # Remove batch dimension from outputs + node_ebd = node_ebd_batched.squeeze(0) + edge_ebd = ( + edge_ebd_batched.squeeze(0) + if not self.use_dynamic_sel + else edge_ebd_batched + ) + h2 = h2_batched.squeeze(0) if not self.use_dynamic_sel else h2_batched + sw = sw_batched.squeeze(0) if not self.use_dynamic_sel else sw_batched + + # [nloc, e_dim, 3] + rot_mat = torch.permute(h2g2, (0, 2, 1)) + + return node_ebd, edge_ebd, h2, rot_mat, sw + def compute_input_stats( self, merged: Callable[[], list[dict]] | list[dict], diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..d824bba555 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -19,6 +19,9 @@ import numpy as np +from deepmd.dpmodel.utils.seed import ( + child_seed, +) from deepmd.pt.model.atomic_model import ( DPAtomicModel, PairTabAtomicModel, @@ -74,6 +77,27 @@ SpinModel, ) +DEFAULT_DESCRIPTOR_INIT_SEED = 1 +DEFAULT_FITTING_INIT_SEED = 2 + + +def _set_default_init_seed(params: dict[str, Any], seed: int | list[int]) -> None: + if params.get("seed") is None: + params["seed"] = seed + + +def _set_default_descriptor_init_seed( + params: dict[str, Any], seed: int | list[int] +) -> None: + if params.get("type") == "hybrid": + for idx, descriptor_params in enumerate(params.get("list", [])): + if isinstance(descriptor_params, dict): + _set_default_descriptor_init_seed( + descriptor_params, child_seed(seed, idx) + ) + return + _set_default_init_seed(params, seed) + def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: if "type_embedding" in model_params: @@ -83,9 +107,13 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: # descriptor model_params["descriptor"]["ntypes"] = ntypes model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + _set_default_descriptor_init_seed( + model_params["descriptor"], DEFAULT_DESCRIPTOR_INIT_SEED + ) descriptor = BaseDescriptor(**model_params["descriptor"]) # fitting fitting_net = model_params.get("fitting_net", {}) + _set_default_init_seed(fitting_net, DEFAULT_FITTING_INIT_SEED) fitting_net["type"] = fitting_net.get("type", "ener") fitting_net["ntypes"] = descriptor.get_ntypes() fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 1680d1e258..e505318447 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -72,7 +72,76 @@ def forward( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + batch: torch.Tensor | None = None, + ptr: torch.Tensor | None = None, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: + if not torch.jit.is_scripting() and batch is not None and ptr is not None: + model_ret = self.forward_common_flat( + coord=coord, + atype=atype, + batch=batch, + ptr=ptr, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + extended_atype=extended_atype, + extended_batch=extended_batch, + extended_image=extended_image, + extended_ptr=extended_ptr, + mapping=mapping, + central_ext_index=central_ext_index, + nlist=nlist, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + if "dforce" in model_ret: + model_predict["force"] = model_ret["dforce"] + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze( + -2 + ) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret[ + "energy_derv_r_derv_r" + ].squeeze(-3) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + model_ret = self.forward_common( coord, atype, @@ -87,14 +156,15 @@ def forward( model_predict["energy"] = model_ret["energy_redu"] if self.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + else: + if "dforce" in model_ret: + model_predict["force"] = model_ret["dforce"] if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( -2 ) - else: - model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] if self._hessian_enabled: @@ -102,6 +172,7 @@ def forward( else: model_predict = model_ret model_predict["updated_coord"] += coord + return model_predict @torch.jit.export diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 83e0209ad8..28a982ae00 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -214,6 +214,404 @@ def forward_common( model_predict = self._output_type_cast(model_predict, input_prec) return model_predict + def forward_common_flat_native( + self, + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass for mixed-nloc batches with a precomputed flat graph. + + This path consumes graph tensors prepared by the LMDB collate function + and keeps atom-wise values flattened across frames. + + Parameters + ---------- + coord + Flattened atomic coordinates with shape [total_atoms, 3]. + atype + Flattened atomic types with shape [total_atoms]. + batch + Atom-to-frame assignment with shape [total_atoms]. + ptr + Frame boundaries with shape [nframes + 1]. + box + Simulation boxes with shape [nframes, 9]. + fparam + Frame parameters with shape [nframes, ndf]. + aparam + Flattened atomic parameters with shape [total_atoms, nda]. + do_atomic_virial + Whether to calculate atomic virial. + + Returns + ------- + model_predict : dict[str, torch.Tensor] + Model predictions with flat format: + - atomwise outputs: [total_atoms, ...] + - frame-wise outputs: [nframes, ...] + + Notes + ----- + The precomputed graph fields are required for this path; missing + fields are treated as a data pipeline error. + """ + if do_atomic_virial: + raise NotImplementedError( + "Atomic virial is not implemented for flat mixed-batch forward." + ) + coord, box, fparam, aparam, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + # Enable gradient tracking for coord and box if needed + if self.do_grad_r("energy"): + coord = coord.clone().detach().requires_grad_(True) + if self.do_grad_c("energy") and box is not None: + box = box.clone().detach().requires_grad_(True) + if ( + extended_atype is not None + and extended_batch is not None + and extended_image is not None + and mapping is not None + and nlist is not None + and nlist_ext is not None + and a_nlist is not None + and a_nlist_ext is not None + and nlist_mask is not None + and a_nlist_mask is not None + and central_ext_index is not None + ): + from deepmd.pt.utils.nlist import ( + rebuild_extended_coord_from_flat_graph, + ) + + extended_coord = rebuild_extended_coord_from_flat_graph( + coord, + box, + mapping, + extended_batch, + extended_image, + ) + else: + raise RuntimeError( + "Flat mixed-batch forward requires precomputed graph fields from " + "the LMDB collate_fn." + ) + # Pass flat extended coordinates directly to the atomic model. + assert extended_atype is not None + assert extended_batch is not None + assert mapping is not None + assert nlist is not None + model_predict_lower = self.forward_common_lower_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + do_atomic_virial=do_atomic_virial, + fparam=fparam, + aparam=aparam, + extended_ptr=extended_ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + # Compute derivatives if needed + if self.do_grad_r("energy") or self.do_grad_c("energy"): + model_predict_lower = self._compute_derivatives_flat( + model_predict_lower, + extended_coord, + extended_atype, + extended_batch, + coord, + atype, + batch, + ptr, + box, + do_atomic_virial, + ) + + return self._output_type_cast(model_predict_lower, input_prec) + + def forward_common_lower_flat( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + do_atomic_virial: bool = False, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Lower interface for flat batch format. + + Parameters + ---------- + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + nlist : torch.Tensor + Neighbor list [total_atoms, nnei]. + mapping : torch.Tensor + Extended atom -> local flat index mapping [total_extended_atoms]. + batch : torch.Tensor + Frame assignment for local atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + do_atomic_virial : bool + Whether to compute atomic virial. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + aparam : torch.Tensor | None + Atomic parameters [total_atoms, nda]. + + Returns + ------- + model_predict : dict[str, torch.Tensor] + Model predictions in flat format. + """ + # The atomic model keeps atom-wise outputs in flat format. + model_ret = self.atomic_model.forward_common_atomic_flat( + extended_coord, + extended_atype, + extended_batch, + nlist, + mapping, + batch, + ptr, + fparam=fparam, + aparam=aparam, + extended_ptr=extended_ptr, + central_ext_index=central_ext_index, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + + # Reduce atom-wise energy to frame-wise energy. + nframes = ptr.numel() - 1 + if "energy" in model_ret: + energy_atomic = model_ret["energy"] # [total_atoms, 1] + energy_redu = energy_atomic.new_zeros( + (nframes, energy_atomic.shape[-1]) + ) + energy_redu.index_add_(0, batch, energy_atomic) + model_ret["energy_redu"] = energy_redu + + return model_ret + + def _compute_derivatives_flat( + self, + fit_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_batch: torch.Tensor, + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + """Compute force and virial derivatives for flat batch format. + + Parameters + ---------- + fit_ret : dict[str, torch.Tensor] + Fitting network output with "energy" key [total_atoms, 1]. + extended_coord : torch.Tensor + Extended coordinates [total_extended_atoms, 3]. + extended_atype : torch.Tensor + Extended atom types [total_extended_atoms]. + extended_batch : torch.Tensor + Frame assignment for extended atoms [total_extended_atoms]. + coord : torch.Tensor + Original coordinates [total_atoms, 3]. + atype : torch.Tensor + Original atom types [total_atoms]. + batch : torch.Tensor + Frame assignment for original atoms [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + box : torch.Tensor | None + Simulation boxes [nframes, 9]. + do_atomic_virial : bool + Whether to compute atomic virial. + + Returns + ------- + model_predict : dict[str, torch.Tensor] + Model predictions with derivatives in flat format. + """ + # Force is the negative gradient of the total atomic energy. + if self.do_grad_r("energy"): + energy_atomic = fit_ret["energy"] # [total_atoms, 1] + + energy_derv_r = torch.autograd.grad( + outputs=energy_atomic.sum(), + inputs=coord, + create_graph=True, + retain_graph=True, + )[0] # [total_atoms, 3] + + fit_ret["energy_derv_r"] = -energy_derv_r.unsqueeze( + -2 + ) # [total_atoms, 1, 3] + # Also provide dforce field for compatibility with EnergyModel.forward() + fit_ret["dforce"] = -energy_derv_r # [total_atoms, 3] + + # Compute virial: dE/dh + if self.do_grad_c("energy"): + nframes = ptr.numel() - 1 + energy_redu = fit_ret["energy_redu"] # [nframes, 1] + + if box is not None: + energy_derv_c_redu = torch.autograd.grad( + outputs=energy_redu.sum(), + inputs=box, + create_graph=True, + retain_graph=True, + )[0] # [nframes, 9] + + fit_ret["energy_derv_c_redu"] = energy_derv_c_redu.unsqueeze( + 1 + ) # [nframes, 1, 9] + + if do_atomic_virial: + raise NotImplementedError( + "Atomic virial is not implemented for flat mixed-batch " + "forward." + ) + + return fit_ret + + def forward_common_flat( + self, + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass for flat mixed-nloc batch. + + This method consumes the precomputed flat graph produced by LMDB + collation and returns the same output keys as the regular path. + + Parameters + ---------- + coord + Flattened atomic coordinates with shape [total_atoms, 3]. + atype + Flattened atomic types with shape [total_atoms]. + batch + Atom-to-frame assignment with shape [total_atoms]. + ptr + Frame boundaries with shape [nframes + 1]. + box + Simulation boxes with shape [nframes, 9]. + fparam + Frame parameters with shape [nframes, ndf]. + aparam + Flattened atomic parameters with shape [total_atoms, nda]. + do_atomic_virial + Whether to calculate atomic virial. + + Returns + ------- + model_predict : dict[str, torch.Tensor] + Model predictions with flat format: + - atomwise outputs: [total_atoms, ...] + - frame-wise outputs: [nframes, ...] + """ + return self.forward_common_flat_native( + coord, + atype, + batch, + ptr, + box, + fparam, + aparam, + do_atomic_virial, + extended_atype=extended_atype, + extended_batch=extended_batch, + extended_image=extended_image, + extended_ptr=extended_ptr, + mapping=mapping, + central_ext_index=central_ext_index, + nlist=nlist, + nlist_ext=nlist_ext, + a_nlist=a_nlist, + a_nlist_ext=a_nlist_ext, + nlist_mask=nlist_mask, + a_nlist_mask=a_nlist_mask, + edge_index=edge_index, + angle_index=angle_index, + ) + def get_out_bias(self) -> torch.Tensor: return self.atomic_model.get_out_bias() diff --git a/deepmd/pt/model/network/graph_utils_flat.py b/deepmd/pt/model/network/graph_utils_flat.py new file mode 100644 index 0000000000..5d348c3104 --- /dev/null +++ b/deepmd/pt/model/network/graph_utils_flat.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + + +def get_graph_index_flat( + nlist_flat: torch.Tensor, + a_nlist_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Get edge and angle graph indices for flat neighbor lists. + + Parameters + ---------- + nlist_flat : torch.Tensor + Neighbor list in flat format [total_atoms, nnei]. + Indices refer to positions in extended_coord_flat. + Padded neighbors are marked as -1. + a_nlist_mask : torch.Tensor + Valid angle-neighbor mask with shape [total_atoms, a_sel]. + + Returns + ------- + edge_index : torch.Tensor [2, n_edge] + ``edge_index[0]`` : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + These are flat indices in range [0, total_atoms). + ``edge_index[1]`` : n_edge + Broadcast indices from extended node(j) to edge(ij). + These are flat indices in range [0, total_extended_atoms). + angle_index : torch.Tensor [3, n_angle] + ``angle_index[0]`` : n_angle + Broadcast indices from node(i) to angle(ijk). + These are flat indices in range [0, total_atoms). + ``angle_index[1]`` : n_angle + Broadcast indices from edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + These are edge indices in range [0, n_edge). + ``angle_index[2]`` : n_angle + Broadcast indices from edge(ik) to angle(ijk). + These are edge indices in range [0, n_edge). + """ + total_atoms = nlist_flat.shape[0] + nnei = nlist_flat.shape[1] + device = nlist_flat.device + dtype = nlist_flat.dtype + a_sel = a_nlist_mask.shape[1] + + # Create mask for valid neighbors (not -1) + nlist_mask = nlist_flat >= 0 # [total_atoms, nnei] + + # Count edges + n_edge = nlist_mask.sum().item() + + # Angle mask: both neighbors must be valid + a_nlist_mask_3d = a_nlist_mask[:, :, None] & a_nlist_mask[:, None, :] + + # 1. Build edge_index + # n2e_index: for each edge, which local atom does it belong to + atom_indices = torch.arange( + total_atoms, dtype=dtype, device=device + ) # [total_atoms] + n2e_index = atom_indices[:, None].expand(-1, nnei)[nlist_mask] # [n_edge] + + # n_ext2e_index: for each edge, which extended atom is the neighbor + n_ext2e_index = nlist_flat[nlist_mask] # [n_edge] + + edge_index = torch.stack([n2e_index, n_ext2e_index], dim=0) # [2, n_edge] + + # 2. Build angle_index + # n2a_index: for each angle, which local atom does it belong to + n2a_index = atom_indices[:, None, None].expand(-1, a_sel, a_sel)[a_nlist_mask_3d] + + # Create edge_id mapping: (atom_idx, neighbor_idx) -> edge_id + edge_id = torch.arange(n_edge, dtype=dtype, device=device) + edge_lookup = torch.full((total_atoms, nnei), -1, dtype=dtype, device=device) + edge_lookup[nlist_mask] = edge_id + # Only consider first a_sel neighbors for angles + edge_lookup_a = edge_lookup[:, :a_sel] # [total_atoms, a_sel] + + # eij2a_index: for each angle (i,j,k), the edge id of (i,j) + edge_lookup_ij = edge_lookup_a[:, :, None].expand(-1, -1, a_sel) + eij2a_index = edge_lookup_ij[a_nlist_mask_3d] # [n_angle] + + # eik2a_index: for each angle (i,j,k), the edge id of (i,k) + edge_lookup_ik = edge_lookup_a[:, None, :].expand(-1, a_sel, -1) + eik2a_index = edge_lookup_ik[a_nlist_mask_3d] # [n_angle] + + angle_index = torch.stack([n2a_index, eij2a_index, eik2a_index], dim=0) + return edge_index, angle_index diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index a8953fcd2b..4773ad88a7 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -197,5 +197,121 @@ def forward( ) return result + def forward_flat( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Forward pass with flat batch format. + + Parameters + ---------- + descriptor : torch.Tensor + Descriptor [total_atoms, descriptor_dim]. + atype : torch.Tensor + Atom types [total_atoms]. + batch : torch.Tensor + Frame assignment [total_atoms]. + ptr : torch.Tensor + Frame boundaries [nframes + 1]. + gr : torch.Tensor | None + Rotation matrix [total_atoms, e_dim, 3]. + g2 : torch.Tensor | None + Edge embedding. + h2 : torch.Tensor | None + Pair representation. + fparam : torch.Tensor | None + Frame parameters [nframes, ndf]. + aparam : torch.Tensor | None + Atomic parameters [total_atoms, nda]. + + Returns + ------- + result : dict[str, torch.Tensor] + Model predictions in flat atom format. Atom-wise outputs are + flattened back to ``[total_atoms, ...]`` after the regular fitting + network runs on a padded dense batch. + """ + device = descriptor.device + batch = batch.to(device=device, dtype=torch.long) + ptr = ptr.to(device=device, dtype=torch.long) + atype = atype.to(device=device) + + nframes = ptr.numel() - 1 + total_atoms = descriptor.shape[0] + atom_counts = ptr[1:] - ptr[:-1] + max_nloc = int(atom_counts.max().item()) + flat_index = torch.arange(total_atoms, dtype=torch.long, device=device) + local_index = flat_index - ptr[batch] + + descriptor_batch = torch.zeros( + (nframes, max_nloc, descriptor.shape[1]), + dtype=descriptor.dtype, + device=device, + ) + atype_batch = torch.full( + (nframes, max_nloc), + -1, + dtype=atype.dtype, + device=device, + ) + gr_batch = None + if gr is not None: + gr_batch = torch.zeros( + (nframes, max_nloc, *gr.shape[1:]), + dtype=gr.dtype, + device=device, + ) + aparam_batch = None + if aparam is not None: + aparam_batch = torch.zeros( + (nframes, max_nloc, *aparam.shape[1:]), + dtype=aparam.dtype, + device=device, + ) + + descriptor_batch[batch, local_index] = descriptor + atype_batch[batch, local_index] = atype + if gr is not None: + assert gr_batch is not None + gr_batch[batch, local_index] = gr + if aparam is not None: + assert aparam_batch is not None + aparam_batch[batch, local_index] = aparam + + result_batch = self.forward( + descriptor_batch, + atype_batch, + gr=gr_batch, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam_batch, + ) + + valid_atom_mask = torch.arange( + max_nloc, dtype=torch.long, device=device + ).unsqueeze(0) < atom_counts.unsqueeze(1) + result_flat: dict[str, torch.Tensor] = {} + for key, value in result_batch.items(): + if ( + isinstance(value, torch.Tensor) + and value.dim() >= 2 + and value.shape[0] == nframes + and value.shape[1] == max_nloc + ): + result_flat[key] = value[valid_atom_mask] + else: + result_flat[key] = value + + return result_flat + # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 012cdb3a65..0db423b11c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -91,6 +91,7 @@ LmdbDataset, _collate_lmdb_batch, _SameNlocBatchSamplerTorch, + make_lmdb_mixed_batch_collate, ) from deepmd.pt.utils.stat import ( make_stat_input, @@ -112,6 +113,7 @@ get_optimizer_state_dict, set_optimizer_state_dict, ) +from torch.nn.parallel import DistributedDataParallel as DDP try: from torch.distributed.fsdp import ( @@ -122,7 +124,6 @@ from torch.distributed.optim import ( ZeroRedundancyOptimizer, ) -from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import ( DataLoader, ) @@ -133,6 +134,25 @@ log = logging.getLogger(__name__) +_FLAT_GRAPH_INPUT_KEYS = ( + "batch", + "ptr", + "extended_atype", + "extended_batch", + "extended_image", + "extended_ptr", + "mapping", + "central_ext_index", + "nlist", + "nlist_ext", + "a_nlist", + "a_nlist_ext", + "nlist_mask", + "a_nlist_mask", + "edge_index", + "angle_index", +) + class Trainer: def __init__( @@ -249,6 +269,7 @@ def get_data_loader( _training_data: DpLoaderSet | LmdbDataset, _validation_data: DpLoaderSet | LmdbDataset | None, _training_params: dict[str, Any], + _task_key: str = "Default", ) -> tuple[ DataLoader, Generator[Any, None, None], @@ -259,19 +280,54 @@ def get_data_loader( def get_dataloader_and_iter_lmdb( _data: LmdbDataset, ) -> tuple[DataLoader, Generator[Any, None, None]]: + _shuffle = _training_params.get("shuffle", True) + _seed = _training_params.get("seed", training_params.get("seed", 42)) + if _seed is None: + _seed = 42 + if _data.mixed_batch: - # TODO [mixed_batch=True]: Replace SameNlocBatchSampler with - # RandomSampler(replacement=False) + padding collate_fn. - # Changes needed: - # 1. _collate_lmdb_batch: pad coord/force/atype to max_nloc, - # add "atom_mask" bool tensor (nframes, max_nloc) - # 2. Use RandomSampler(_data, replacement=False) as sampler - # 3. Use fixed batch_size in DataLoader (not batch_sampler) - # 4. Model forward: apply atom_mask to descriptor/fitting - # 5. Loss: mask out padded atoms in force loss - raise NotImplementedError( - "mixed_batch=True training is not yet supported." + from torch.utils.data import ( + RandomSampler, + SequentialSampler, ) + + if _shuffle: + generator = torch.Generator() + generator.manual_seed(_seed) + _sampler = RandomSampler( + _data, replacement=False, generator=generator + ) + else: + _sampler = SequentialSampler(_data) + + model_for_graph = ( + self.model[_task_key] if self.multi_task else self.model + ) + descriptor = model_for_graph.atomic_model.descriptor + if not hasattr(descriptor, "repflows"): + raise ValueError( + "mixed_batch=True currently requires a flat-graph " + "capable descriptor, for example DPA3/RepFlow." + ) + graph_config = { + "rcut": descriptor.get_rcut(), + "sel": descriptor.get_sel(), + "a_rcut": descriptor.repflows.a_rcut, + "a_sel": descriptor.repflows.a_sel, + "mixed_types": descriptor.mixed_types(), + } + + _dataloader = DataLoader( + _data, + batch_size=_data.batch_size, + sampler=_sampler, + num_workers=0, + collate_fn=make_lmdb_mixed_batch_collate(graph_config), + pin_memory=(DEVICE != "cpu"), + ) + _data_iter = cycle_iterator(_dataloader) + return _dataloader, _data_iter + # mixed_batch=False: group frames by nloc, each batch same nloc. # SameNlocBatchSampler yields list[int] per batch, all same nloc. # Auto batch_size is computed per-nloc-group inside the sampler. @@ -290,14 +346,15 @@ def get_dataloader_and_iter_lmdb( _data._reader, rank=self.rank, world_size=self.world_size, - shuffle=True, - seed=_training_params.get("seed", None), + shuffle=_shuffle, + seed=_seed, block_targets=_block_targets, ) else: _inner_sampler = SameNlocBatchSampler( _data._reader, - shuffle=True, + shuffle=_shuffle, + seed=_seed, block_targets=_block_targets, ) @@ -555,6 +612,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: training_data[model_key], validation_data[model_key], training_params["data_dict"][model_key], + _task_key=model_key, ) training_data[model_key].print_summary( @@ -1284,6 +1342,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: input_dict, label_dict, log_dict = self.get_data( is_train=True, task_key=task_key ) + if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) @@ -2008,6 +2067,7 @@ def save_ema_model( def get_data( self, is_train: bool = True, task_key: str = "Default" ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + if is_train: iterator = self.training_data else: @@ -2017,16 +2077,25 @@ def get_data( if iterator is None: return {}, {}, {} batch_data = next(iterator) + + # Detect mixed batch format (has 'batch' and 'ptr' fields) + is_mixed_batch = "batch" in batch_data and "ptr" in batch_data + for key in batch_data.keys(): - if key == "sid" or key == "fid" or key == "box" or "find_" in key: + if key == "sid" or key == "fid" or "find_" in key: + continue + # Skip batch and ptr for now, will handle them separately + elif key == "batch" or key == "ptr": continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: batch_data[key] = batch_data[key].to(DEVICE, non_blocking=True) else: batch_data[key] = [ - item.to(DEVICE, non_blocking=True) for item in batch_data[key] + item.to(DEVICE, non_blocking=True) if item is not None else None + for item in batch_data[key] ] + # we may need a better way to classify which are inputs and which are labels # now wrapper only supports the following inputs: input_keys = [ @@ -2037,6 +2106,13 @@ def get_data( "fparam", "aparam", ] + + # Mixed-nloc LMDB batches include precomputed flat-graph tensors. + if is_mixed_batch: + input_keys = input_keys + list(_FLAT_GRAPH_INPUT_KEYS) + batch_data["batch"] = batch_data["batch"].to(DEVICE, non_blocking=True) + batch_data["ptr"] = batch_data["ptr"].to(DEVICE, non_blocking=True) + input_dict = dict.fromkeys(input_keys) label_dict = {} for item_key in batch_data: diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index ddb4a4323d..925f92b428 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -165,6 +165,22 @@ def forward( do_atomic_virial: bool = False, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + batch: torch.Tensor | None = None, + ptr: torch.Tensor | None = None, + extended_atype: torch.Tensor | None = None, + extended_batch: torch.Tensor | None = None, + extended_image: torch.Tensor | None = None, + extended_ptr: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + central_ext_index: torch.Tensor | None = None, + nlist: torch.Tensor | None = None, + nlist_ext: torch.Tensor | None = None, + a_nlist: torch.Tensor | None = None, + a_nlist_ext: torch.Tensor | None = None, + nlist_mask: torch.Tensor | None = None, + a_nlist_mask: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + angle_index: torch.Tensor | None = None, ) -> tuple[Any, Any, Any]: if not self.multi_task: task_key = "Default" @@ -180,6 +196,30 @@ def forward( "fparam": fparam, "aparam": aparam, } + + # Mixed-nloc LMDB batches carry a precomputed flat graph. + if batch is not None and ptr is not None: + input_dict.update( + { + "batch": batch, + "ptr": ptr, + "extended_atype": extended_atype, + "extended_batch": extended_batch, + "extended_image": extended_image, + "extended_ptr": extended_ptr, + "mapping": mapping, + "central_ext_index": central_ext_index, + "nlist": nlist, + "nlist_ext": nlist_ext, + "a_nlist": a_nlist, + "a_nlist_ext": a_nlist_ext, + "nlist_mask": nlist_mask, + "a_nlist_mask": a_nlist_mask, + "edge_index": edge_index, + "angle_index": angle_index, + } + ) + has_spin = getattr(self.model[task_key], "has_spin", False) if callable(has_spin): has_spin = has_spin() @@ -194,7 +234,9 @@ def forward( model_pred[k] = model_pred[k] + v return model_pred, None, None else: - natoms = atype.shape[-1] + # For mixed batch, natoms is the total flattened atoms + # For regular batch, natoms is per-frame atoms + natoms = atype.shape[-1] if atype.dim() > 1 else atype.shape[0] model_pred, loss, more_loss = self.loss[task_key]( input_dict, self.model[task_key], diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py index 067b420da9..e4ef1b5c9c 100644 --- a/deepmd/pt/utils/lmdb_dataset.py +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -3,6 +3,7 @@ import logging from collections.abc import ( + Callable, Iterator, ) from typing import ( @@ -15,6 +16,9 @@ Dataset, Sampler, ) +from torch.utils.data._utils.collate import ( + collate_tensor_fn, +) from deepmd.dpmodel.utils.lmdb_data import ( LmdbDataReader, @@ -30,17 +34,134 @@ log = logging.getLogger(__name__) +FrameDict = dict[str, Any] +BatchDict = dict[str, Any] +GraphConfig = dict[str, Any] +MixedBatchCollate = Callable[[list[FrameDict]], BatchDict] + # Re-export for backward compatibility __all__ = [ "LmdbDataset", "LmdbTestData", "_collate_lmdb_batch", + "_collate_lmdb_mixed_batch", "is_lmdb", + "make_lmdb_mixed_batch_collate", ] +_ATOMWISE_MIXED_BATCH_KEYS = frozenset( + { + "aparam", + "atom_dos", + "atom_ener", + "atom_ener_coeff", + "atom_pref", + "atomic_weight", + "atype", + "coord", + "drdq", + "force", + "force_mag", + "hessian", + "spin", + } +) + + +def _collate_lmdb_mixed_batch(batch: list[FrameDict]) -> BatchDict: + """Collate mixed-nloc frames into flattened atom-wise tensors. -def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Collate a list of frame dicts into a torch batch dict. + Atom-wise fields are concatenated across frames and accompanied by: + + - ``batch``: flattened atom-to-frame assignment with shape ``[total_atoms]``. + - ``ptr``: prefix-sum atom offsets with shape ``[nframes + 1]``. + + Frame-wise fields such as ``energy`` and ``box`` keep the usual batch + dimension via ``torch.stack``. The returned ``sid`` keeps the historical + LMDB collate shape, namely a CPU tensor with shape ``[1]``. + """ + with torch.device("cpu"): + atype_list = [torch.as_tensor(item["atype"]) for item in batch] + counts = torch.tensor( + [int(item.shape[0]) for item in atype_list], + dtype=torch.long, + device=torch.device("cpu"), + ) + ptr = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=counts.device), + torch.cumsum(counts, dim=0), + ], + dim=0, + ) + atom_batch = torch.repeat_interleave( + torch.arange(len(batch), dtype=torch.long, device=counts.device), + counts, + ) + + example = batch[0] + result: BatchDict = {} + for key in example: + if "find_" in key: + result[key] = batch[0][key] + elif key == "fid": + result[key] = [d[key] for d in batch] + elif key == "type": + continue + elif batch[0][key] is None: + result[key] = None + else: + with torch.device("cpu"): + tensors = [torch.as_tensor(d[key]) for d in batch] + if key in _ATOMWISE_MIXED_BATCH_KEYS: + result[key] = torch.cat(tensors, dim=0) + else: + result[key] = collate_tensor_fn(tensors) + result["batch"] = atom_batch + result["ptr"] = ptr + result["sid"] = torch.tensor([0], dtype=torch.long, device="cpu") + return result + + +def make_lmdb_mixed_batch_collate( + graph_config: GraphConfig | None = None, +) -> MixedBatchCollate: + """Build a collate function for flattened mixed-nloc LMDB batches. + + When ``graph_config`` is provided, the collate function also precomputes the + extended image, neighbor lists, masks, edge index, and angle index consumed + by the flat DPA3 forward path. ``graph_config`` is expected to contain + ``rcut``, ``sel``, ``a_rcut``, ``a_sel``, and ``mixed_types``. + """ + + def collate(batch: list[FrameDict]) -> BatchDict: + result = _collate_lmdb_mixed_batch(batch) + if graph_config is None: + return result + from deepmd.pt.utils.nlist import ( + build_precomputed_flat_graph, + ) + + graph_data = build_precomputed_flat_graph( + result["coord"], + result["atype"], + result["batch"], + result["ptr"], + graph_config["rcut"], + graph_config["sel"], + graph_config["a_rcut"], + graph_config["a_sel"], + mixed_types=graph_config["mixed_types"], + box=result.get("box"), + ) + result.update(graph_data) + return result + + return collate + + +def _collate_lmdb_batch(batch: list[FrameDict]) -> BatchDict: + """Collate a list of frame dicts into a batch dict. Pre-converts per-frame numpy arrays to CPU torch tensors (zero-copy when dtype matches) and delegates stacking to the backend-agnostic @@ -48,18 +169,14 @@ def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: collate yields a torch dict (``sid`` becomes a torch tensor automatically via ``array_api_compat``). - All frames in the batch must have the same nloc (enforced by - SameNlocBatchSampler when mixed_batch=False). For mixed_batch=True, - raises NotImplementedError. + For mixed_batch=True, this function would need padding + mask. + Mixed-nloc batches are flattened atom-wise and augmented with ``batch`` and + ``ptr`` to preserve frame ownership. """ if len(batch) > 1: atypes = [d.get("atype") for d in batch if d.get("atype") is not None] if atypes and any(len(a) != len(atypes[0]) for a in atypes): - raise NotImplementedError( - "mixed_batch collation (frames with different atom counts " - "in the same batch) is not yet supported. " - "Padding + mask in collate_fn needed." - ) + return _collate_lmdb_mixed_batch(batch) with torch.device("cpu"): torch_frames: list[dict[str, Any]] = [] @@ -133,13 +250,7 @@ def __init__( self._reader = LmdbDataReader( lmdb_path, type_map, batch_size, mixed_batch=mixed_batch ) - - if mixed_batch: - # Future: DataLoader with padding collate_fn - raise NotImplementedError( - "mixed_batch=True is not yet supported. " - "Requires padding + mask in collate_fn." - ) + self._batch_sampler: _SameNlocBatchSamplerTorch | None = None # Compute block_targets from auto_prob_style if provided self._block_targets = None @@ -149,27 +260,42 @@ def __init__( self._reader.nsystems, self._reader.system_nframes, ) - if self._block_targets is not None: + if self._block_targets: + if mixed_batch: + raise NotImplementedError( + "auto_prob_style/block weighting is not supported with " + "mixed_batch=True yet." + ) log.info( f"LMDB auto_prob: {len(self._block_targets)} blocks, " f"nsystems={self._reader.nsystems}" ) - # Same-nloc batching: use SameNlocBatchSampler - sampler = SameNlocBatchSampler( - self._reader, - shuffle=True, - block_targets=self._block_targets, - ) - self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) - - with torch.device("cpu"): - self._inner_dataloader = DataLoader( - self, - batch_sampler=self._batch_sampler, - num_workers=0, - collate_fn=_collate_lmdb_batch, + if mixed_batch: + with torch.device("cpu"): + self._inner_dataloader = DataLoader( + self, + batch_size=self._reader.batch_size, + shuffle=True, + num_workers=0, + collate_fn=_collate_lmdb_mixed_batch, + ) + else: + # Same-nloc batching: use SameNlocBatchSampler + sampler = SameNlocBatchSampler( + self._reader, + shuffle=True, + block_targets=self._block_targets, ) + self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) + + with torch.device("cpu"): + self._inner_dataloader = DataLoader( + self, + batch_sampler=self._batch_sampler, + num_workers=0, + collate_fn=_collate_lmdb_batch, + ) # Per-nloc-group dataloaders for make_stat_input. # Each group gets its own DataLoader so torch.cat in stat collection @@ -326,7 +452,9 @@ def batch_sizes(self) -> list[int]: @property def systems(self) -> list: """One 'system' per nloc group for stat collection compatibility.""" - return [self] * len(self._nloc_dataloaders) + if self._nloc_dataloaders: + return [self] * len(self._nloc_dataloaders) + return [self] @property def dataloaders(self) -> list: @@ -335,8 +463,12 @@ def dataloaders(self) -> list: Each dataloader yields batches with uniform nloc, so torch.cat in stat collection only concatenates same-shape tensors. """ - return self._nloc_dataloaders + if self._nloc_dataloaders: + return self._nloc_dataloaders + return [self._inner_dataloader] @property def sampler_list(self) -> list: + if self._batch_sampler is None: + return [] return [self._batch_sampler] diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 7f74e65f26..07d372ad4e 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -10,6 +10,8 @@ to_face_distance, ) +FlatGraphData = dict[str, torch.Tensor] + def extend_input_and_build_neighbor_list( coord: torch.Tensor, @@ -44,6 +46,307 @@ def extend_input_and_build_neighbor_list( return extended_coord, extended_atype, mapping, nlist +def build_precomputed_flat_graph( + coord: torch.Tensor, + atype: torch.Tensor, + batch: torch.Tensor, + ptr: torch.Tensor, + rcut: float, + sel: list[int], + a_rcut: float, + a_sel: int, + mixed_types: bool = False, + box: torch.Tensor | None = None, +) -> FlatGraphData: + """Build graph tensors for flattened mixed-nloc LMDB batches. + + Parameters + ---------- + coord + Flattened local coordinates with shape ``[total_atoms, 3]``. + atype + Flattened local atom types with shape ``[total_atoms]``. + batch + Local atom-to-frame assignment with shape ``[total_atoms]``. + ptr + Prefix-sum local atom offsets with shape ``[nframes + 1]``. + rcut, sel + Edge cutoff and neighbor selection used by the descriptor. + a_rcut, a_sel + Angle cutoff and maximum angle-neighbor count. + mixed_types + Whether neighbor selection ignores atom types. + box + Optional flattened cell tensor with shape ``[nframes, 9]``. + + Returns + ------- + FlatGraphData + Dictionary consumed by the flat forward path. ``*_ext`` neighbor lists + index into the concatenated extended atoms, while ``nlist`` and + ``a_nlist`` map neighbors back to flattened local atom indices. + """ + device = coord.device + nframes = ptr.numel() - 1 + extended_coords_list = [] + extended_atypes_list = [] + extended_batches_list = [] + extended_images_list = [] + extended_to_atom_list = [] + nlists_ext_list = [] + central_indices_list = [] + extended_ptr = torch.zeros(nframes + 1, dtype=torch.long, device=device) + extended_offset = 0 + + for frame_idx in range(nframes): + start_idx = int(ptr[frame_idx].item()) + end_idx = int(ptr[frame_idx + 1].item()) + nloc = end_idx - start_idx + frame_coord = coord[start_idx:end_idx].reshape(1, nloc, 3) + frame_atype = atype[start_idx:end_idx].reshape(1, nloc) + frame_box = box[frame_idx : frame_idx + 1] if box is not None else None + + if frame_box is not None: + box_device = frame_box.to(device, non_blocking=True) + coord_normalized = normalize_coord( + frame_coord, + box_device.reshape(1, 3, 3), + ) + else: + box_device = None + coord_normalized = frame_coord.clone() + + ( + frame_extended_coord, + frame_extended_atype, + frame_mapping, + frame_extended_image, + ) = extend_coord_with_ghosts_with_images( + coord_normalized.reshape(1, -1), + frame_atype, + box_device, + rcut, + frame_box, + ) + frame_nlist_ext = build_neighbor_list( + frame_extended_coord, + frame_extended_atype, + nloc, + rcut, + sel, + distinguish_types=(not mixed_types), + ) + + frame_extended_coord = frame_extended_coord.view(-1, 3) + frame_extended_atype = frame_extended_atype.view(-1) + frame_mapping = frame_mapping.view(-1) + frame_extended_image = frame_extended_image.view(-1, 3) + frame_nlist_ext = frame_nlist_ext.view(nloc, -1) + nall_frame = frame_extended_coord.shape[0] + + central_indices_list.append( + torch.arange( + extended_offset, + extended_offset + nloc, + dtype=torch.long, + device=device, + ) + ) + nlists_ext_list.append( + torch.where( + frame_nlist_ext >= 0, + frame_nlist_ext + extended_offset, + frame_nlist_ext, + ) + ) + extended_coords_list.append(frame_extended_coord) + extended_atypes_list.append(frame_extended_atype) + extended_batches_list.append( + torch.full((nall_frame,), frame_idx, dtype=torch.long, device=device) + ) + extended_images_list.append(frame_extended_image) + extended_to_atom_list.append(frame_mapping + start_idx) + extended_offset += nall_frame + extended_ptr[frame_idx + 1] = extended_offset + + extended_coord = torch.cat(extended_coords_list, dim=0) + extended_atype = torch.cat(extended_atypes_list, dim=0) + extended_batch = torch.cat(extended_batches_list, dim=0) + extended_image = torch.cat(extended_images_list, dim=0) + mapping = torch.cat(extended_to_atom_list, dim=0) + central_ext_index = torch.cat(central_indices_list, dim=0) + nlist_ext = torch.cat(nlists_ext_list, dim=0) + nlist_mask = nlist_ext >= 0 + + nall = extended_coord.shape[0] + nlist_ext_clamped = torch.clamp(nlist_ext, min=0, max=nall - 1) + nlist = torch.where( + nlist_mask, + mapping[nlist_ext_clamped], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + + coord_central = extended_coord[central_ext_index] + coord_pad = torch.cat([extended_coord, extended_coord[-1:, :] + rcut], dim=0) + nlist_safe = torch.where( + nlist_mask, + nlist_ext, + torch.tensor(nall, dtype=nlist_ext.dtype, device=device), + ) + index = nlist_safe.view(-1).unsqueeze(-1).expand(-1, 3) + coord_nei = torch.gather(coord_pad, 0, index).view(nlist_ext.shape[0], -1, 3) + dist = torch.linalg.norm(coord_nei - coord_central[:, None, :], dim=-1) + a_dist_mask = (dist[:, :a_sel] < a_rcut) & nlist_mask[:, :a_sel] + a_nlist_ext = torch.where( + a_dist_mask, + nlist_ext[:, :a_sel], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + a_nlist_mask = a_nlist_ext >= 0 + a_nlist_ext_clamped = torch.clamp(a_nlist_ext, min=0, max=nall - 1) + a_nlist = torch.where( + a_nlist_mask, + mapping[a_nlist_ext_clamped], + torch.tensor(-1, dtype=nlist_ext.dtype, device=device), + ) + + from deepmd.pt.model.network.graph_utils_flat import ( + get_graph_index_flat, + ) + + edge_index, angle_index = get_graph_index_flat( + nlist, + a_nlist_mask, + ) + return { + "extended_atype": extended_atype, + "extended_batch": extended_batch, + "extended_image": extended_image, + "extended_ptr": extended_ptr, + "mapping": mapping, + "central_ext_index": central_ext_index, + "nlist": nlist, + "nlist_ext": nlist_ext, + "a_nlist": a_nlist, + "a_nlist_ext": a_nlist_ext, + "nlist_mask": nlist_mask, + "a_nlist_mask": a_nlist_mask, + "edge_index": edge_index, + "angle_index": angle_index, + } + + +def rebuild_extended_coord_from_flat_graph( + coord: torch.Tensor, + box: torch.Tensor | None, + mapping: torch.Tensor, + extended_batch: torch.Tensor, + extended_image: torch.Tensor, +) -> torch.Tensor: + """Reconstruct extended coordinates from precomputed flat graph metadata. + + ``mapping`` maps each extended atom to its source local atom. When ``box`` + is available, ``extended_image`` is applied after wrapping the source local + coordinate back into the corresponding periodic cell. + """ + if box is None: + return coord[mapping] + cell = box.reshape(-1, 3, 3) + atom_cell = cell[extended_batch] + rec_cell, _ = torch.linalg.inv_ex(atom_cell) + coord_inter = torch.einsum("ni,nij->nj", coord[mapping], rec_cell) + coord_wrapped = torch.einsum( + "ni,nij->nj", + torch.remainder(coord_inter, 1.0), + atom_cell, + ) + image = extended_image.to(dtype=box.dtype, device=box.device) + shift_vec = torch.einsum("ni,nij->nj", image, atom_cell) + return coord_wrapped + shift_vec + + +def get_central_ext_index( + extended_batch: torch.Tensor, + ptr: torch.Tensor, +) -> torch.Tensor: + """Return extended-atom indices corresponding to local atoms.""" + nframes = ptr.numel() - 1 + extended_counts = torch.bincount(extended_batch, minlength=nframes) + extended_ptr = torch.cat( + [ + torch.zeros(1, dtype=extended_counts.dtype, device=extended_counts.device), + torch.cumsum(extended_counts, dim=0), + ] + ) + extended_index = torch.arange( + extended_batch.shape[0], + dtype=extended_batch.dtype, + device=extended_batch.device, + ) + frame_local_index = extended_index - extended_ptr[extended_batch] + nloc_per_frame = (ptr[1:] - ptr[:-1]).to(extended_batch.device) + central_mask = frame_local_index < nloc_per_frame[extended_batch] + return torch.nonzero(central_mask, as_tuple=False).view(-1) + + +def extend_input_and_build_neighbor_list_with_images( + coord: torch.Tensor, + atype: torch.Tensor, + rcut: float, + sel: list[int], + mixed_types: bool = False, + box: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Like ``extend_input_and_build_neighbor_list`` but also returns lattice images. + + This helper is intended for sidecar graph precomputation workflows that need a + stable, replayable description of how extended atoms are generated without + changing the existing training path. + + Returns + ------- + extended_coord + Extended coordinates with shape ``[nf, nall, 3]``. + extended_atype + Extended atom types with shape ``[nf, nall]``. + mapping + Extended atom -> local atom index mapping with shape ``[nf, nall]``. + extended_image + Integer lattice image for each extended atom with shape ``[nf, nall, 3]``. + nlist + Neighbor list with shape ``[nf, nloc, nnei]``. + """ + nframes, nloc = atype.shape[:2] + if box is not None: + box_gpu = box.to(coord.device, non_blocking=True) + coord_normalized = normalize_coord( + coord.view(nframes, nloc, 3), + box_gpu.reshape(nframes, 3, 3), + ) + else: + box_gpu = None + coord_normalized = coord.clone() + extended_coord, extended_atype, mapping, extended_image = ( + extend_coord_with_ghosts_with_images( + coord_normalized, + atype, + box_gpu, + rcut, + box, + ) + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=(not mixed_types), + ) + extended_coord = extended_coord.view(nframes, -1, 3) + return extended_coord, extended_atype, mapping, extended_image, nlist + + def build_neighbor_list( coord: torch.Tensor, atype: torch.Tensor, @@ -438,9 +741,54 @@ def extend_coord_with_ghosts( mapping extended index to the local index """ + extend_coord, extend_atype, extend_aidx, _ = _extend_coord_with_ghosts_impl( + coord, + atype, + cell, + rcut, + cell_cpu=cell_cpu, + return_image=False, + ) + return extend_coord, extend_atype, extend_aidx + + +def extend_coord_with_ghosts_with_images( + coord: torch.Tensor, + atype: torch.Tensor, + cell: torch.Tensor | None, + rcut: float, + cell_cpu: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Extend coordinates and additionally return the integer lattice image. + + The returned image tensor records which periodic image each extended atom + belongs to. This is useful for sidecar graph serialization where extended + coordinates should be recoverable from the original local coordinates and + the simulation cell. + """ + extend_coord, extend_atype, extend_aidx, extend_image = ( + _extend_coord_with_ghosts_impl( + coord, + atype, + cell, + rcut, + cell_cpu=cell_cpu, + return_image=True, + ) + ) + return extend_coord, extend_atype, extend_aidx, extend_image + + +def _extend_coord_with_ghosts_impl( + coord: torch.Tensor, + atype: torch.Tensor, + cell: torch.Tensor | None, + rcut: float, + cell_cpu: torch.Tensor | None = None, + return_image: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: device = coord.device nf, nloc = atype.shape - # int64 for index aidx = torch.tile( torch.arange(nloc, device=device, dtype=torch.int64).unsqueeze(0), [nf, 1] ) @@ -449,18 +797,17 @@ def extend_coord_with_ghosts( extend_coord = coord.clone() extend_atype = atype.clone() extend_aidx = aidx.clone() + if return_image: + extend_image = torch.zeros((nf, nloc, 3), device=device, dtype=torch.int64) + else: + extend_image = torch.empty((0,), device=device, dtype=torch.int64) else: coord = coord.view([nf, nloc, 3]) cell = cell.view([nf, 3, 3]) cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell - # nf x 3 to_face = to_face_distance(cell_cpu) - # nf x 3 - # *2: ghost copies on + and - directions - # +1: central cell nbuff = torch.ceil(rcut / to_face).to(torch.int64) - # 3 - nbuff = torch.amax(nbuff, dim=0) # faster than torch.max + nbuff = torch.amax(nbuff, dim=0) nbuff_cpu = nbuff.cpu() xi = torch.arange( -nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu", dtype=torch.int64 @@ -477,20 +824,24 @@ def extend_coord_with_ghosts( xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2] xyz = xyz.view(-1, 3) xyz = xyz.to(device=device, non_blocking=True) - # ns x 3 shift_idx = xyz[torch.argsort(torch.linalg.norm(xyz, dim=-1))] + # Convert shift_idx to the same dtype as cell to avoid type mismatch + shift_idx = shift_idx.to(dtype=cell.dtype) ns, _ = shift_idx.shape nall = ns * nloc - # nf x ns x 3 shift_vec = torch.einsum("sd,fdk->fsk", shift_idx, cell) - # nf x ns x nloc x 3 extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] - # nf x ns x nloc extend_atype = torch.tile(atype.unsqueeze(-2), [1, ns, 1]) - # nf x ns x nloc extend_aidx = torch.tile(aidx.unsqueeze(-2), [1, ns, 1]) - return ( - extend_coord.reshape([nf, nall * 3]).to(device), - extend_atype.view([nf, nall]).to(device), - extend_aidx.view([nf, nall]).to(device), - ) + if return_image: + extend_image = torch.tile(shift_idx.view(1, ns, 1, 3), [nf, 1, nloc, 1]) + else: + extend_image = torch.empty((0,), device=device, dtype=torch.int64) + extend_coord_out = extend_coord.reshape([nf, nall * 3]).to(device) + extend_atype_out = extend_atype.view([nf, nall]).to(device) + extend_aidx_out = extend_aidx.view([nf, nall]).to(device) + if return_image: + extend_image_out = extend_image.view([nf, nall, 3]).to(device) + else: + extend_image_out = extend_image + return extend_coord_out, extend_atype_out, extend_aidx_out, extend_image_out diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0364b24695..37eb2c73f9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3760,6 +3760,15 @@ def training_data_args() -> list[ "specifying the probability of each system." ) + doc_mixed_batch = ( + "Whether to enable LMDB mixed-batch training with different numbers of atoms " + "per frame. When set to True, the PyTorch LMDB dataloader flattens atom-wise " + "fields and precomputes graph indices in the collate function. In this mode, " + "`batch_size` is the number of frames/systems per batch rather than the total " + "atom count. " + "The alias `mix_batch` is accepted. Default is False." + ) + args = [ Argument( "systems", [list[str], str], optional=False, default=".", doc=doc_systems @@ -3796,6 +3805,14 @@ def training_data_args() -> list[ doc=doc_sys_probs, alias=["sys_weights"], ), + Argument( + "mixed_batch", + bool, + optional=True, + default=False, + alias=["mix_batch"], + doc=doc_mixed_batch + doc_only_pt_supported, + ), ] doc_training_data = "Configurations of training data." @@ -3840,6 +3857,14 @@ def validation_data_args() -> list[ "specifying the probability of each system." ) doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period." + doc_mixed_batch = ( + "Whether to enable LMDB mixed-batch validation with different numbers of atoms " + "per frame. When set to True, the PyTorch LMDB dataloader flattens atom-wise " + "fields and precomputes graph indices in the collate function. In this mode, " + "`batch_size` is the number of frames/systems per batch rather than the total " + "atom count. " + "The alias `mix_batch` is accepted. Default is False." + ) args = [ Argument( @@ -3887,6 +3912,14 @@ def validation_data_args() -> list[ "numb_batch", ], ), + Argument( + "mixed_batch", + bool, + optional=True, + default=False, + alias=["mix_batch"], + doc=doc_mixed_batch + doc_only_pt_supported, + ), ] doc_validation_data = ( diff --git a/doc/development/lmdb-mixed-system-batching.md b/doc/development/lmdb-mixed-system-batching.md new file mode 100644 index 0000000000..6049aa68cc --- /dev/null +++ b/doc/development/lmdb-mixed-system-batching.md @@ -0,0 +1,92 @@ +# LMDB 不同大小 System 的 Batch 拼接实现 + +本文记录 `deepmd-kit-lmdb` 当前 PyTorch 训练中,LMDB 数据如何把不同 `nloc` 的 frame/system 拼成一个 batch,并进入 mixed-batch forward。 + +## 核心思路 + +- 默认 `mixed_batch=False` 时,不真正混合不同大小的 system:`SameNlocBatchSampler` 先按 `nloc` 分组,每个 batch 内 frame 的原子数相同,然后走普通 `torch.stack` collate。 +- `mixed_batch=True` 时,batch 内允许不同 `nloc`。实现方式不是 padding 原始输入,而是把 atom-wise 字段按原子维度展平拼接: + - `coord`: `[sum(nloc_i), 3]` + - `atype`: `[sum(nloc_i)]` + - `force` / `aparam` 等 atom-wise 字段同样按第 0 维 `torch.cat` + - `energy` / `box` / `fparam` / `virial` 等 frame-wise 字段仍按 frame 维 `torch.stack` +- 额外生成两个索引张量保留 frame 边界: + - `batch`: `[total_atoms]`,每个原子属于第几个 frame + - `ptr`: `[nframes + 1]`,前缀和边界,例如 `[0, nloc_0, nloc_0+nloc_1, ...]` + +因此,第 `i` 个 frame 的局部原子范围可以通过 `coord[ptr[i]:ptr[i + 1]]` 还原。 + +## Flat Graph 预处理 + +当 descriptor 是 DPA3/Repflows 路径时,训练侧会从模型 descriptor 取出 `rcut`、`sel`、`a_rcut`、`a_sel`、`mixed_types`,传给 `make_lmdb_mixed_batch_collate(graph_config)`。 + +collate 阶段会调用 `build_precomputed_flat_graph(...)`,逐个 frame 做邻居图预处理: + +1. 通过 `ptr` 切出单个 frame 的 `coord/atype/box`。 +1. 对单个 frame 调用 ghost 扩展和 neighbor list 构建。 +1. 用 `extended_offset` 把每个 frame 的扩展原子索引平移到全 batch 的 flat index 空间。 +1. 拼接得到: + - `extended_atype`, `extended_batch`, `extended_image`, `extended_ptr` + - `mapping`: extended atom -> 原始 flat local atom + - `central_ext_index`: local atom 在 extended atom 列表里的位置 + - `nlist_ext`, `a_nlist_ext`: 指向 extended atom 的邻居表 + - `nlist`, `a_nlist`: 映射回 local flat atom 的邻居表 + - `nlist_mask`, `a_nlist_mask`, `edge_index`, `angle_index` + +`extended_image` 和 `mapping` 在 forward 中用于从原始 `coord/box` 重建可求导的 `extended_coord`,保证 force/virial 的 autograd 仍连接到原始输入。 + +## 调用关系 + +```text +Trainer.get_dataloader_and_iter_lmdb + -> DataLoader( + dataset=LmdbDataset, + batch_size=_data.batch_size, + sampler=RandomSampler/SequentialSampler, + collate_fn=make_lmdb_mixed_batch_collate(graph_config), + ) + -> LmdbDataset.__getitem__ + -> LmdbDataReader.__getitem__ + 读取单个 frame,并把 coord/atype/force/box/energy 等整理成标准形状 + -> _collate_lmdb_mixed_batch + atom-wise 字段 torch.cat,frame-wise 字段 torch.stack + 生成 batch 和 ptr + -> build_precomputed_flat_graph # graph_config 存在时 + 生成 flat graph 相关字段 + +Trainer.get_data + -> 检测 batch_data 中是否有 batch/ptr + -> 把 _FLAT_GRAPH_INPUT_KEYS 加入 input_keys + -> 将 flat graph 字段移动到 DEVICE + +ModelWrapper.forward + -> input_dict.update(flat graph fields) + +EnergyModel.forward + -> batch 和 ptr 非空时进入 forward_common_flat + -> forward_common_flat_native + -> rebuild_extended_coord_from_flat_graph + -> forward_common_lower_flat + -> DPAtomicModel.forward_common_atomic_flat + -> descriptor.forward_flat + -> DPA3.forward_flat + -> Repflows.forward_flat + -> fitting_net.forward_flat + -> energy_atomic.index_add_(0, batch, ...) 得到 energy_redu + -> _compute_derivatives_flat # 需要 force/virial 时 +``` + +## 关键文件 + +- `deepmd/pt/utils/lmdb_dataset.py`: LMDB PyTorch dataset 和 mixed-batch collate。 +- `deepmd/pt/utils/nlist.py`: flat graph 预计算和 extended coord 重建。 +- `deepmd/pt/train/training.py`: mixed-batch DataLoader 创建、flat graph 输入字段搬运。 +- `deepmd/pt/train/wrapper.py`: 把 flat graph 字段传入模型。 +- `deepmd/pt/model/model/ener_model.py`: 检测 `batch/ptr` 并进入 flat forward。 +- `deepmd/pt/model/model/make_model.py`: flat forward、frame-wise energy reduction、导数计算。 +- `deepmd/pt/model/atomic_model/dp_atomic_model.py`: flat atomic model forward。 +- `deepmd/pt/model/descriptor/dpa3.py` 和 `deepmd/pt/model/descriptor/repflows.py`: 消费预计算 flat graph。 + +## 小结 + +当前 mixed-batch 的本质是 **flat concatenation + `batch/ptr` 边界索引 + collate 阶段预计算 flat graph**。这样可以在同一个 batch 中放入不同原子数的 system,同时避免在主模型输入层对 `coord/atype` 做全局 padding。 diff --git a/source/tests/pt/test_lmdb_dataloader.py b/source/tests/pt/test_lmdb_dataloader.py index ebb505706d..f077a38f09 100644 --- a/source/tests/pt/test_lmdb_dataloader.py +++ b/source/tests/pt/test_lmdb_dataloader.py @@ -24,6 +24,10 @@ from deepmd.pt.utils.lmdb_dataset import ( LmdbDataset, _collate_lmdb_batch, + make_lmdb_mixed_batch_collate, +) +from deepmd.utils.argcheck import ( + training_data_args, ) from deepmd.utils.data import ( DataRequirementItem, @@ -208,6 +212,16 @@ def test_mixed_type(self, lmdb_dir): ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) assert ds.mixed_type is True + def test_mixed_batch_init(self, multi_nloc_lmdb): + ds = LmdbDataset( + multi_nloc_lmdb, + type_map=["O", "H"], + batch_size=3, + mixed_batch=True, + ) + assert ds.mixed_batch is True + assert len(ds.dataloaders) == 1 + # ============================================================ # Trainer compatibility interface @@ -353,6 +367,108 @@ def test_collate_none_values(self): ] assert _collate_lmdb_batch(frames)["box"] is None + def test_collate_mixed_nloc_flattens_atomwise(self): + rng = np.random.default_rng(7) + frames = [ + { + "coord": rng.standard_normal((2, 3)), + "atype": np.array([0, 1], dtype=np.int64), + "force": rng.standard_normal((2, 3)), + "atom_ener": rng.standard_normal((2, 1)), + "drdq": rng.standard_normal((2, 6)), + "energy": np.array([1.0]), + "box": np.arange(9, dtype=np.float64), + "find_energy": 1.0, + "fid": 3, + }, + { + "coord": rng.standard_normal((3, 3)), + "atype": np.array([1, 0, 1], dtype=np.int64), + "force": rng.standard_normal((3, 3)), + "atom_ener": rng.standard_normal((3, 1)), + "drdq": rng.standard_normal((3, 6)), + "energy": np.array([2.0]), + "box": np.arange(9, dtype=np.float64) + 10.0, + "find_energy": 1.0, + "fid": 9, + }, + ] + batch = _collate_lmdb_batch(frames) + assert batch["coord"].shape == (5, 3) + assert batch["atype"].shape == (5,) + assert batch["force"].shape == (5, 3) + assert batch["atom_ener"].shape == (5, 1) + assert batch["drdq"].shape == (5, 6) + assert batch["energy"].shape == (2, 1) + assert batch["box"].shape == (2, 9) + torch.testing.assert_close( + batch["batch"], torch.tensor([0, 0, 1, 1, 1], device="cpu") + ) + torch.testing.assert_close(batch["ptr"], torch.tensor([0, 2, 5], device="cpu")) + assert batch["fid"] == [3, 9] + + def test_mixed_batch_collate_precomputes_graph(self): + frames = [ + { + "coord": np.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0]]), + "atype": np.array([0, 0], dtype=np.int64), + "force": np.zeros((2, 3)), + "energy": np.array([0.0]), + "box": np.eye(3).reshape(9), + "find_energy": 1.0, + "fid": 0, + }, + { + "coord": np.array([[0.0, 0.0, 0.0], [0.2, 0.0, 0.0], [0.0, 0.2, 0.0]]), + "atype": np.array([0, 0, 0], dtype=np.int64), + "force": np.zeros((3, 3)), + "energy": np.array([1.0]), + "box": np.eye(3).reshape(9), + "find_energy": 1.0, + "fid": 1, + }, + ] + collate = make_lmdb_mixed_batch_collate( + { + "rcut": 0.8, + "sel": [4], + "a_rcut": 0.8, + "a_sel": 4, + "mixed_types": True, + } + ) + batch = collate(frames) + assert batch["coord"].shape == (5, 3) + torch.testing.assert_close(batch["ptr"], torch.tensor([0, 2, 5], device="cpu")) + for key in ( + "extended_atype", + "extended_batch", + "extended_image", + "extended_ptr", + "mapping", + "central_ext_index", + "nlist", + "nlist_ext", + "a_nlist", + "a_nlist_ext", + "nlist_mask", + "a_nlist_mask", + "edge_index", + "angle_index", + ): + assert key in batch + assert batch["nlist"].shape[0] == 5 + assert batch["edge_index"].shape[0] == 2 + assert batch["angle_index"].shape[0] == 3 + + def test_mix_batch_arg_alias(self): + arg = training_data_args() + normalized = arg.normalize_value( + {"systems": "train.lmdb", "batch_size": 2, "mix_batch": True}, + trim_pattern="_*", + ) + assert normalized["mixed_batch"] is True + # ============================================================ # Type map remapping (PT-specific: LmdbDataset) @@ -601,6 +717,15 @@ def test_dataset_auto_prob_passthrough(self, auto_prob_lmdb): ) assert ds._block_targets is not None + def test_dataset_auto_prob_default_mixed_batch(self, auto_prob_lmdb): + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + mixed_batch=True, + ) + assert ds._block_targets is None + def test_dataset_auto_prob_none(self, auto_prob_lmdb): ds = LmdbDataset(auto_prob_lmdb, type_map=["O", "H"], batch_size=4) assert ds._block_targets is None @@ -624,6 +749,16 @@ def test_dataset_auto_prob_iteration(self, auto_prob_lmdb): count = sum(len(batch) for batch in ds._batch_sampler) assert count > 300 # expanded + def test_dataset_auto_prob_mixed_batch_raises(self, auto_prob_lmdb): + with pytest.raises(NotImplementedError, match="mixed_batch=True"): + LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + mixed_batch=True, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + class TestMergeLmdbSystemIds: """Test merge_lmdb propagates frame_system_ids.""" diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 2519111357..af4094bab0 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -16,6 +16,9 @@ EnergySpinLoss, EnergyStdLoss, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.dataset import ( DeepmdDataSetForLoader, ) @@ -370,6 +373,97 @@ def fake_model(): self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) +class TestEnerStdLossMixedBatch(unittest.TestCase): + def test_per_frame_energy_and_virial_normalization(self) -> None: + loss_obj = EnergyStdLoss( + starter_learning_rate=1.0, + start_pref_e=1.0, + limit_pref_e=1.0, + start_pref_v=1.0, + limit_pref_v=1.0, + ) + + energy_pred = torch.tensor( + [[10.0], [200.0]], dtype=torch.float64, device=env.DEVICE + ) + energy_label = torch.zeros_like(energy_pred) + virial_pred = torch.stack( + [ + torch.full((9,), 10.0, dtype=torch.float64, device=env.DEVICE), + torch.full((9,), 200.0, dtype=torch.float64, device=env.DEVICE), + ] + ) + virial_label = torch.zeros_like(virial_pred) + + def fake_model(**kwargs): + return { + "energy": energy_pred, + "virial": virial_pred, + } + + _, loss, _ = loss_obj( + { + "ptr": torch.tensor([0, 10, 110], dtype=torch.long, device=env.DEVICE), + }, + fake_model, + { + "energy": energy_label, + "find_energy": 1.0, + "virial": virial_label, + "find_virial": 1.0, + }, + natoms=0, + learning_rate=1.0, + ) + + expected_per_term = ( + torch.tensor( + [10.0**2 / 10, 200.0**2 / 100], + dtype=torch.float64, + device=env.DEVICE, + ) + .mean() + .to(loss.dtype) + ) + torch.testing.assert_close(loss, expected_per_term * 2.0) + + def test_generalized_force_rejected(self) -> None: + loss_obj = EnergyStdLoss( + starter_learning_rate=1.0, + start_pref_f=1.0, + limit_pref_f=1.0, + start_pref_gf=1.0, + limit_pref_gf=1.0, + numb_generalized_coord=2, + ) + + def fake_model(**kwargs): + return { + "force": torch.zeros((2, 3), dtype=torch.float64, device=env.DEVICE), + } + + with self.assertRaisesRegex( + NotImplementedError, + "Generalized force loss is not supported with mixed_batch=True yet.", + ): + loss_obj( + {"ptr": torch.tensor([0, 2], dtype=torch.long, device=env.DEVICE)}, + fake_model, + { + "force": torch.zeros( + (2, 3), dtype=torch.float64, device=env.DEVICE + ), + "drdq": torch.zeros( + (1, 12), dtype=torch.float64, device=env.DEVICE + ), + "find_force": 1.0, + "find_drdq": 1.0, + }, + natoms=2, + learning_rate=1.0, + ) + + class TestEnerStdLossAePfGf(LossCommonTest): def setUp(self) -> None: self.start_lr = 1.1 diff --git a/test_mixed_batch.sh b/test_mixed_batch.sh new file mode 100755 index 0000000000..85eef37d76 --- /dev/null +++ b/test_mixed_batch.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Test script for mixed batch training with LMDB + +set -e + +repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$repo_root" + +echo "=== Testing Mixed Batch Training ===" +echo "" + +echo "Starting training with mixed_batch=True..." +echo "" + +dp --pt train test_mptraj/lmdb_mixed_batch.json --skip-neighbor-stat >mixed_batch_train.log 2>&1 + +echo "" +echo "=== Training completed ===" +echo "Check mixed_batch_train.log for details" diff --git a/test_mptraj/lmdb_baseline.json b/test_mptraj/lmdb_baseline.json new file mode 100644 index 0000000000..812d4a92ed --- /dev/null +++ b/test_mptraj/lmdb_baseline.json @@ -0,0 +1,215 @@ +{ + "_comment": "based on baseline_exp/input.json, with full_validation (PR #5336) enabled", + "model": { + "type_map": [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 256, + "e_dim": 128, + "a_dim": 64, + "nlayers": 3, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.5, + "a_rcut_smth": 4.0, + "a_sel": 300, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_dynamic_sel": true, + "sel_reduce_factor": 10.0, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "precision": "float32", + "activation_function": "silut:10.0", + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-05, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 20, + "start_pref_v": 0.02, + "limit_pref_v": 1, + "_comment": " that's all" + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "${DEEPMD_TEST_STAT_FILE}", + "training_data": { + "systems": "${DEEPMD_TEST_LMDB_TRAIN}", + "batch_size": "auto:128", + "_comment": "that's all" + }, + "validation_data": { + "systems": "${DEEPMD_TEST_LMDB_VALID}", + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "seed": 10, + "max_ckpt_keep": 10000000, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 10000, + "_comment": "that's all" + } +} diff --git a/test_mptraj/lmdb_mixed_batch.json b/test_mptraj/lmdb_mixed_batch.json new file mode 100644 index 0000000000..93f6fd1b83 --- /dev/null +++ b/test_mptraj/lmdb_mixed_batch.json @@ -0,0 +1,203 @@ +{ + "_comment": "Test config for mixed batch training with LMDB", + "model": { + "type_map": [ + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Nh", + "Fl", + "Mc", + "Lv", + "Ts", + "Og" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 256, + "e_dim": 128, + "a_dim": 64, + "nlayers": 3, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 1200, + "a_rcut": 4.5, + "a_rcut_smth": 4.0, + "a_sel": 300, + "axis_neuron": 4, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_use_dist": true, + "use_dynamic_sel": true, + "sel_reduce_factor": 10.0, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "precision": "float32", + "activation_function": "silut:10.0" + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-05 + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 20, + "start_pref_v": 0.02, + "limit_pref_v": 1 + }, + "training": { + "stat_file": "./MP_traj_v024_alldata_mixu.hdf5", + "training_data": { + "systems": "/mnt/data_nas/zhangd/data/lmdb_data/mptraj_v024.lmdb", + "batch_size": 8, + "mixed_batch": true, + "_comment": "Enable mixed_batch mode for testing, disable shuffle for reproducibility" + }, + "validation_data": { + "systems": "/mnt/data_nas/zhangd/data/lmdb_data/wbm.lmdb", + "batch_size": 8, + "mixed_batch": true + }, + "numb_steps": 1000, + "gradient_max_norm": 5.0, + "seed": 42, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100 + } +}