Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@ frozen_model.*
# Test system directories
system/
*.expected
temp/
pkl/
history/
deepmd-kit/
*.hdf5
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
98 changes: 67 additions & 31 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,38 @@ def forward(
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
# Normalization exponent controls loss scaling with system size:
# - norm_exp=2 (intensive_ener_virial=True): loss uses 1/N² scaling, making it independent of system size
# - norm_exp=1 (intensive_ener_virial=False, legacy): loss uses 1/N scaling, which varies with system size

# Detect mixed batch format
is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None

atom_norms = None
if is_mixed_batch:
ptr = input_dict["ptr"]
natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]
atom_norms = 1.0 / natoms_per_frame.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
atom_norm = None
else:
atom_norm = 1.0 / natoms
Comment thread
coderabbitai[bot] marked this conversation as resolved.
norm_exp = 2 if self.intensive_ener_virial else 1

def get_frame_norm(value: torch.Tensor) -> torch.Tensor:
assert atom_norms is not None
return atom_norms.to(device=value.device, dtype=value.dtype).view(
[-1] + [1] * (value.dim() - 1)
)

def weighted_mean(value: torch.Tensor, power: int = 1) -> torch.Tensor:
if atom_norms is None:
assert atom_norm is not None
return value.mean() * (atom_norm**power)
return (value * get_frame_norm(value) ** power).mean()

def normalized_rmse(diff: torch.Tensor) -> torch.Tensor:
if atom_norms is None:
assert atom_norm is not None
return torch.mean(torch.square(diff)).sqrt() * atom_norm
return torch.mean(torch.square(diff * get_frame_norm(diff))).sqrt()

if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
Expand All @@ -261,35 +288,37 @@ def forward(
energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1)
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
diff_e = energy_pred - energy_label
if self.loss_func == "mse":
l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label))
square_ener_diff = torch.square(diff_e)
l2_ener_loss = torch.mean(square_ener_diff)
if not self.inference:
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
if not self.use_huber:
loss += atom_norm**norm_exp * (pref_e * l2_ener_loss)
loss += pref_e * weighted_mean(square_ener_diff, norm_exp)
else:
energy_norm = (
atom_norm if atom_norms is None else get_frame_norm(energy_pred)
)
l_huber_loss = custom_huber_loss(
atom_norm * energy_pred,
atom_norm * energy_label,
energy_norm * energy_pred,
energy_norm * energy_label,
delta=self._huber_delta_energy,
)
loss += pref_e * l_huber_loss
rmse_e = l2_ener_loss.sqrt() * atom_norm
rmse_e = normalized_rmse(diff_e)
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
elif self.loss_func == "mae":
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
)
loss += atom_norm * (pref_e * l1_ener_loss)
abs_ener_diff = torch.abs(diff_e)
mae_e = weighted_mean(abs_ener_diff)
loss += pref_e * mae_e
more_loss["mae_e"] = self.display_if_exist(
l1_ener_loss.detach() * atom_norm,
mae_e.detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
Expand All @@ -298,9 +327,9 @@ def forward(
f"Loss type {self.loss_func} is not implemented for energy loss."
)
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
mae_e = weighted_mean(torch.abs(diff_e))
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
mae_e_all = torch.mean(torch.abs(diff_e))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
Expand Down Expand Up @@ -417,6 +446,10 @@ def forward(
)

if self.has_gf and "drdq" in label:
if is_mixed_batch:
raise NotImplementedError(
"Generalized force loss is not supported with mixed_batch=True yet."
)
drdq = label["drdq"]
find_drdq = label.get("find_drdq", 0.0)
pref_gf = pref_gf * find_drdq
Expand Down Expand Up @@ -446,41 +479,44 @@ def forward(
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
if self.loss_func == "mse":
l2_virial_loss = torch.mean(torch.square(diff_v))
square_virial_diff = torch.square(diff_v)
l2_virial_loss = torch.mean(square_virial_diff)
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
if not self.use_huber:
loss += atom_norm**norm_exp * (pref_v * l2_virial_loss)
loss += pref_v * weighted_mean(square_virial_diff, norm_exp)
else:
virial = model_pred["virial"].reshape(-1, 9)
virial_label = label["virial"].reshape(-1, 9)
virial_norm = (
atom_norm if atom_norms is None else get_frame_norm(virial)
)
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
(virial_norm * virial).reshape(-1),
(virial_norm * virial_label).reshape(-1),
delta=self._huber_delta_virial,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
rmse_v = normalized_rmse(diff_v)
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
elif self.loss_func == "mae":
l1_virial_loss = F.l1_loss(
label["virial"].reshape(-1),
model_pred["virial"].reshape(-1),
reduction="mean",
)
loss += atom_norm * (pref_v * l1_virial_loss)
abs_virial_diff = torch.abs(diff_v)
mae_v = weighted_mean(abs_virial_diff)
loss += pref_v * mae_v
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss.detach() * atom_norm,
mae_v.detach(),
find_virial,
)
else:
raise NotImplementedError(
f"Loss type {self.loss_func} is not implemented for virial loss."
)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
mae_v = weighted_mean(torch.abs(diff_v))
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
Expand Down
147 changes: 147 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,153 @@ def forward_atomic(
)
return fit_ret

def forward_common_atomic_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
extended_ptr: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
aparam : torch.Tensor | None
Atomic parameters [total_atoms, nda].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result_dict : dict[str, torch.Tensor]
Model predictions in flat format.
"""
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

if (
hasattr(self.fitting_net, "get_dim_fparam")
and self.fitting_net.get_dim_fparam() > 0
and fparam is None
):
default_fparam_tensor = self.fitting_net.get_default_fparam()
assert default_fparam_tensor is not None
fparam_input_for_des = torch.tile(
default_fparam_tensor.to(device=extended_coord.device).unsqueeze(0),
[ptr.numel() - 1, 1],
)
else:
fparam_input_for_des = fparam

# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
extended_atype,
extended_batch,
nlist,
mapping,
batch,
ptr,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

descriptor = descriptor_out.get("descriptor")
rot_mat = descriptor_out.get("rot_mat")
g2 = descriptor_out.get("g2")
h2 = descriptor_out.get("h2")

if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())

if central_ext_index is None:
from deepmd.pt.utils.nlist import (
get_central_ext_index,
)

central_ext_index = get_central_ext_index(extended_batch, ptr)
atype = extended_atype[central_ext_index]
else:
atype = extended_atype[central_ext_index]

fit_ret = self.fitting_net.forward_flat(
descriptor,
atype,
batch,
ptr,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
fit_ret = self.apply_out_stat(fit_ret, atype)

atom_mask = self.make_atom_mask(atype).to(torch.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0)

for kk in fit_ret.keys():
out_shape = fit_ret[kk].shape
out_shape2 = 1
for ss in out_shape[1:]:
out_shape2 *= ss
fit_ret[kk] = (
fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None]
).view(out_shape)
fit_ret["mask"] = atom_mask

if self.enable_eval_fitting_last_layer_hook:
if "middle_output" in fit_ret:
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)

return fit_ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
Loading
Loading