diff --git a/pina/__init__.py b/pina/__init__.py index 2cbe7f3bb..0d38804fe 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,8 +1,14 @@ -"""Module for the Pina library.""" +""" +PINA: Physics-Informed Neural Analysis. + +A specialized framework for Scientific Machine Learning (SciML), providing +tools for Physics-Informed Neural Networks (PINNs), Neural Operators, +and data-driven physical modeling. +""" __all__ = [ - "Trainer", "LabelTensor", + "Trainer", "Condition", "PinaDataModule", "Graph", @@ -10,9 +16,9 @@ "MultiSolverInterface", ] -from .label_tensor import LabelTensor -from .graph import Graph -from .solver import SolverInterface, MultiSolverInterface -from .trainer import Trainer -from .condition.condition import Condition -from .data import PinaDataModule +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph +from pina._src.solver.solver import SolverInterface, MultiSolverInterface +from pina._src.core.trainer import Trainer +from pina._src.condition.condition import Condition +from pina._src.data.data_module import PinaDataModule diff --git a/pina/_src/__init__.py b/pina/_src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/adaptive_function/__init__.py b/pina/_src/adaptive_function/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/adaptive_function/adaptive_function.py b/pina/_src/adaptive_function/adaptive_function.py similarity index 99% rename from pina/adaptive_function/adaptive_function.py rename to pina/_src/adaptive_function/adaptive_function.py index e6f86a549..21f45fd1e 100644 --- a/pina/adaptive_function/adaptive_function.py +++ b/pina/_src/adaptive_function/adaptive_function.py @@ -1,8 +1,10 @@ """Module for the Adaptive Functions.""" import torch -from ..utils import check_consistency -from .adaptive_function_interface import AdaptiveActivationFunctionInterface +from pina._src.core.utils import check_consistency +from pina._src.adaptive_function.adaptive_function_interface import ( + AdaptiveActivationFunctionInterface, +) class AdaptiveReLU(AdaptiveActivationFunctionInterface): diff --git a/pina/adaptive_function/adaptive_function_interface.py b/pina/_src/adaptive_function/adaptive_function_interface.py similarity index 98% rename from pina/adaptive_function/adaptive_function_interface.py rename to pina/_src/adaptive_function/adaptive_function_interface.py index a655fdbd7..d73382cb6 100644 --- a/pina/adaptive_function/adaptive_function_interface.py +++ b/pina/_src/adaptive_function/adaptive_function_interface.py @@ -2,7 +2,7 @@ from abc import ABCMeta import torch -from ..utils import check_consistency, is_function +from pina._src.core.utils import check_consistency, is_function class AdaptiveActivationFunctionInterface(torch.nn.Module, metaclass=ABCMeta): diff --git a/pina/_src/callback/__init__.py b/pina/_src/callback/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/callback/optim/__init__.py b/pina/_src/callback/optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/callback/optim/switch_optimizer.py b/pina/_src/callback/optim/switch_optimizer.py similarity index 96% rename from pina/callback/optim/switch_optimizer.py rename to pina/_src/callback/optim/switch_optimizer.py index 3072b7c2e..4f6f0be09 100644 --- a/pina/callback/optim/switch_optimizer.py +++ b/pina/_src/callback/optim/switch_optimizer.py @@ -1,8 +1,8 @@ """Module for the SwitchOptimizer callback.""" from lightning.pytorch.callbacks import Callback -from ...optim import TorchOptimizer -from ...utils import check_consistency +from pina._src.optim.torch_optimizer import TorchOptimizer +from pina._src.core.utils import check_consistency class SwitchOptimizer(Callback): diff --git a/pina/callback/optim/switch_scheduler.py b/pina/_src/callback/optim/switch_scheduler.py similarity index 95% rename from pina/callback/optim/switch_scheduler.py rename to pina/_src/callback/optim/switch_scheduler.py index 3641f4ee4..bd4920bba 100644 --- a/pina/callback/optim/switch_scheduler.py +++ b/pina/_src/callback/optim/switch_scheduler.py @@ -1,8 +1,8 @@ """Module for the SwitchScheduler callback.""" from lightning.pytorch.callbacks import Callback -from ...optim import TorchScheduler -from ...utils import check_consistency, check_positive_integer +from pina._src.optim.torch_scheduler import TorchScheduler +from pina._src.core.utils import check_consistency, check_positive_integer class SwitchScheduler(Callback): diff --git a/pina/_src/callback/processing/__init__.py b/pina/_src/callback/processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/callback/processing/metric_tracker.py b/pina/_src/callback/processing/metric_tracker.py similarity index 100% rename from pina/callback/processing/metric_tracker.py rename to pina/_src/callback/processing/metric_tracker.py diff --git a/pina/callback/processing/normalizer_data_callback.py b/pina/_src/callback/processing/normalizer_data_callback.py similarity index 97% rename from pina/callback/processing/normalizer_data_callback.py rename to pina/_src/callback/processing/normalizer_data_callback.py index 4d85a7d9a..2524f5765 100644 --- a/pina/callback/processing/normalizer_data_callback.py +++ b/pina/_src/callback/processing/normalizer_data_callback.py @@ -2,10 +2,10 @@ import torch from lightning.pytorch import Callback -from ...label_tensor import LabelTensor -from ...utils import check_consistency, is_function -from ...condition import InputTargetCondition -from ...data.dataset import PinaGraphDataset +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency, is_function +from pina._src.condition.condition import InputTargetCondition +from pina._src.data.dataset import PinaGraphDataset class NormalizerDataCallback(Callback): diff --git a/pina/callback/processing/pina_progress_bar.py b/pina/_src/callback/processing/pina_progress_bar.py similarity index 98% rename from pina/callback/processing/pina_progress_bar.py rename to pina/_src/callback/processing/pina_progress_bar.py index 4c322a5e8..90c34f8cc 100644 --- a/pina/callback/processing/pina_progress_bar.py +++ b/pina/_src/callback/processing/pina_progress_bar.py @@ -4,7 +4,7 @@ from lightning.pytorch.callbacks.progress.progress_bar import ( get_standard_metrics, ) -from pina.utils import check_consistency +from pina._src.core.utils import check_consistency class PINAProgressBar(TQDMProgressBar): diff --git a/pina/_src/callback/refinement/__init__.py b/pina/_src/callback/refinement/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/callback/refinement/r3_refinement.py b/pina/_src/callback/refinement/r3_refinement.py similarity index 93% rename from pina/callback/refinement/r3_refinement.py rename to pina/_src/callback/refinement/r3_refinement.py index 863dedfc1..b8bcc7285 100644 --- a/pina/callback/refinement/r3_refinement.py +++ b/pina/_src/callback/refinement/r3_refinement.py @@ -1,10 +1,12 @@ """Module for the R3Refinement callback.""" import torch -from .refinement_interface import RefinementInterface -from ...label_tensor import LabelTensor -from ...utils import check_consistency -from ...loss import LossInterface +from pina._src.callback.refinement.refinement_interface import ( + RefinementInterface, +) +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency +from pina._src.loss.loss_interface import LossInterface class R3Refinement(RefinementInterface): diff --git a/pina/callback/refinement/refinement_interface.py b/pina/_src/callback/refinement/refinement_interface.py similarity index 97% rename from pina/callback/refinement/refinement_interface.py rename to pina/_src/callback/refinement/refinement_interface.py index adc6e4e7c..83ca8d8be 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/_src/callback/refinement/refinement_interface.py @@ -5,8 +5,10 @@ from abc import ABCMeta, abstractmethod from lightning.pytorch import Callback -from ...utils import check_consistency -from ...solver.physics_informed_solver import PINNInterface +from pina._src.core.utils import check_consistency +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) class RefinementInterface(Callback, metaclass=ABCMeta): diff --git a/pina/_src/condition/__init__.py b/pina/_src/condition/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/condition/condition.py b/pina/_src/condition/condition.py similarity index 95% rename from pina/condition/condition.py rename to pina/_src/condition/condition.py index ad8764c9f..db2a666d8 100644 --- a/pina/condition/condition.py +++ b/pina/_src/condition/condition.py @@ -1,9 +1,11 @@ """Module for the Condition class.""" -from .data_condition import DataCondition -from .domain_equation_condition import DomainEquationCondition -from .input_equation_condition import InputEquationCondition -from .input_target_condition import InputTargetCondition +from pina._src.condition.data_condition import DataCondition +from pina._src.condition.domain_equation_condition import ( + DomainEquationCondition, +) +from pina._src.condition.input_equation_condition import InputEquationCondition +from pina._src.condition.input_target_condition import InputTargetCondition class Condition: diff --git a/pina/condition/condition_interface.py b/pina/_src/condition/condition_interface.py similarity index 98% rename from pina/condition/condition_interface.py rename to pina/_src/condition/condition_interface.py index b0264517c..509ac2fc3 100644 --- a/pina/condition/condition_interface.py +++ b/pina/_src/condition/condition_interface.py @@ -2,8 +2,8 @@ from abc import ABCMeta from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph class ConditionInterface(metaclass=ABCMeta): diff --git a/pina/condition/data_condition.py b/pina/_src/condition/data_condition.py similarity index 96% rename from pina/condition/data_condition.py rename to pina/_src/condition/data_condition.py index 5f5e7d36b..ec6da762c 100644 --- a/pina/condition/data_condition.py +++ b/pina/_src/condition/data_condition.py @@ -2,9 +2,9 @@ import torch from torch_geometric.data import Data -from .condition_interface import ConditionInterface -from ..label_tensor import LabelTensor -from ..graph import Graph +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph class DataCondition(ConditionInterface): diff --git a/pina/condition/domain_equation_condition.py b/pina/_src/condition/domain_equation_condition.py similarity index 89% rename from pina/condition/domain_equation_condition.py rename to pina/_src/condition/domain_equation_condition.py index 3565c0b41..93e76892a 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/_src/condition/domain_equation_condition.py @@ -1,9 +1,9 @@ """Module for the DomainEquationCondition class.""" -from .condition_interface import ConditionInterface -from ..utils import check_consistency -from ..domain import DomainInterface -from ..equation.equation_interface import EquationInterface +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.core.utils import check_consistency +from pina._src.domain import DomainInterface +from pina._src.equation.equation_interface import EquationInterface class DomainEquationCondition(ConditionInterface): diff --git a/pina/condition/input_equation_condition.py b/pina/_src/condition/input_equation_condition.py similarity index 95% rename from pina/condition/input_equation_condition.py rename to pina/_src/condition/input_equation_condition.py index d32597894..636d8b9f8 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/_src/condition/input_equation_condition.py @@ -1,10 +1,10 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_interface import ConditionInterface -from ..label_tensor import LabelTensor -from ..graph import Graph -from ..utils import check_consistency -from ..equation.equation_interface import EquationInterface +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph +from pina._src.core.utils import check_consistency +from pina._src.equation.equation_interface import EquationInterface class InputEquationCondition(ConditionInterface): diff --git a/pina/condition/input_target_condition.py b/pina/_src/condition/input_target_condition.py similarity index 98% rename from pina/condition/input_target_condition.py rename to pina/_src/condition/input_target_condition.py index 07b07bb7b..e1392ed75 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/_src/condition/input_target_condition.py @@ -4,9 +4,9 @@ import torch from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph -from .condition_interface import ConditionInterface +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.graph import Graph +from pina._src.condition.condition_interface import ConditionInterface class InputTargetCondition(ConditionInterface): diff --git a/pina/_src/core/__init__.py b/pina/_src/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/core/graph.py b/pina/_src/core/graph.py new file mode 100644 index 000000000..3c72051ec --- /dev/null +++ b/pina/_src/core/graph.py @@ -0,0 +1,421 @@ +"""Module to build Graph objects and perform operations on them.""" + +import torch +from torch_geometric.data import Data, Batch +from torch_geometric.utils import to_undirected +from torch_geometric.utils.loop import remove_self_loops +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency, is_function + + +class Graph(Data): + """ + Extends :class:`~torch_geometric.data.Data` class to include additional + checks and functionalities. + """ + + def __new__( + cls, + **kwargs, + ): + """ + Create a new instance of the :class:`~pina.graph.Graph` class by + checking the consistency of the input data and storing the attributes. + + :param dict kwargs: Parameters used to initialize the + :class:`~pina.graph.Graph` object. + :return: A new instance of the :class:`~pina.graph.Graph` class. + :rtype: Graph + """ + # create class instance + instance = Data.__new__(cls) + + # check the consistency of types defined in __init__, the others are not + # checked (as in pyg Data object) + instance._check_type_consistency(**kwargs) + + return instance + + def __init__( + self, + x=None, + edge_index=None, + pos=None, + edge_attr=None, + undirected=False, + **kwargs, + ): + """ + Initialize the object by setting the node features, edge index, + edge attributes, and positions. The edge index is preprocessed to make + the graph undirected if required. For more details, see the + :meth:`torch_geometric.data.Data` + + :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the + number of features per node. + :type x: torch.Tensor, LabelTensor + :param torch.Tensor edge_index: A tensor of shape ``(2, E)`` + representing the indices of the graph's edges. + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where + ``F'`` is the number of edge features + :type edge_attr: torch.Tensor | LabelTensor + :param bool undirected: Whether to make the graph undirected + :param dict kwargs: Additional keyword arguments passed to the + :class:`~torch_geometric.data.Data` class constructor. + """ + # preprocessing + self._preprocess_edge_index(edge_index, undirected) + + # calling init + super().__init__( + x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs + ) + + def _check_type_consistency(self, **kwargs): + """ + Check the consistency of the types of the input data. + + :param dict kwargs: Attributes to be checked for consistency. + """ + # default types, specified in cls.__new__, by default they are Nont + # if specified in **kwargs they get override + x, pos, edge_index, edge_attr = None, None, None, None + if "pos" in kwargs: + pos = kwargs["pos"] + self._check_pos_consistency(pos) + if "edge_index" in kwargs: + edge_index = kwargs["edge_index"] + self._check_edge_index_consistency(edge_index) + if "x" in kwargs: + x = kwargs["x"] + self._check_x_consistency(x, pos) + if "edge_attr" in kwargs: + edge_attr = kwargs["edge_attr"] + self._check_edge_attr_consistency(edge_attr, edge_index) + if "undirected" in kwargs: + undirected = kwargs["undirected"] + check_consistency(undirected, bool) + + @staticmethod + def _check_pos_consistency(pos): + """ + Check if the position tensor is consistent. + :param torch.Tensor pos: The position tensor. + :raises ValueError: If the position tensor is not consistent. + """ + if pos is not None: + check_consistency(pos, (torch.Tensor, LabelTensor)) + if pos.ndim != 2: + raise ValueError("pos must be a 2D tensor.") + + @staticmethod + def _check_edge_index_consistency(edge_index): + """ + Check if the edge index is consistent. + + :param torch.Tensor edge_index: The edge index tensor. + :raises ValueError: If the edge index tensor is not consistent. + """ + check_consistency(edge_index, (torch.Tensor, LabelTensor)) + if edge_index.ndim != 2: + raise ValueError("edge_index must be a 2D tensor.") + if edge_index.size(0) != 2: + raise ValueError("edge_index must have shape [2, num_edges].") + + @staticmethod + def _check_edge_attr_consistency(edge_attr, edge_index): + """ + Check if the edge attribute tensor is consistent in type and shape + with the edge index. + + :param edge_attr: The edge attribute tensor. + :type edge_attr: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: The edge index tensor. + :raises ValueError: If the edge attribute tensor is not consistent. + """ + if edge_attr is not None: + check_consistency(edge_attr, (torch.Tensor, LabelTensor)) + if edge_attr.ndim != 2: + raise ValueError("edge_attr must be a 2D tensor.") + if edge_attr.size(0) != edge_index.size(1): + raise ValueError( + "edge_attr must have shape " + "[num_edges, num_edge_features], expected " + f"num_edges {edge_index.size(1)} " + f"got {edge_attr.size(0)}." + ) + + @staticmethod + def _check_x_consistency(x, pos=None): + """ + Check if the input tensor x is consistent with the position tensor + `pos`. + + :param x: The input tensor. + :type x: torch.Tensor | LabelTensor + :param pos: The position tensor. + :type pos: torch.Tensor | LabelTensor + :raises ValueError: If the input tensor is not consistent. + """ + if x is not None: + check_consistency(x, (torch.Tensor, LabelTensor)) + if x.ndim != 2: + raise ValueError("x must be a 2D tensor.") + if pos is not None: + if x.size(0) != pos.size(0): + raise ValueError("Inconsistent number of nodes.") + + @staticmethod + def _preprocess_edge_index(edge_index, undirected): + """ + Preprocess the edge index to make the graph undirected (if required). + + :param torch.Tensor edge_index: The edge index. + :param bool undirected: Whether the graph is undirected. + :return: The preprocessed edge index. + :rtype: torch.Tensor + """ + if undirected: + edge_index = to_undirected(edge_index) + return edge_index + + def extract(self, labels, attr="x"): + """ + Perform extraction of labels from the attribute specified by `attr`. + + :param labels: Labels to extract + :type labels: list[str] | tuple[str] | str | dict + :return: Batch object with extraction performed on x + :rtype: PinaBatch + """ + # Extract labels from LabelTensor object + tensor = getattr(self, attr).extract(labels) + # Set the extracted tensor as the new attribute + setattr(self, attr, tensor) + return self + + +class GraphBuilder: + """ + A class that allows an easy definition of :class:`Graph` instances. + """ + + def __new__( + cls, + pos, + edge_index, + x=None, + edge_attr=False, + custom_edge_func=None, + loop=True, + **kwargs, + ): + """ + Compute the edge attributes and create a new instance of the + :class:`~pina.graph.Graph` class. + + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. + :type pos: torch.Tensor or LabelTensor + :param edge_index: A tensor of shape ``(2, E)`` representing the indices + of the graph's edges. + :type edge_index: torch.Tensor + :param x: Optional tensor of node features of shape ``(N, F)``, where + ``F`` is the number of features per node. + :type x: torch.Tensor | LabelTensor, optional + :param bool edge_attr: Whether to compute the edge attributes. + :param custom_edge_func: A custom function to compute edge attributes. + If provided, overrides ``edge_attr``. + :type custom_edge_func: Callable, optional + :param bool loop: Whether to include self-loops. + :param kwargs: Additional keyword arguments passed to the + :class:`~pina.graph.Graph` class constructor. + :return: A :class:`~pina.graph.Graph` instance constructed using the + provided information. + :rtype: Graph + """ + if not loop: + edge_index = remove_self_loops(edge_index)[0] + edge_attr = cls._create_edge_attr( + pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr + ) + return Graph( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + pos=pos, + **kwargs, + ) + + @staticmethod + def _create_edge_attr(pos, edge_index, edge_attr, func): + """ + Create the edge attributes based on the input parameters. + + :param pos: Positions of the points. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: Edge indices. + :param bool edge_attr: Whether to compute the edge attributes. + :param Callable func: Function to compute the edge attributes. + :raises ValueError: If ``func`` is not a function. + :return: The edge attributes. + :rtype: torch.Tensor | LabelTensor | None + """ + check_consistency(edge_attr, bool) + if edge_attr: + if is_function(func): + return func(pos, edge_index) + raise ValueError("custom_edge_func must be a function.") + return None + + @staticmethod + def _build_edge_attr(pos, edge_index): + """ + Default function to compute the edge attributes. + + :param pos: Positions of the points. + :type pos: torch.Tensor | LabelTensor + :param torch.Tensor edge_index: Edge indices. + :return: The edge attributes. + :rtype: torch.Tensor + """ + return ( + (pos[edge_index[0]] - pos[edge_index[1]]) + .abs() + .as_subclass(torch.Tensor) + ) + + +class RadiusGraph(GraphBuilder): + """ + Extends the :class:`~pina.graph.GraphBuilder` class to compute + ``edge_index`` based on a radius. Each point is connected to all the points + within the radius. + """ + + def __new__(cls, pos, radius, **kwargs): + """ + Instantiate the :class:`~pina.graph.Graph` class by computing the + ``edge_index`` based on the radius provided. + + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param float radius: The radius within which points are connected. + :param dict kwargs: The additional keyword arguments to be passed to + :class:`GraphBuilder` and :class:`Graph` classes. + :return: A :class:`~pina.graph.Graph` instance with the computed + ``edge_index``. + :rtype: Graph + """ + edge_index = cls.compute_radius_graph(pos, radius) + return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) + + @staticmethod + def compute_radius_graph(points, radius): + """ + Computes the ``edge_index`` based on the radius. Each point is connected + to all the points within the radius. + + :param points: A tensor of shape ``(N, D)`` representing the positions + of ``N`` points in ``D``-dimensional space. + :type points: torch.Tensor | LabelTensor + :param float radius: The radius within which points are connected. + :return: A tensor of shape ``(2, E)``, with ``E`` number of edges, + representing the edge indices of the graph. + :rtype: torch.Tensor + """ + dist = torch.cdist(points, points, p=2) + return ( + torch.nonzero(dist <= radius, as_tuple=False) + .t() + .as_subclass(torch.Tensor) + ) + + +class KNNGraph(GraphBuilder): + """ + Extends the :class:`~pina.graph.GraphBuilder` class to compute + ``edge_index`` based on a K-nearest neighbors algorithm. + """ + + def __new__(cls, pos, neighbours, **kwargs): + """ + Instantiate the :class:`~pina.graph.Graph` class by computing the + ``edge_index`` based on the K-nearest neighbors algorithm. + + :param pos: A tensor of shape ``(N, D)`` representing the positions of + ``N`` points in ``D``-dimensional space. + :type pos: torch.Tensor | LabelTensor + :param int neighbours: The number of nearest neighbors to consider when + building the graph. + :param dict kwargs: The additional keyword arguments to be passed to + :class:`GraphBuilder` and :class:`Graph` classes. + + :return: A :class:`~pina.graph.Graph` instance with the computed + ``edge_index``. + :rtype: Graph + """ + + edge_index = cls.compute_knn_graph(pos, neighbours) + return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) + + @staticmethod + def compute_knn_graph(points, neighbours): + """ + Computes the ``edge_index`` based on the K-nearest neighbors algorithm. + + :param points: A tensor of shape ``(N, D)`` representing the positions + of ``N`` points in ``D``-dimensional space. + :type points: torch.Tensor | LabelTensor + :param int neighbours: The number of nearest neighbors to consider when + building the graph. + :return: A tensor of shape ``(2, E)``, with ``E`` number of edges, + representing the edge indices of the graph. + :rtype: torch.Tensor + """ + dist = torch.cdist(points, points, p=2) + knn_indices = torch.topk(dist, k=neighbours, largest=False).indices + row = torch.arange(points.size(0)).repeat_interleave(neighbours) + col = knn_indices.flatten() + return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) + + +class LabelBatch(Batch): + """ + Extends the :class:`~torch_geometric.data.Batch` class to include + :class:`~pina.label_tensor.LabelTensor` objects. + """ + + @classmethod + def from_data_list(cls, data_list): + """ + Create a Batch object from a list of :class:`~torch_geometric.data.Data` + or :class:`~pina.graph.Graph` objects. + + :param data_list: List of :class:`~torch_geometric.data.Data` or + :class:`~pina.graph.Graph` objects. + :type data_list: list[Data] | list[Graph] + :return: A :class:`~torch_geometric.data.Batch` object containing + the input data. + :rtype: :class:`~torch_geometric.data.Batch` + """ + # Store the labels of Data/Graph objects (all data have the same labels) + # If the data do not contain labels, labels is an empty dictionary, + # therefore the labels are not stored + labels = { + k: v.labels + for k, v in data_list[0].items() + if isinstance(v, LabelTensor) + } + + # Create a Batch object from the list of Data objects + batch = super().from_data_list(data_list) + + # Put the labels back in the Batch object + for k, v in labels.items(): + batch[k].labels = v + return batch diff --git a/pina/label_tensor.py b/pina/_src/core/label_tensor.py similarity index 99% rename from pina/label_tensor.py rename to pina/_src/core/label_tensor.py index 535954d23..41bccc6fc 100644 --- a/pina/label_tensor.py +++ b/pina/_src/core/label_tensor.py @@ -541,7 +541,7 @@ def _update_single_label(self, index, dim): return new_dof def __getitem__(self, index): - """ " + """ Override the __getitem__ method to handle the labels of the :class:`~pina.label_tensor.LabelTensor` instance. It first performs __getitem__ operation on the :class:`torch.Tensor` part of the instance, diff --git a/pina/_src/core/operator.py b/pina/_src/core/operator.py new file mode 100644 index 000000000..8ed28c3a6 --- /dev/null +++ b/pina/_src/core/operator.py @@ -0,0 +1,482 @@ +"""Module for vectorized differential operators implementation. + +Differential operators are used to define differential problems and are +implemented to run efficiently on various accelerators, including CPU, GPU, TPU, +and MPS. + +Each differential operator takes the following inputs: +- A tensor on which the operator is applied. +- A tensor with respect to which the operator is computed. +- The names of the output variables for which the operator is evaluated. +- The names of the variables with respect to which the operator is computed. + +Each differential operator has its fast version, which performs no internal +checks on input and output tensors. For these methods, the user is always +required to specify both ``components`` and ``d`` as lists of strings. +""" + +import torch +from pina._src.core.label_tensor import LabelTensor + + +def _check_values(output_, input_, components, d): + """ + Perform checks on arguments of differential operators. + + :param LabelTensor output_: The output tensor on which the operator is + computed. + :param LabelTensor input_: The input tensor with respect to which the + operator is computed. + :param components: The names of the output variables for which to compute + the operator. It must be a subset of the output labels. + If ``None``, all output variables are considered. Default is ``None``. + :type components: str | list[str] + :param d: The names of the input variables with respect to which the + operator is computed. It must be a subset of the input labels. + If ``None``, all input variables are considered. Default is ``None``. + :type d: str | list[str] + :raises TypeError: If the input tensor is not a LabelTensor. + :raises TypeError: If the output tensor is not a LabelTensor. + :raises RuntimeError: If derivative labels are missing from the ``input_``. + :raises RuntimeError: If component labels are missing from the ``output_``. + :return: The components and d lists. + :rtype: tuple[list[str], list[str]] + """ + # Check if the input is a LabelTensor + if not isinstance(input_, LabelTensor): + raise TypeError("Input must be a LabelTensor.") + + # Check if the output is a LabelTensor + if not isinstance(output_, LabelTensor): + raise TypeError("Output must be a LabelTensor.") + + # If no labels are provided, use all labels + d = d or input_.labels + components = components or output_.labels + + # Convert to list if not already + d = d if isinstance(d, list) else [d] + components = components if isinstance(components, list) else [components] + + # Check if all labels are present in the input tensor + if not all(di in input_.labels for di in d): + raise RuntimeError("Derivative labels missing from input tensor.") + + # Check if all labels are present in the output tensor + if not all(c in output_.labels for c in components): + raise RuntimeError("Component label missing from output tensor.") + + return components, d + + +def _scalar_grad(output_, input_, d): + """ + Compute the gradient of a scalar-valued ``output_``. + + :param LabelTensor output_: The output tensor on which the gradient is + computed. It must be a column tensor. + :param LabelTensor input_: The input tensor with respect to which the + gradient is computed. + :param list[str] d: The names of the input variables with respect to + which the gradient is computed. It must be a subset of the input + labels. If ``None``, all input variables are considered. + :return: The computed gradient tensor. + :rtype: LabelTensor + """ + grad_out = torch.autograd.grad( + outputs=output_, + inputs=input_, + grad_outputs=torch.ones_like(output_), + create_graph=True, + retain_graph=True, + allow_unused=True, + )[0] + + return grad_out[..., [input_.labels.index(i) for i in d]] + + +def _scalar_laplacian(output_, input_, d): + """ + Compute the laplacian of a scalar-valued ``output_``. + + :param LabelTensor output_: The output tensor on which the laplacian is + computed. It must be a column tensor. + :param LabelTensor input_: The input tensor with respect to which the + laplacian is computed. + :param list[str] d: The names of the input variables with respect to + which the laplacian is computed. It must be a subset of the input + labels. If ``None``, all input variables are considered. + :return: The computed laplacian tensor. + :rtype: LabelTensor + """ + first_grad = fast_grad( + output_=output_, input_=input_, components=output_.labels, d=d + ) + second_grad = fast_grad( + output_=first_grad, input_=input_, components=first_grad.labels, d=d + ) + labels_to_extract = [f"d{c}d{d_}" for c, d_ in zip(first_grad.labels, d)] + return torch.sum( + second_grad.extract(labels_to_extract), dim=-1, keepdim=True + ) + + +def fast_grad(output_, input_, components, d): + """ + Compute the gradient of the ``output_`` with respect to the ``input``. + + Unlike ``grad``, this function performs no internal checks on input and + output tensors. The user is required to specify both ``components`` and + ``d`` as lists of strings. It is designed to enhance computation speed. + + This operator supports both vector-valued and scalar-valued functions with + one or multiple input coordinates. + + :param LabelTensor output_: The output tensor on which the gradient is + computed. + :param LabelTensor input_: The input tensor with respect to which the + gradient is computed. + :param list[str] components: The names of the output variables for which to + compute the gradient. It must be a subset of the output labels. + :param list[str] d: The names of the input variables with respect to which + the gradient is computed. It must be a subset of the input labels. + :return: The computed gradient tensor. + :rtype: LabelTensor + """ + # Scalar gradient + if output_.shape[-1] == 1: + return LabelTensor( + _scalar_grad(output_=output_, input_=input_, d=d), + labels=[f"d{output_.labels[0]}d{i}" for i in d], + ) + + # Vector gradient + grads = torch.cat( + [ + _scalar_grad(output_=output_.extract(c), input_=input_, d=d) + for c in components + ], + dim=-1, + ) + + return LabelTensor( + grads, labels=[f"d{c}d{i}" for c in components for i in d] + ) + + +def fast_div(output_, input_, components, d): + """ + Compute the divergence of the ``output_`` with respect to ``input``. + + Unlike ``div``, this function performs no internal checks on input and + output tensors. The user is required to specify both ``components`` and + ``d`` as lists of strings. It is designed to enhance computation speed. + + This operator supports vector-valued functions with multiple input + coordinates. + + :param LabelTensor output_: The output tensor on which the divergence is + computed. + :param LabelTensor input_: The input tensor with respect to which the + divergence is computed. + :param list[str] components: The names of the output variables for which to + compute the divergence. It must be a subset of the output labels. + :param list[str] d: The names of the input variables with respect to which + the divergence is computed. It must be a subset of the input labels. + :rtype: LabelTensor + """ + grad_out = fast_grad( + output_=output_, input_=input_, components=components, d=d + ) + tensors_to_sum = [ + grad_out.extract(f"d{c}d{d_}") for c, d_ in zip(components, d) + ] + + return LabelTensor.summation(tensors_to_sum) + + +def fast_laplacian(output_, input_, components, d, method="std"): + """ + Compute the laplacian of the ``output_`` with respect to ``input``. + + Unlike ``laplacian``, this function performs no internal checks on input and + output tensors. The user is required to specify both ``components`` and + ``d`` as lists of strings. It is designed to enhance computation speed. + + This operator supports both vector-valued and scalar-valued functions with + one or multiple input coordinates. + + :param LabelTensor output_: The output tensor on which the laplacian is + computed. + :param LabelTensor input_: The input tensor with respect to which the + laplacian is computed. + :param list[str] components: The names of the output variables for which to + compute the laplacian. It must be a subset of the output labels. + :param list[str] d: The names of the input variables with respect to which + the laplacian is computed. It must be a subset of the input labels. + :param str method: The method used to compute the Laplacian. Available + methods are ``std`` and ``divgrad``. The ``std`` method computes the + trace of the Hessian matrix, while the ``divgrad`` method computes the + divergence of the gradient. Default is ``std``. + :return: The computed laplacian tensor. + :rtype: LabelTensor + :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``. + """ + # Scalar laplacian + if output_.shape[-1] == 1: + return LabelTensor( + _scalar_laplacian(output_=output_, input_=input_, d=d), + labels=[f"dd{c}" for c in components], + ) + + # Initialize the result tensor and its labels + labels = [f"dd{c}" for c in components] + result = torch.empty( + input_.shape[0], len(components), device=output_.device + ) + + # Vector laplacian + if method == "std": + result = torch.cat( + [ + _scalar_laplacian( + output_=output_.extract(c), input_=input_, d=d + ) + for c in components + ], + dim=-1, + ) + + elif method == "divgrad": + grads = fast_grad( + output_=output_, input_=input_, components=components, d=d + ) + result = torch.cat( + [ + fast_div( + output_=grads, + input_=input_, + components=[f"d{c}d{i}" for i in d], + d=d, + ) + for c in components + ], + dim=-1, + ) + + else: + raise ValueError( + "Invalid method. Available methods are ``std`` and ``divgrad``." + ) + + return LabelTensor(result, labels=labels) + + +def fast_advection(output_, input_, velocity_field, components, d): + """ + Perform the advection operation on the ``output_`` with respect to the + ``input``. This operator supports vector-valued functions with multiple + input coordinates. + + Unlike ``advection``, this function performs no internal checks on input and + output tensors. The user is required to specify both ``components`` and + ``d`` as lists of strings. It is designed to enhance computation speed. + + :param LabelTensor output_: The output tensor on which the advection is + computed. It includes both the velocity and the quantity to be advected. + :param LabelTensor input_: the input tensor with respect to which advection + is computed. + :param list[str] velocity_field: The name of the output variables used as + velocity field. It must be chosen among the output labels. + :param list[str] components: The names of the output variables for which to + compute the advection. It must be a subset of the output labels. + :param list[str] d: The names of the input variables with respect to which + the advection is computed. It must be a subset of the input labels. + :return: The computed advection tensor. + :rtype: LabelTensor + """ + # Add a dimension to the velocity field for following operations + velocity = output_.extract(velocity_field).unsqueeze(-1) + + # Compute the gradient + grads = fast_grad( + output_=output_, input_=input_, components=components, d=d + ) + + # Reshape into [..., len(filter_components), len(d)] + tmp = grads.reshape(*output_.shape[:-1], len(components), len(d)) + + # Transpose to [..., len(d), len(filter_components)] + tmp = tmp.transpose(-1, -2) + + adv = (tmp * velocity).sum(dim=tmp.tensor.ndim - 2) + return LabelTensor(adv, labels=[f"adv_{c}" for c in components]) + + +def grad(output_, input_, components=None, d=None): + """ + Compute the gradient of the ``output_`` with respect to the ``input``. + + This operator supports both vector-valued and scalar-valued functions with + one or multiple input coordinates. + + :param LabelTensor output_: The output tensor on which the gradient is + computed. + :param LabelTensor input_: The input tensor with respect to which the + gradient is computed. + :param components: The names of the output variables for which to compute + the gradient. It must be a subset of the output labels. + If ``None``, all output variables are considered. Default is ``None``. + :type components: str | list[str] + :param d: The names of the input variables with respect to which the + gradient is computed. It must be a subset of the input labels. + If ``None``, all input variables are considered. Default is ``None``. + :type d: str | list[str] + :raises TypeError: If the input tensor is not a LabelTensor. + :raises TypeError: If the output tensor is not a LabelTensor. + :raises RuntimeError: If derivative labels are missing from the ``input_``. + :raises RuntimeError: If component labels are missing from the ``output_``. + :return: The computed gradient tensor. + :rtype: LabelTensor + """ + components, d = _check_values( + output_=output_, input_=input_, components=components, d=d + ) + return fast_grad(output_=output_, input_=input_, components=components, d=d) + + +def div(output_, input_, components=None, d=None): + """ + Compute the divergence of the ``output_`` with respect to ``input``. + + This operator supports vector-valued functions with multiple input + coordinates. + + :param LabelTensor output_: The output tensor on which the divergence is + computed. + :param LabelTensor input_: The input tensor with respect to which the + divergence is computed. + :param components: The names of the output variables for which to compute + the divergence. It must be a subset of the output labels. + If ``None``, all output variables are considered. Default is ``None``. + :type components: str | list[str] + :param d: The names of the input variables with respect to which the + divergence is computed. It must be a subset of the input labels. + If ``None``, all input variables are considered. Default is ``None``. + :type components: str | list[str] + :raises TypeError: If the input tensor is not a LabelTensor. + :raises TypeError: If the output tensor is not a LabelTensor. + :raises ValueError: If the length of ``components`` and ``d`` do not match. + :return: The computed divergence tensor. + :rtype: LabelTensor + """ + components, d = _check_values( + output_=output_, input_=input_, components=components, d=d + ) + + # Components and d must be of the same length + if len(components) != len(d): + raise ValueError( + "Divergence requires components and d to be of the same length." + ) + + return fast_div(output_=output_, input_=input_, components=components, d=d) + + +def laplacian(output_, input_, components=None, d=None, method="std"): + """ + Compute the laplacian of the ``output_`` with respect to ``input``. + + This operator supports both vector-valued and scalar-valued functions with + one or multiple input coordinates. + + :param LabelTensor output_: The output tensor on which the laplacian is + computed. + :param LabelTensor input_: The input tensor with respect to which the + laplacian is computed. + :param components: The names of the output variables for which to + compute the laplacian. It must be a subset of the output labels. + If ``None``, all output variables are considered. Default is ``None``. + :type components: str | list[str] + :param d: The names of the input variables with respect to which + the laplacian is computed. It must be a subset of the input labels. + If ``None``, all input variables are considered. Default is ``None``. + :type d: str | list[str] + :param str method: The method used to compute the Laplacian. Available + methods are ``std`` and ``divgrad``. The ``std`` method computes the + trace of the Hessian matrix, while the ``divgrad`` method computes the + divergence of the gradient. Default is ``std``. + :raises TypeError: If the input tensor is not a LabelTensor. + :raises TypeError: If the output tensor is not a LabelTensor. + :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``. + :return: The computed laplacian tensor. + :rtype: LabelTensor + """ + components, d = _check_values( + output_=output_, input_=input_, components=components, d=d + ) + + return fast_laplacian( + output_=output_, + input_=input_, + components=components, + d=d, + method=method, + ) + + +def advection(output_, input_, velocity_field, components=None, d=None): + """ + Perform the advection operation on the ``output_`` with respect to the + ``input``. This operator supports vector-valued functions with multiple + input coordinates. + + :param LabelTensor output_: The output tensor on which the advection is + computed. It includes both the velocity and the quantity to be advected. + :param LabelTensor input_: the input tensor with respect to which advection + is computed. + :param velocity_field: The name of the output variables used as velocity + field. It must be chosen among the output labels. + :type velocity_field: str | list[str] + :param components: The names of the output variables for which to compute + the advection. It must be a subset of the output labels. + If ``None``, all output variables are considered. Default is ``None``. + :type components: str | list[str] + :param d: The names of the input variables with respect to which the + advection is computed. It must be a subset of the input labels. + If ``None``, all input variables are considered. Default is ``None``. + :type d: str | list[str] + :raises TypeError: If the input tensor is not a LabelTensor. + :raises TypeError: If the output tensor is not a LabelTensor. + :raises RuntimeError: If the velocity field is not a subset of the output + labels. + :raises RuntimeError: If the dimensionality of the velocity field does not + match that of the input tensor. + :return: The computed advection tensor. + :rtype: LabelTensor + """ + components, d = _check_values( + output_=output_, input_=input_, components=components, d=d + ) + + # Map velocity_field to a list if it is a string + if isinstance(velocity_field, str): + velocity_field = [velocity_field] + + # Check if all the velocity_field labels are present in the output labels + if not all(vi in output_.labels for vi in velocity_field): + raise RuntimeError("Velocity labels missing from output tensor.") + + # Check if the velocity has the same dimensionality as the input tensor + if len(velocity_field) != len(d): + raise RuntimeError( + "Velocity dimensionality does not match input dimensionality." + ) + + return fast_advection( + output_=output_, + input_=input_, + velocity_field=velocity_field, + components=components, + d=d, + ) diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py new file mode 100644 index 000000000..7500be537 --- /dev/null +++ b/pina/_src/core/trainer.py @@ -0,0 +1,367 @@ +"""Module for the Trainer.""" + +import sys +import warnings +import torch +import lightning +from pina._src.core.utils import check_consistency, custom_warning_format +from pina._src.data.data_module import PinaDataModule +from pina._src.solver.supervised_solver.supervised_solver_interface import ( + SolverInterface, +) +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) + +# set the warning for compile options +warnings.formatwarning = custom_warning_format +warnings.filterwarnings("always", category=UserWarning) + + +class Trainer(lightning.pytorch.Trainer): + """ + PINA custom Trainer class to extend the standard Lightning functionality. + + This class enables specific features or behaviors required by the PINA + framework. It modifies the standard + :class:`lightning.pytorch.Trainer ` + class to better support the training process in PINA. + """ + + def __init__( + self, + solver, + batch_size=None, + train_size=1.0, + test_size=0.0, + val_size=0.0, + compile=None, + repeat=None, + automatic_batching=None, + num_workers=None, + pin_memory=None, + shuffle=None, + **kwargs, + ): + """ + Initialization of the :class:`Trainer` class. + + :param SolverInterface solver: A + :class:`~pina.solver.solver.SolverInterface` solver used to solve a + :class:`~pina.problem.abstract_problem.AbstractProblem`. + :param int batch_size: The number of samples per batch to load. + If ``None``, all samples are loaded and data is not batched. + Default is ``None``. + :param float train_size: The percentage of elements to include in the + training dataset. Default is ``1.0``. + :param float test_size: The percentage of elements to include in the + test dataset. Default is ``0.0``. + :param float val_size: The percentage of elements to include in the + validation dataset. Default is ``0.0``. + :param bool compile: If ``True``, the model is compiled before training. + Default is ``False``. For Windows users, it is always disabled. Not + supported for python version greater or equal than 3.14. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. For further details, see the + :class:`~pina.data.data_module.PinaDataModule` class. Default is + ``False``. + :param bool automatic_batching: If ``True``, automatic PyTorch batching + is performed, otherwise the items are retrieved from the dataset + all at once. For further details, see the + :class:`~pina.data.data_module.PinaDataModule` class. Default is + ``False``. + :param int num_workers: The number of worker threads for data loading. + Default is ``0`` (serial loading). + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. Default is ``False``. + :param bool shuffle: Whether to shuffle the data during training. + Default is ``True``. + :param dict kwargs: Additional keyword arguments that specify the + training setup. These can be selected from the `pytorch-lightning + Trainer API + `_. + """ + # check consistency for init types + self._check_input_consistency( + solver=solver, + train_size=train_size, + test_size=test_size, + val_size=val_size, + repeat=repeat, + automatic_batching=automatic_batching, + compile=compile, + ) + pin_memory, num_workers, shuffle, batch_size = ( + self._check_consistency_and_set_defaults( + pin_memory, num_workers, shuffle, batch_size + ) + ) + + # inference mode set to false when validating/testing PINNs otherwise + # gradient is not tracked and optimization_cycle fails + if isinstance(solver, PINNInterface): + kwargs["inference_mode"] = False + + # Logging depends on the batch size, when batch_size is None then + # log_every_n_steps should be zero + if batch_size is None: + kwargs["log_every_n_steps"] = 0 + else: + kwargs.setdefault("log_every_n_steps", 50) # default for lightning + + # Setting default kwargs, overriding lightning defaults + kwargs.setdefault("enable_progress_bar", True) + + super().__init__(**kwargs) + + # checking compilation and automatic batching + # compilation disabled for Windows and for Python 3.14+ + if ( + compile is None + or sys.platform == "win32" + or sys.version_info >= (3, 14) + ): + compile = False + warnings.warn( + "Compilation is disabled for Python 3.14+ and for Windows.", + UserWarning, + ) + + repeat = repeat if repeat is not None else False + + automatic_batching = ( + automatic_batching if automatic_batching is not None else False + ) + + # set attributes + self.compile = compile + self.solver = solver + self.batch_size = batch_size + self._move_to_device() + self.data_module = None + self._create_datamodule( + train_size=train_size, + test_size=test_size, + val_size=val_size, + batch_size=batch_size, + repeat=repeat, + automatic_batching=automatic_batching, + pin_memory=pin_memory, + num_workers=num_workers, + shuffle=shuffle, + ) + + # logging + self.logging_kwargs = { + "sync_dist": bool( + len(self._accelerator_connector._parallel_devices) > 1 + ), + "on_step": bool(kwargs["log_every_n_steps"] > 0), + "prog_bar": bool(kwargs["enable_progress_bar"]), + "on_epoch": True, + } + + def _move_to_device(self): + """ + Moves the ``unknown_parameters`` of an instance of + :class:`~pina.problem.abstract_problem.AbstractProblem` to the + :class:`Trainer` device. + """ + device = self._accelerator_connector._parallel_devices[0] + # move parameters to device + pb = self.solver.problem + if hasattr(pb, "unknown_parameters"): + for key in pb.unknown_parameters: + pb.unknown_parameters[key] = torch.nn.Parameter( + pb.unknown_parameters[key].data.to(device) + ) + + def _create_datamodule( + self, + train_size, + test_size, + val_size, + batch_size, + repeat, + automatic_batching, + pin_memory, + num_workers, + shuffle, + ): + """ + This method is designed to handle the creation of a data module when + resampling is needed during training. Instead of manually defining and + modifying the trainer's dataloaders, this method is called to + automatically configure the data module. + + :param float train_size: The percentage of elements to include in the + training dataset. + :param float test_size: The percentage of elements to include in the + test dataset. + :param float val_size: The percentage of elements to include in the + validation dataset. + :param int batch_size: The number of samples per batch to load. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. + :param bool automatic_batching: Whether to perform automatic batching + with PyTorch. + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. + :param int num_workers: The number of worker threads for data loading. + :param bool shuffle: Whether to shuffle the data during training. + :raises RuntimeError: If not all conditions are sampled. + """ + if not self.solver.problem.are_all_domains_discretised: + error_message = "\n".join( + [ + f"""{" " * 13} ---> Domain {key} { + "sampled" if key in self.solver.problem.discretised_domains + else + "not sampled"}""" + for key in self.solver.problem.domains.keys() + ] + ) + raise RuntimeError( + "Cannot create Trainer if not all conditions " + "are sampled. The Trainer got the following:\n" + f"{error_message}" + ) + self.data_module = PinaDataModule( + self.solver.problem, + train_size=train_size, + test_size=test_size, + val_size=val_size, + batch_size=batch_size, + repeat=repeat, + automatic_batching=automatic_batching, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + ) + + def train(self, **kwargs): + """ + Manage the training process of the solver. + + :param dict kwargs: Additional keyword arguments. See `pytorch-lightning + Trainer API `_ + for details. + """ + return super().fit(self.solver, datamodule=self.data_module, **kwargs) + + def test(self, **kwargs): + """ + Manage the test process of the solver. + + :param dict kwargs: Additional keyword arguments. See `pytorch-lightning + Trainer API `_ + for details. + """ + return super().test(self.solver, datamodule=self.data_module, **kwargs) + + @property + def solver(self): + """ + Get the solver. + + :return: The solver. + :rtype: SolverInterface + """ + return self._solver + + @solver.setter + def solver(self, solver): + """ + Set the solver. + + :param SolverInterface solver: The solver to set. + """ + self._solver = solver + + @staticmethod + def _check_input_consistency( + solver, + train_size, + test_size, + val_size, + repeat, + automatic_batching, + compile, + ): + """ + Verifies the consistency of the parameters for the solver configuration. + + :param SolverInterface solver: The solver. + :param float train_size: The percentage of elements to include in the + training dataset. + :param float test_size: The percentage of elements to include in the + test dataset. + :param float val_size: The percentage of elements to include in the + validation dataset. + :param bool repeat: Whether to repeat the dataset data in each + condition during training. + :param bool automatic_batching: Whether to perform automatic batching + with PyTorch. + :param bool compile: If ``True``, the model is compiled before training. + """ + + check_consistency(solver, SolverInterface) + check_consistency(train_size, float) + check_consistency(test_size, float) + check_consistency(val_size, float) + if repeat is not None: + check_consistency(repeat, bool) + if automatic_batching is not None: + check_consistency(automatic_batching, bool) + if compile is not None: + check_consistency(compile, bool) + + @staticmethod + def _check_consistency_and_set_defaults( + pin_memory, num_workers, shuffle, batch_size + ): + """ + Checks the consistency of input parameters and sets default values + for missing or invalid parameters. + + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. + :param int num_workers: The number of worker threads for data loading. + :param bool shuffle: Whether to shuffle the data during training. + :param int batch_size: The number of samples per batch to load. + """ + if pin_memory is not None: + check_consistency(pin_memory, bool) + else: + pin_memory = False + if num_workers is not None: + check_consistency(num_workers, int) + else: + num_workers = 0 + if shuffle is not None: + check_consistency(shuffle, bool) + else: + shuffle = True + if batch_size is not None: + check_consistency(batch_size, int) + return pin_memory, num_workers, shuffle, batch_size + + @property + def compile(self): + """ + Whether compilation is required or not. + + :return: ``True`` if compilation is required, ``False`` otherwise. + :rtype: bool + """ + return self._compile + + @compile.setter + def compile(self, value): + """ + Setting the value of compile. + + :param bool value: Whether compilation is required or not. + """ + check_consistency(value, bool) + self._compile = value diff --git a/pina/type_checker.py b/pina/_src/core/type_checker.py similarity index 100% rename from pina/type_checker.py rename to pina/_src/core/type_checker.py diff --git a/pina/utils.py b/pina/_src/core/utils.py similarity index 99% rename from pina/utils.py rename to pina/_src/core/utils.py index efc48424e..ea70ed944 100644 --- a/pina/utils.py +++ b/pina/_src/core/utils.py @@ -4,7 +4,7 @@ from functools import reduce import torch -from .label_tensor import LabelTensor +from pina._src.core.label_tensor import LabelTensor # Codacy error unused parameters diff --git a/pina/_src/data/__init__.py b/pina/_src/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/data/data_module.py b/pina/_src/data/data_module.py similarity index 99% rename from pina/data/data_module.py rename to pina/_src/data/data_module.py index 52b52a3fa..664d200f8 100644 --- a/pina/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -10,8 +10,8 @@ from torch_geometric.data import Data from torch.utils.data import DataLoader, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler -from ..label_tensor import LabelTensor -from .dataset import PinaDatasetFactory, PinaTensorDataset +from pina._src.core.label_tensor import LabelTensor +from pina._src.data.dataset import PinaDatasetFactory, PinaTensorDataset class DummyDataloader: diff --git a/pina/data/dataset.py b/pina/_src/data/dataset.py similarity index 99% rename from pina/data/dataset.py rename to pina/_src/data/dataset.py index 62e3913d8..bf2f168e4 100644 --- a/pina/data/dataset.py +++ b/pina/_src/data/dataset.py @@ -3,7 +3,7 @@ from abc import abstractmethod, ABC from torch.utils.data import Dataset from torch_geometric.data import Data -from ..graph import Graph, LabelBatch +from pina._src.core.graph import Graph, LabelBatch class PinaDatasetFactory: diff --git a/pina/_src/domain/__init__.py b/pina/_src/domain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/domain/base_domain.py b/pina/_src/domain/base_domain.py similarity index 97% rename from pina/domain/base_domain.py rename to pina/_src/domain/base_domain.py index c7bef9700..3316fabfd 100644 --- a/pina/domain/base_domain.py +++ b/pina/_src/domain/base_domain.py @@ -2,8 +2,8 @@ from copy import deepcopy from abc import ABCMeta -from .domain_interface import DomainInterface -from ..utils import check_consistency, check_positive_integer +from pina._src.domain.domain_interface import DomainInterface +from pina._src.core.utils import check_consistency, check_positive_integer class BaseDomain(DomainInterface, metaclass=ABCMeta): diff --git a/pina/domain/base_operation.py b/pina/_src/domain/base_operation.py similarity index 97% rename from pina/domain/base_operation.py rename to pina/_src/domain/base_operation.py index 8261ae431..ff83e1551 100644 --- a/pina/domain/base_operation.py +++ b/pina/_src/domain/base_operation.py @@ -2,9 +2,9 @@ from copy import deepcopy from abc import ABCMeta -from .operation_interface import OperationInterface -from .base_domain import BaseDomain -from ..utils import check_consistency +from pina._src.domain.operation_interface import OperationInterface +from pina._src.domain.base_domain import BaseDomain +from pina._src.core.utils import check_consistency class BaseOperation(OperationInterface, BaseDomain, metaclass=ABCMeta): diff --git a/pina/domain/cartesian_domain.py b/pina/_src/domain/cartesian_domain.py similarity index 97% rename from pina/domain/cartesian_domain.py rename to pina/_src/domain/cartesian_domain.py index 3333a8fc3..089e3377c 100644 --- a/pina/domain/cartesian_domain.py +++ b/pina/_src/domain/cartesian_domain.py @@ -1,10 +1,10 @@ """Module for the Cartesian Domain.""" import torch -from .base_domain import BaseDomain -from .union import Union -from ..utils import torch_lhs, chebyshev_roots, check_consistency -from ..label_tensor import LabelTensor +from pina._src.domain.base_domain import BaseDomain +from pina._src.domain.union import Union +from pina._src.core.utils import torch_lhs, chebyshev_roots, check_consistency +from pina._src.core.label_tensor import LabelTensor class CartesianDomain(BaseDomain): diff --git a/pina/domain/difference.py b/pina/_src/domain/difference.py similarity index 96% rename from pina/domain/difference.py rename to pina/_src/domain/difference.py index 76807b035..ce87920e5 100644 --- a/pina/domain/difference.py +++ b/pina/_src/domain/difference.py @@ -1,8 +1,8 @@ """Module for the Difference operation.""" -from .base_operation import BaseOperation -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_operation import BaseOperation +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class Difference(BaseOperation): diff --git a/pina/domain/domain_interface.py b/pina/_src/domain/domain_interface.py similarity index 100% rename from pina/domain/domain_interface.py rename to pina/_src/domain/domain_interface.py diff --git a/pina/domain/ellipsoid_domain.py b/pina/_src/domain/ellipsoid_domain.py similarity index 98% rename from pina/domain/ellipsoid_domain.py rename to pina/_src/domain/ellipsoid_domain.py index ecb08e37c..402ec29a8 100644 --- a/pina/domain/ellipsoid_domain.py +++ b/pina/_src/domain/ellipsoid_domain.py @@ -2,9 +2,9 @@ from copy import deepcopy import torch -from .base_domain import BaseDomain -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_domain import BaseDomain +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class EllipsoidDomain(BaseDomain): diff --git a/pina/domain/exclusion.py b/pina/_src/domain/exclusion.py similarity index 97% rename from pina/domain/exclusion.py rename to pina/_src/domain/exclusion.py index 59205f3a8..914e17086 100644 --- a/pina/domain/exclusion.py +++ b/pina/_src/domain/exclusion.py @@ -1,9 +1,9 @@ """Module for the Exclusion set-operation.""" import random -from .base_operation import BaseOperation -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_operation import BaseOperation +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class Exclusion(BaseOperation): diff --git a/pina/domain/intersection.py b/pina/_src/domain/intersection.py similarity index 96% rename from pina/domain/intersection.py rename to pina/_src/domain/intersection.py index 105575df1..1b004556e 100644 --- a/pina/domain/intersection.py +++ b/pina/_src/domain/intersection.py @@ -1,9 +1,9 @@ """Module for the Intersection operation.""" import random -from .base_operation import BaseOperation -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_operation import BaseOperation +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class Intersection(BaseOperation): diff --git a/pina/domain/operation_interface.py b/pina/_src/domain/operation_interface.py similarity index 92% rename from pina/domain/operation_interface.py rename to pina/_src/domain/operation_interface.py index 9be458972..357556105 100644 --- a/pina/domain/operation_interface.py +++ b/pina/_src/domain/operation_interface.py @@ -1,7 +1,7 @@ """Module for the Operation Interface.""" from abc import ABCMeta, abstractmethod -from .domain_interface import DomainInterface +from pina._src.domain.domain_interface import DomainInterface class OperationInterface(DomainInterface, metaclass=ABCMeta): diff --git a/pina/domain/simplex_domain.py b/pina/_src/domain/simplex_domain.py similarity index 98% rename from pina/domain/simplex_domain.py rename to pina/_src/domain/simplex_domain.py index 9e3a3e58f..5dff002ce 100644 --- a/pina/domain/simplex_domain.py +++ b/pina/_src/domain/simplex_domain.py @@ -2,9 +2,9 @@ from copy import deepcopy import torch -from .base_domain import BaseDomain -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_domain import BaseDomain +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class SimplexDomain(BaseDomain): diff --git a/pina/domain/union.py b/pina/_src/domain/union.py similarity index 95% rename from pina/domain/union.py rename to pina/_src/domain/union.py index df094bb82..eff137df3 100644 --- a/pina/domain/union.py +++ b/pina/_src/domain/union.py @@ -1,9 +1,9 @@ """Module for the Union operation.""" import random -from .base_operation import BaseOperation -from ..label_tensor import LabelTensor -from ..utils import check_consistency +from pina._src.domain.base_operation import BaseOperation +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.utils import check_consistency class Union(BaseOperation): diff --git a/pina/_src/equation/__init__.py b/pina/_src/equation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/equation/equation.py b/pina/_src/equation/equation.py similarity index 97% rename from pina/equation/equation.py rename to pina/_src/equation/equation.py index 057c6bcf5..a1d67628c 100644 --- a/pina/equation/equation.py +++ b/pina/_src/equation/equation.py @@ -1,7 +1,7 @@ """Module for the Equation.""" import inspect -from .equation_interface import EquationInterface +from pina._src.equation.equation_interface import EquationInterface class Equation(EquationInterface): diff --git a/pina/equation/equation_factory.py b/pina/_src/equation/equation_factory.py similarity index 99% rename from pina/equation/equation_factory.py rename to pina/_src/equation/equation_factory.py index 01560d6c1..c001d1461 100644 --- a/pina/equation/equation_factory.py +++ b/pina/_src/equation/equation_factory.py @@ -2,9 +2,9 @@ from typing import Callable import torch -from .equation import Equation -from ..operator import grad, div, laplacian -from ..utils import check_consistency +from pina._src.equation.equation import Equation +from pina._src.core.operator import grad, div, laplacian +from pina._src.core.utils import check_consistency class FixedValue(Equation): # pylint: disable=R0903 diff --git a/pina/equation/equation_interface.py b/pina/_src/equation/equation_interface.py similarity index 100% rename from pina/equation/equation_interface.py rename to pina/_src/equation/equation_interface.py diff --git a/pina/equation/system_equation.py b/pina/_src/equation/system_equation.py similarity index 96% rename from pina/equation/system_equation.py rename to pina/_src/equation/system_equation.py index 3e8550d9b..a9920a955 100644 --- a/pina/equation/system_equation.py +++ b/pina/_src/equation/system_equation.py @@ -1,9 +1,9 @@ """Module for the System of Equation.""" import torch -from .equation_interface import EquationInterface -from .equation import Equation -from ..utils import check_consistency +from pina._src.equation.equation_interface import EquationInterface +from pina._src.equation.equation import Equation +from pina._src.core.utils import check_consistency class SystemEquation(EquationInterface): diff --git a/pina/_src/loss/__init__.py b/pina/_src/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/loss/linear_weighting.py b/pina/_src/loss/linear_weighting.py similarity index 94% rename from pina/loss/linear_weighting.py rename to pina/_src/loss/linear_weighting.py index 9049b52fa..e50d5151c 100644 --- a/pina/loss/linear_weighting.py +++ b/pina/_src/loss/linear_weighting.py @@ -1,7 +1,7 @@ """Module for the LinearWeighting class.""" -from ..loss import WeightingInterface -from ..utils import check_consistency, check_positive_integer +from pina._src.loss.weighting_interface import WeightingInterface +from pina._src.core.utils import check_consistency, check_positive_integer class LinearWeighting(WeightingInterface): diff --git a/pina/loss/loss_interface.py b/pina/_src/loss/loss_interface.py similarity index 100% rename from pina/loss/loss_interface.py rename to pina/_src/loss/loss_interface.py diff --git a/pina/loss/lp_loss.py b/pina/_src/loss/lp_loss.py similarity index 95% rename from pina/loss/lp_loss.py rename to pina/_src/loss/lp_loss.py index f535a5b6f..b2047d945 100644 --- a/pina/loss/lp_loss.py +++ b/pina/_src/loss/lp_loss.py @@ -1,9 +1,8 @@ """Module for the LpLoss class.""" import torch - -from ..utils import check_consistency -from .loss_interface import LossInterface +from pina._src.loss.loss_interface import LossInterface +from pina._src.core.utils import check_consistency class LpLoss(LossInterface): diff --git a/pina/loss/ntk_weighting.py b/pina/_src/loss/ntk_weighting.py similarity index 95% rename from pina/loss/ntk_weighting.py rename to pina/_src/loss/ntk_weighting.py index fe1c4fc6a..96c89fc3a 100644 --- a/pina/loss/ntk_weighting.py +++ b/pina/_src/loss/ntk_weighting.py @@ -1,8 +1,8 @@ """Module for Neural Tangent Kernel Class""" import torch -from .weighting_interface import WeightingInterface -from ..utils import check_consistency, in_range +from pina._src.loss.weighting_interface import WeightingInterface +from pina._src.core.utils import check_consistency, in_range class NeuralTangentKernelWeighting(WeightingInterface): diff --git a/pina/loss/power_loss.py b/pina/_src/loss/power_loss.py similarity index 95% rename from pina/loss/power_loss.py rename to pina/_src/loss/power_loss.py index 1edbf4f86..67986a988 100644 --- a/pina/loss/power_loss.py +++ b/pina/_src/loss/power_loss.py @@ -2,8 +2,8 @@ import torch -from ..utils import check_consistency -from .loss_interface import LossInterface +from pina._src.loss.loss_interface import LossInterface +from pina._src.core.utils import check_consistency class PowerLoss(LossInterface): diff --git a/pina/loss/scalar_weighting.py b/pina/_src/loss/scalar_weighting.py similarity index 93% rename from pina/loss/scalar_weighting.py rename to pina/_src/loss/scalar_weighting.py index 692c4937b..c97b037f9 100644 --- a/pina/loss/scalar_weighting.py +++ b/pina/_src/loss/scalar_weighting.py @@ -1,7 +1,7 @@ """Module for the Scalar Weighting.""" -from .weighting_interface import WeightingInterface -from ..utils import check_consistency +from pina._src.loss.weighting_interface import WeightingInterface +from pina._src.core.utils import check_consistency class ScalarWeighting(WeightingInterface): diff --git a/pina/loss/self_adaptive_weighting.py b/pina/_src/loss/self_adaptive_weighting.py similarity index 96% rename from pina/loss/self_adaptive_weighting.py rename to pina/_src/loss/self_adaptive_weighting.py index c796d359f..8a91f98f5 100644 --- a/pina/loss/self_adaptive_weighting.py +++ b/pina/_src/loss/self_adaptive_weighting.py @@ -1,7 +1,7 @@ """Module for Self-Adaptive Weighting class.""" import torch -from .weighting_interface import WeightingInterface +from pina._src.loss.weighting_interface import WeightingInterface class SelfAdaptiveWeighting(WeightingInterface): diff --git a/pina/loss/weighting_interface.py b/pina/_src/loss/weighting_interface.py similarity index 98% rename from pina/loss/weighting_interface.py rename to pina/_src/loss/weighting_interface.py index bc34c3181..5e75e0aaa 100644 --- a/pina/loss/weighting_interface.py +++ b/pina/_src/loss/weighting_interface.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod from typing import final -from ..utils import check_positive_integer, is_function +from pina._src.core.utils import check_positive_integer, is_function _AGGREGATE_METHODS = {"sum": sum, "mean": lambda x: sum(x) / len(x)} diff --git a/pina/_src/model/__init__.py b/pina/_src/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/model/average_neural_operator.py b/pina/_src/model/average_neural_operator.py similarity index 100% rename from pina/model/average_neural_operator.py rename to pina/_src/model/average_neural_operator.py diff --git a/pina/_src/model/block/__init__.py b/pina/_src/model/block/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/model/block/average_neural_operator_block.py b/pina/_src/model/block/average_neural_operator_block.py similarity index 97% rename from pina/model/block/average_neural_operator_block.py rename to pina/_src/model/block/average_neural_operator_block.py index 91379abeb..4b5af8081 100644 --- a/pina/model/block/average_neural_operator_block.py +++ b/pina/_src/model/block/average_neural_operator_block.py @@ -2,7 +2,7 @@ import torch from torch import nn -from ...utils import check_consistency +from pina._src.core.utils import check_consistency class AVNOBlock(nn.Module): diff --git a/pina/model/block/convolution.py b/pina/_src/model/block/convolution.py similarity index 98% rename from pina/model/block/convolution.py rename to pina/_src/model/block/convolution.py index 666f66a66..bfe7054af 100644 --- a/pina/model/block/convolution.py +++ b/pina/_src/model/block/convolution.py @@ -2,8 +2,8 @@ from abc import ABCMeta, abstractmethod import torch -from .stride import Stride -from .utils_convolution import optimizing +from pina._src.model.block.stride import Stride +from pina._src.model.block.utils_convolution import optimizing class BaseContinuousConv(torch.nn.Module, metaclass=ABCMeta): diff --git a/pina/model/block/convolution_2d.py b/pina/_src/model/block/convolution_2d.py similarity index 98% rename from pina/model/block/convolution_2d.py rename to pina/_src/model/block/convolution_2d.py index 825ae613b..935bb0afa 100644 --- a/pina/model/block/convolution_2d.py +++ b/pina/_src/model/block/convolution_2d.py @@ -1,9 +1,9 @@ -"""Module for the Continuous Convolution class.""" +"""Module for the Continuous 2D Convolution class.""" import torch -from .convolution import BaseContinuousConv -from .utils_convolution import check_point, map_points_ -from .integral import Integral +from pina._src.model.block.convolution import BaseContinuousConv +from pina._src.model.block.utils_convolution import check_point, map_points_ +from pina._src.model.block.integral import Integral class ContinuousConvBlock(BaseContinuousConv): diff --git a/pina/model/block/embedding.py b/pina/_src/model/block/embedding.py similarity index 99% rename from pina/model/block/embedding.py rename to pina/_src/model/block/embedding.py index 1e44ec143..f9f05c119 100644 --- a/pina/model/block/embedding.py +++ b/pina/_src/model/block/embedding.py @@ -1,7 +1,7 @@ """Modules for the the Embedding blocks.""" import torch -from pina.utils import check_consistency +from pina._src.core.utils import check_consistency class PeriodicBoundaryEmbedding(torch.nn.Module): diff --git a/pina/model/block/fourier_block.py b/pina/_src/model/block/fourier_block.py similarity index 98% rename from pina/model/block/fourier_block.py rename to pina/_src/model/block/fourier_block.py index 2983c840a..2510320ec 100644 --- a/pina/model/block/fourier_block.py +++ b/pina/_src/model/block/fourier_block.py @@ -2,9 +2,9 @@ import torch from torch import nn -from ...utils import check_consistency +from pina._src.core.utils import check_consistency -from .spectral import ( +from pina._src.model.block.spectral import ( SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D, diff --git a/pina/model/block/gno_block.py b/pina/_src/model/block/gno_block.py similarity index 100% rename from pina/model/block/gno_block.py rename to pina/_src/model/block/gno_block.py diff --git a/pina/model/block/integral.py b/pina/_src/model/block/integral.py similarity index 100% rename from pina/model/block/integral.py rename to pina/_src/model/block/integral.py diff --git a/pina/model/block/low_rank_block.py b/pina/_src/model/block/low_rank_block.py similarity index 98% rename from pina/model/block/low_rank_block.py rename to pina/_src/model/block/low_rank_block.py index 1e8925d95..ad67b4dca 100644 --- a/pina/model/block/low_rank_block.py +++ b/pina/_src/model/block/low_rank_block.py @@ -2,7 +2,7 @@ import torch -from ...utils import check_consistency +from pina._src.core.utils import check_consistency class LowRankBlock(torch.nn.Module): diff --git a/pina/_src/model/block/message_passing/__init__.py b/pina/_src/model/block/message_passing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/model/block/message_passing/deep_tensor_network_block.py b/pina/_src/model/block/message_passing/deep_tensor_network_block.py similarity index 98% rename from pina/model/block/message_passing/deep_tensor_network_block.py rename to pina/_src/model/block/message_passing/deep_tensor_network_block.py index a2de3097a..ed19578b7 100644 --- a/pina/model/block/message_passing/deep_tensor_network_block.py +++ b/pina/_src/model/block/message_passing/deep_tensor_network_block.py @@ -2,7 +2,7 @@ import torch from torch_geometric.nn import MessagePassing -from ....utils import check_positive_integer +from pina._src.core.utils import check_positive_integer class DeepTensorNetworkBlock(MessagePassing): diff --git a/pina/model/block/message_passing/en_equivariant_network_block.py b/pina/_src/model/block/message_passing/en_equivariant_network_block.py similarity index 98% rename from pina/model/block/message_passing/en_equivariant_network_block.py rename to pina/_src/model/block/message_passing/en_equivariant_network_block.py index b8057b0f1..28a197230 100644 --- a/pina/model/block/message_passing/en_equivariant_network_block.py +++ b/pina/_src/model/block/message_passing/en_equivariant_network_block.py @@ -3,8 +3,8 @@ import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import degree -from ....utils import check_positive_integer, check_consistency -from ....model import FeedForward +from pina._src.core.utils import check_positive_integer, check_consistency +from pina._src.model.feed_forward import FeedForward class EnEquivariantNetworkBlock(MessagePassing): diff --git a/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py b/pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py similarity index 97% rename from pina/model/block/message_passing/equivariant_graph_neural_operator_block.py rename to pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py index f6c739203..8a0f30aed 100644 --- a/pina/model/block/message_passing/equivariant_graph_neural_operator_block.py +++ b/pina/_src/model/block/message_passing/equivariant_graph_neural_operator_block.py @@ -1,8 +1,10 @@ """Module for the Equivariant Graph Neural Operator block.""" import torch -from ....utils import check_positive_integer -from .en_equivariant_network_block import EnEquivariantNetworkBlock +from pina._src.core.utils import check_positive_integer +from pina._src.model.block.message_passing.en_equivariant_network_block import ( + EnEquivariantNetworkBlock, +) class EquivariantGraphNeuralOperatorBlock(torch.nn.Module): diff --git a/pina/model/block/message_passing/interaction_network_block.py b/pina/_src/model/block/message_passing/interaction_network_block.py similarity index 98% rename from pina/model/block/message_passing/interaction_network_block.py rename to pina/_src/model/block/message_passing/interaction_network_block.py index 7c6eb03f6..06fb39406 100644 --- a/pina/model/block/message_passing/interaction_network_block.py +++ b/pina/_src/model/block/message_passing/interaction_network_block.py @@ -2,8 +2,8 @@ import torch from torch_geometric.nn import MessagePassing -from ....utils import check_positive_integer -from ....model import FeedForward +from pina._src.core.utils import check_positive_integer +from pina._src.model.feed_forward import FeedForward class InteractionNetworkBlock(MessagePassing): diff --git a/pina/model/block/message_passing/radial_field_network_block.py b/pina/_src/model/block/message_passing/radial_field_network_block.py similarity index 97% rename from pina/model/block/message_passing/radial_field_network_block.py rename to pina/_src/model/block/message_passing/radial_field_network_block.py index ef621b10e..ede0fb645 100644 --- a/pina/model/block/message_passing/radial_field_network_block.py +++ b/pina/_src/model/block/message_passing/radial_field_network_block.py @@ -3,8 +3,8 @@ import torch from torch_geometric.nn import MessagePassing from torch_geometric.utils import remove_self_loops -from ....utils import check_positive_integer -from ....model import FeedForward +from pina._src.core.utils import check_positive_integer +from pina._src.model.feed_forward import FeedForward class RadialFieldNetworkBlock(MessagePassing): diff --git a/pina/model/block/orthogonal.py b/pina/_src/model/block/orthogonal.py similarity index 98% rename from pina/model/block/orthogonal.py rename to pina/_src/model/block/orthogonal.py index cd45b3c72..24021ada6 100644 --- a/pina/model/block/orthogonal.py +++ b/pina/_src/model/block/orthogonal.py @@ -1,7 +1,7 @@ """Module for the Orthogonal Block class.""" import torch -from ...utils import check_consistency +from pina._src.core.utils import check_consistency class OrthogonalBlock(torch.nn.Module): diff --git a/pina/model/block/pirate_network_block.py b/pina/_src/model/block/pirate_network_block.py similarity index 97% rename from pina/model/block/pirate_network_block.py rename to pina/_src/model/block/pirate_network_block.py index cfeb8410e..752f81901 100644 --- a/pina/model/block/pirate_network_block.py +++ b/pina/_src/model/block/pirate_network_block.py @@ -1,7 +1,7 @@ """Module for the PirateNet block class.""" import torch -from ...utils import check_consistency, check_positive_integer +from pina._src.core.utils import check_consistency, check_positive_integer class PirateNetBlock(torch.nn.Module): diff --git a/pina/model/block/pod_block.py b/pina/_src/model/block/pod_block.py similarity index 100% rename from pina/model/block/pod_block.py rename to pina/_src/model/block/pod_block.py diff --git a/pina/model/block/rbf_block.py b/pina/_src/model/block/rbf_block.py similarity index 99% rename from pina/model/block/rbf_block.py rename to pina/_src/model/block/rbf_block.py index 8001381bc..061e43109 100644 --- a/pina/model/block/rbf_block.py +++ b/pina/_src/model/block/rbf_block.py @@ -4,7 +4,7 @@ import warnings from itertools import combinations_with_replacement import torch -from ...utils import check_consistency +from pina._src.core.utils import check_consistency def linear(r): diff --git a/pina/model/block/residual.py b/pina/_src/model/block/residual.py similarity index 98% rename from pina/model/block/residual.py rename to pina/_src/model/block/residual.py index f109ce03d..d1e8134cc 100644 --- a/pina/model/block/residual.py +++ b/pina/_src/model/block/residual.py @@ -2,7 +2,7 @@ import torch from torch import nn -from ...utils import check_consistency +from pina._src.core.utils import check_consistency class ResidualBlock(nn.Module): diff --git a/pina/model/block/spectral.py b/pina/_src/model/block/spectral.py similarity index 99% rename from pina/model/block/spectral.py rename to pina/_src/model/block/spectral.py index aae915a42..fd5f48f6a 100644 --- a/pina/model/block/spectral.py +++ b/pina/_src/model/block/spectral.py @@ -2,7 +2,7 @@ import torch from torch import nn -from ...utils import check_consistency +from pina._src.core.utils import check_consistency ######## 1D Spectral Convolution ########### diff --git a/pina/model/block/stride.py b/pina/_src/model/block/stride.py similarity index 98% rename from pina/model/block/stride.py rename to pina/_src/model/block/stride.py index 2a26faf07..e802cddc0 100644 --- a/pina/model/block/stride.py +++ b/pina/_src/model/block/stride.py @@ -5,7 +5,7 @@ class Stride: """ - Stride class for continous convolution. + Stride class for continuous convolution. """ def __init__(self, dict_): diff --git a/pina/model/block/utils_convolution.py b/pina/_src/model/block/utils_convolution.py similarity index 100% rename from pina/model/block/utils_convolution.py rename to pina/_src/model/block/utils_convolution.py diff --git a/pina/model/deeponet.py b/pina/_src/model/deeponet.py similarity index 100% rename from pina/model/deeponet.py rename to pina/_src/model/deeponet.py diff --git a/pina/model/equivariant_graph_neural_operator.py b/pina/_src/model/equivariant_graph_neural_operator.py similarity index 100% rename from pina/model/equivariant_graph_neural_operator.py rename to pina/_src/model/equivariant_graph_neural_operator.py diff --git a/pina/model/feed_forward.py b/pina/_src/model/feed_forward.py similarity index 100% rename from pina/model/feed_forward.py rename to pina/_src/model/feed_forward.py diff --git a/pina/model/fourier_neural_operator.py b/pina/_src/model/fourier_neural_operator.py similarity index 100% rename from pina/model/fourier_neural_operator.py rename to pina/_src/model/fourier_neural_operator.py diff --git a/pina/model/graph_neural_operator.py b/pina/_src/model/graph_neural_operator.py similarity index 100% rename from pina/model/graph_neural_operator.py rename to pina/_src/model/graph_neural_operator.py diff --git a/pina/model/kernel_neural_operator.py b/pina/_src/model/kernel_neural_operator.py similarity index 100% rename from pina/model/kernel_neural_operator.py rename to pina/_src/model/kernel_neural_operator.py diff --git a/pina/model/low_rank_neural_operator.py b/pina/_src/model/low_rank_neural_operator.py similarity index 100% rename from pina/model/low_rank_neural_operator.py rename to pina/_src/model/low_rank_neural_operator.py diff --git a/pina/model/multi_feed_forward.py b/pina/_src/model/multi_feed_forward.py similarity index 100% rename from pina/model/multi_feed_forward.py rename to pina/_src/model/multi_feed_forward.py diff --git a/pina/model/pirate_network.py b/pina/_src/model/pirate_network.py similarity index 100% rename from pina/model/pirate_network.py rename to pina/_src/model/pirate_network.py diff --git a/pina/model/sindy.py b/pina/_src/model/sindy.py similarity index 100% rename from pina/model/sindy.py rename to pina/_src/model/sindy.py diff --git a/pina/model/spline.py b/pina/_src/model/spline.py similarity index 100% rename from pina/model/spline.py rename to pina/_src/model/spline.py diff --git a/pina/model/spline_surface.py b/pina/_src/model/spline_surface.py similarity index 100% rename from pina/model/spline_surface.py rename to pina/_src/model/spline_surface.py diff --git a/pina/_src/optim/__init__.py b/pina/_src/optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/optim/optimizer_interface.py b/pina/_src/optim/optimizer_interface.py similarity index 100% rename from pina/optim/optimizer_interface.py rename to pina/_src/optim/optimizer_interface.py diff --git a/pina/optim/scheduler_interface.py b/pina/_src/optim/scheduler_interface.py similarity index 100% rename from pina/optim/scheduler_interface.py rename to pina/_src/optim/scheduler_interface.py diff --git a/pina/optim/torch_optimizer.py b/pina/_src/optim/torch_optimizer.py similarity index 92% rename from pina/optim/torch_optimizer.py rename to pina/_src/optim/torch_optimizer.py index 7163c295e..f01d3b3cb 100644 --- a/pina/optim/torch_optimizer.py +++ b/pina/_src/optim/torch_optimizer.py @@ -2,8 +2,8 @@ import torch -from ..utils import check_consistency -from .optimizer_interface import Optimizer +from pina._src.core.utils import check_consistency +from pina._src.optim.optimizer_interface import Optimizer class TorchOptimizer(Optimizer): diff --git a/pina/optim/torch_scheduler.py b/pina/_src/optim/torch_scheduler.py similarity index 90% rename from pina/optim/torch_scheduler.py rename to pina/_src/optim/torch_scheduler.py index ff12300a1..bf9927836 100644 --- a/pina/optim/torch_scheduler.py +++ b/pina/_src/optim/torch_scheduler.py @@ -7,9 +7,9 @@ _LRScheduler as LRScheduler, ) # torch < 2.0 -from ..utils import check_consistency -from .optimizer_interface import Optimizer -from .scheduler_interface import Scheduler +from pina._src.core.utils import check_consistency +from pina._src.optim.optimizer_interface import Optimizer +from pina._src.optim.scheduler_interface import Scheduler class TorchScheduler(Scheduler): diff --git a/pina/_src/problem/__init__.py b/pina/_src/problem/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/problem/abstract_problem.py b/pina/_src/problem/abstract_problem.py similarity index 98% rename from pina/problem/abstract_problem.py rename to pina/_src/problem/abstract_problem.py index 441356def..9280d48d8 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/_src/problem/abstract_problem.py @@ -58,15 +58,10 @@ def collected_data(self): if not self.are_all_domains_discretised: warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=RuntimeWarning) - warning_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { + warning_message = "\n".join([f"""{" " * 13} ---> Domain {key} { "sampled" if key in self.discretised_domains else - "not sampled"}""" - for key in self.domains - ] - ) + "not sampled"}""" for key in self.domains]) warnings.warn( "Some of the domains are still not sampled. Consider calling " "problem.discretise_domain function for all domains before " diff --git a/pina/problem/inverse_problem.py b/pina/_src/problem/inverse_problem.py similarity index 100% rename from pina/problem/inverse_problem.py rename to pina/_src/problem/inverse_problem.py diff --git a/pina/problem/parametric_problem.py b/pina/_src/problem/parametric_problem.py similarity index 100% rename from pina/problem/parametric_problem.py rename to pina/_src/problem/parametric_problem.py diff --git a/pina/problem/spatial_problem.py b/pina/_src/problem/spatial_problem.py similarity index 100% rename from pina/problem/spatial_problem.py rename to pina/_src/problem/spatial_problem.py diff --git a/pina/problem/time_dependent_problem.py b/pina/_src/problem/time_dependent_problem.py similarity index 100% rename from pina/problem/time_dependent_problem.py rename to pina/_src/problem/time_dependent_problem.py diff --git a/pina/_src/problem/zoo/__init__.py b/pina/_src/problem/zoo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/problem/zoo/acoustic_wave.py b/pina/_src/problem/zoo/acoustic_wave.py similarity index 85% rename from pina/problem/zoo/acoustic_wave.py rename to pina/_src/problem/zoo/acoustic_wave.py index b4b2035a4..44db8eb96 100644 --- a/pina/problem/zoo/acoustic_wave.py +++ b/pina/_src/problem/zoo/acoustic_wave.py @@ -1,13 +1,14 @@ """Formulation of the acoustic wave problem.""" import torch -from ... import Condition -from ...problem import SpatialProblem, TimeDependentProblem -from ...utils import check_consistency -from ...domain import CartesianDomain -from ...equation import ( - Equation, - SystemEquation, +from pina._src.condition.condition import Condition +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.core.utils import check_consistency +from pina._src.domain.cartesian_domain import CartesianDomain +from pina._src.equation.equation import Equation +from pina._src.equation.system_equation import SystemEquation +from pina._src.equation.equation_factory import ( FixedValue, FixedGradient, AcousticWave, diff --git a/pina/problem/zoo/advection.py b/pina/_src/problem/zoo/advection.py similarity index 84% rename from pina/problem/zoo/advection.py rename to pina/_src/problem/zoo/advection.py index c709b9632..3067ce8bf 100644 --- a/pina/problem/zoo/advection.py +++ b/pina/_src/problem/zoo/advection.py @@ -1,11 +1,13 @@ """Formulation of the advection problem.""" import torch -from ... import Condition -from ...problem import SpatialProblem, TimeDependentProblem -from ...equation import Equation, Advection -from ...utils import check_consistency -from ...domain import CartesianDomain +from pina._src.condition.condition import Condition +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.equation.equation import Equation +from pina._src.equation.equation_factory import Advection +from pina._src.core.utils import check_consistency +from pina._src.domain.cartesian_domain import CartesianDomain def initial_condition(input_, output_): diff --git a/pina/problem/zoo/allen_cahn.py b/pina/_src/problem/zoo/allen_cahn.py similarity index 84% rename from pina/problem/zoo/allen_cahn.py rename to pina/_src/problem/zoo/allen_cahn.py index 900d5cf33..125a10304 100644 --- a/pina/problem/zoo/allen_cahn.py +++ b/pina/_src/problem/zoo/allen_cahn.py @@ -1,11 +1,14 @@ """Formulation of the Allen Cahn problem.""" import torch -from ... import Condition -from ...problem import SpatialProblem, TimeDependentProblem -from ...equation import Equation, AllenCahn -from ...utils import check_consistency -from ...domain import CartesianDomain + +from pina._src.condition.condition import Condition +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.equation.equation import Equation +from pina._src.equation.equation_factory import AllenCahn +from pina._src.core.utils import check_consistency +from pina._src.domain.cartesian_domain import CartesianDomain def initial_condition(input_, output_): diff --git a/pina/problem/zoo/diffusion_reaction.py b/pina/_src/problem/zoo/diffusion_reaction.py similarity index 88% rename from pina/problem/zoo/diffusion_reaction.py rename to pina/_src/problem/zoo/diffusion_reaction.py index fd02b8368..443ff49c5 100644 --- a/pina/problem/zoo/diffusion_reaction.py +++ b/pina/_src/problem/zoo/diffusion_reaction.py @@ -1,11 +1,13 @@ """Formulation of the diffusion-reaction problem.""" import torch -from ... import Condition -from ...equation import Equation, FixedValue, DiffusionReaction -from ...problem import SpatialProblem, TimeDependentProblem -from ...utils import check_consistency -from ...domain import CartesianDomain +from pina._src.condition.condition import Condition +from pina._src.equation.equation import Equation +from pina._src.equation.equation_factory import FixedValue, DiffusionReaction +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.core.utils import check_consistency +from pina._src.domain.cartesian_domain import CartesianDomain def initial_condition(input_, output_): diff --git a/pina/problem/zoo/helmholtz.py b/pina/_src/problem/zoo/helmholtz.py similarity index 87% rename from pina/problem/zoo/helmholtz.py rename to pina/_src/problem/zoo/helmholtz.py index f7f288627..f59bfdf1e 100644 --- a/pina/problem/zoo/helmholtz.py +++ b/pina/_src/problem/zoo/helmholtz.py @@ -1,11 +1,12 @@ """Formulation of the Helmholtz problem.""" import torch -from ... import Condition -from ...equation import FixedValue, Helmholtz -from ...utils import check_consistency -from ...domain import CartesianDomain -from ...problem import SpatialProblem + +from pina._src.condition.condition import Condition +from pina._src.equation.equation_factory import FixedValue, Helmholtz +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.core.utils import check_consistency +from pina._src.domain.cartesian_domain import CartesianDomain class HelmholtzProblem(SpatialProblem): diff --git a/pina/problem/zoo/inverse_poisson_2d_square.py b/pina/_src/problem/zoo/inverse_poisson_2d_square.py similarity index 90% rename from pina/problem/zoo/inverse_poisson_2d_square.py rename to pina/_src/problem/zoo/inverse_poisson_2d_square.py index 17f30ae14..19628cae0 100644 --- a/pina/problem/zoo/inverse_poisson_2d_square.py +++ b/pina/_src/problem/zoo/inverse_poisson_2d_square.py @@ -4,13 +4,17 @@ import requests import torch from io import BytesIO -from ... import Condition -from ... import LabelTensor -from ...operator import laplacian -from ...domain import CartesianDomain -from ...equation import Equation, FixedValue -from ...problem import SpatialProblem, InverseProblem -from ...utils import custom_warning_format, check_consistency + + +from pina._src.condition.condition import Condition +from pina._src.equation.equation import Equation +from pina._src.equation.equation_factory import FixedValue +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.inverse_problem import InverseProblem +from pina._src.domain.cartesian_domain import CartesianDomain +from pina._src.core.label_tensor import LabelTensor +from pina._src.core.operator import laplacian +from pina._src.core.utils import custom_warning_format, check_consistency warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=ResourceWarning) diff --git a/pina/problem/zoo/poisson_2d_square.py b/pina/_src/problem/zoo/poisson_2d_square.py similarity index 86% rename from pina/problem/zoo/poisson_2d_square.py rename to pina/_src/problem/zoo/poisson_2d_square.py index 5de38b301..12b365666 100644 --- a/pina/problem/zoo/poisson_2d_square.py +++ b/pina/_src/problem/zoo/poisson_2d_square.py @@ -1,10 +1,11 @@ """Formulation of the Poisson problem in a square domain.""" import torch -from ...equation import FixedValue, Poisson -from ...problem import SpatialProblem -from ...domain import CartesianDomain -from ... import Condition + +from pina._src.condition.condition import Condition +from pina._src.equation.equation_factory import FixedValue, Poisson +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.domain.cartesian_domain import CartesianDomain def forcing_term(input_): diff --git a/pina/problem/zoo/supervised_problem.py b/pina/_src/problem/zoo/supervised_problem.py similarity index 93% rename from pina/problem/zoo/supervised_problem.py rename to pina/_src/problem/zoo/supervised_problem.py index 61a49c0cb..81fb18a44 100644 --- a/pina/problem/zoo/supervised_problem.py +++ b/pina/_src/problem/zoo/supervised_problem.py @@ -1,7 +1,7 @@ """Formulation of a Supervised Problem in PINA.""" -from ..abstract_problem import AbstractProblem -from ... import Condition +from pina._src.problem.abstract_problem import AbstractProblem +from pina._src.condition.condition import Condition class SupervisedProblem(AbstractProblem): diff --git a/pina/_src/solver/__init__.py b/pina/_src/solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/solver/ensemble_solver/__init__.py b/pina/_src/solver/ensemble_solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/solver/ensemble_solver/ensemble_pinn.py b/pina/_src/solver/ensemble_solver/ensemble_pinn.py similarity index 96% rename from pina/solver/ensemble_solver/ensemble_pinn.py rename to pina/_src/solver/ensemble_solver/ensemble_pinn.py index 33d929ad2..002d4bd53 100644 --- a/pina/solver/ensemble_solver/ensemble_pinn.py +++ b/pina/_src/solver/ensemble_solver/ensemble_pinn.py @@ -2,9 +2,11 @@ import torch -from .ensemble_solver_interface import DeepEnsembleSolverInterface -from ..physics_informed_solver import PINNInterface -from ...problem import InverseProblem +from pina._src.solver.ensemble_solver.ensemble_solver_interface import ( + DeepEnsembleSolverInterface, +) +from pina._src.solver.physics_informed_solver import PINNInterface +from pina._src.problem import InverseProblem class DeepEnsemblePINN(PINNInterface, DeepEnsembleSolverInterface): diff --git a/pina/solver/ensemble_solver/ensemble_solver_interface.py b/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py similarity index 98% rename from pina/solver/ensemble_solver/ensemble_solver_interface.py rename to pina/_src/solver/ensemble_solver/ensemble_solver_interface.py index 6d874e1bf..7b87e28f1 100644 --- a/pina/solver/ensemble_solver/ensemble_solver_interface.py +++ b/pina/_src/solver/ensemble_solver/ensemble_solver_interface.py @@ -1,8 +1,8 @@ """Module for the DeepEnsemble solver interface.""" import torch -from ..solver import MultiSolverInterface -from ...utils import check_consistency +from pina._src.solver.solver import MultiSolverInterface +from pina._src.core.utils import check_consistency class DeepEnsembleSolverInterface(MultiSolverInterface): diff --git a/pina/solver/ensemble_solver/ensemble_supervised.py b/pina/_src/solver/ensemble_solver/ensemble_supervised.py similarity index 95% rename from pina/solver/ensemble_solver/ensemble_supervised.py rename to pina/_src/solver/ensemble_solver/ensemble_supervised.py index e4837ccdb..ea6f7edde 100644 --- a/pina/solver/ensemble_solver/ensemble_supervised.py +++ b/pina/_src/solver/ensemble_solver/ensemble_supervised.py @@ -1,7 +1,11 @@ """Module for the DeepEnsemble supervised solver.""" -from .ensemble_solver_interface import DeepEnsembleSolverInterface -from ..supervised_solver import SupervisedSolverInterface +from pina._src.solver.ensemble_solver.ensemble_solver_interface import ( + DeepEnsembleSolverInterface, +) +from pina._src.solver.supervised_solver.supervised_solver_interface import ( + SupervisedSolverInterface, +) class DeepEnsembleSupervisedSolver( diff --git a/pina/solver/garom.py b/pina/_src/solver/garom.py similarity index 97% rename from pina/solver/garom.py rename to pina/_src/solver/garom.py index 372eeddfa..3f499abd1 100644 --- a/pina/solver/garom.py +++ b/pina/_src/solver/garom.py @@ -2,10 +2,11 @@ import torch from torch.nn.modules.loss import _Loss -from .solver import MultiSolverInterface -from ..condition import InputTargetCondition -from ..utils import check_consistency -from ..loss import LossInterface, PowerLoss +from pina._src.solver.solver import MultiSolverInterface +from pina._src.condition.input_target_condition import InputTargetCondition +from pina._src.core.utils import check_consistency +from pina._src.loss.loss_interface import LossInterface +from pina._src.loss.power_loss import PowerLoss class GAROM(MultiSolverInterface): diff --git a/pina/_src/solver/physics_informed_solver/__init__.py b/pina/_src/solver/physics_informed_solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/solver/physics_informed_solver/causal_pinn.py b/pina/_src/solver/physics_informed_solver/causal_pinn.py similarity index 97% rename from pina/solver/physics_informed_solver/causal_pinn.py rename to pina/_src/solver/physics_informed_solver/causal_pinn.py index ab085be2d..e7e97392b 100644 --- a/pina/solver/physics_informed_solver/causal_pinn.py +++ b/pina/_src/solver/physics_informed_solver/causal_pinn.py @@ -2,9 +2,9 @@ import torch -from ...problem import TimeDependentProblem -from .pinn import PINN -from ...utils import check_consistency +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.solver.physics_informed_solver.pinn import PINN +from pina._src.core.utils import check_consistency class CausalPINN(PINN): diff --git a/pina/solver/physics_informed_solver/competitive_pinn.py b/pina/_src/solver/physics_informed_solver/competitive_pinn.py similarity index 97% rename from pina/solver/physics_informed_solver/competitive_pinn.py rename to pina/_src/solver/physics_informed_solver/competitive_pinn.py index 5375efba1..287e0fd8d 100644 --- a/pina/solver/physics_informed_solver/competitive_pinn.py +++ b/pina/_src/solver/physics_informed_solver/competitive_pinn.py @@ -3,9 +3,11 @@ import copy import torch -from ...problem import InverseProblem -from .pinn_interface import PINNInterface -from ..solver import MultiSolverInterface +from pina._src.problem.inverse_problem import InverseProblem +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) +from pina._src.solver.solver import MultiSolverInterface class CompetitivePINN(PINNInterface, MultiSolverInterface): diff --git a/pina/solver/physics_informed_solver/gradient_pinn.py b/pina/_src/solver/physics_informed_solver/gradient_pinn.py similarity index 96% rename from pina/solver/physics_informed_solver/gradient_pinn.py rename to pina/_src/solver/physics_informed_solver/gradient_pinn.py index 0de431c41..9583c3025 100644 --- a/pina/solver/physics_informed_solver/gradient_pinn.py +++ b/pina/_src/solver/physics_informed_solver/gradient_pinn.py @@ -2,9 +2,9 @@ import torch -from .pinn import PINN -from ...operator import grad -from ...problem import SpatialProblem +from pina._src.solver.physics_informed_solver.pinn import PINN +from pina._src.core.operator import grad +from pina._src.problem.spatial_problem import SpatialProblem class GradientPINN(PINN): diff --git a/pina/solver/physics_informed_solver/pinn.py b/pina/_src/solver/physics_informed_solver/pinn.py similarity index 95% rename from pina/solver/physics_informed_solver/pinn.py rename to pina/_src/solver/physics_informed_solver/pinn.py index 914d01451..dbea8cbe3 100644 --- a/pina/solver/physics_informed_solver/pinn.py +++ b/pina/_src/solver/physics_informed_solver/pinn.py @@ -2,9 +2,11 @@ import torch -from .pinn_interface import PINNInterface -from ..solver import SingleSolverInterface -from ...problem import InverseProblem +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) +from pina._src.solver.solver import SingleSolverInterface +from pina._src.problem.inverse_problem import InverseProblem class PINN(PINNInterface, SingleSolverInterface): diff --git a/pina/solver/physics_informed_solver/pinn_interface.py b/pina/_src/solver/physics_informed_solver/pinn_interface.py similarity index 96% rename from pina/solver/physics_informed_solver/pinn_interface.py rename to pina/_src/solver/physics_informed_solver/pinn_interface.py index 65a0dd78f..60330372a 100644 --- a/pina/solver/physics_informed_solver/pinn_interface.py +++ b/pina/_src/solver/physics_informed_solver/pinn_interface.py @@ -4,11 +4,11 @@ import warnings import torch -from ...utils import custom_warning_format -from ..supervised_solver import SupervisedSolverInterface -from ...condition import ( - InputTargetCondition, - InputEquationCondition, +from pina._src.core.utils import custom_warning_format +from pina._src.solver.supervised_solver import SupervisedSolverInterface +from pina._src.condition.input_target_condition import InputTargetCondition +from pina._src.condition.input_equation_condition import InputEquationCondition +from pina._src.condition.domain_equation_condition import ( DomainEquationCondition, ) diff --git a/pina/solver/physics_informed_solver/rba_pinn.py b/pina/_src/solver/physics_informed_solver/rba_pinn.py similarity index 99% rename from pina/solver/physics_informed_solver/rba_pinn.py rename to pina/_src/solver/physics_informed_solver/rba_pinn.py index 5c8d50fed..7e7deda0a 100644 --- a/pina/solver/physics_informed_solver/rba_pinn.py +++ b/pina/_src/solver/physics_informed_solver/rba_pinn.py @@ -2,8 +2,8 @@ import torch -from .pinn import PINN -from ...utils import check_consistency +from pina._src.solver.physics_informed_solver.pinn import PINN +from pina._src.core.utils import check_consistency class RBAPINN(PINN): diff --git a/pina/solver/physics_informed_solver/self_adaptive_pinn.py b/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py similarity index 98% rename from pina/solver/physics_informed_solver/self_adaptive_pinn.py rename to pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py index b1d2a2cb4..ee7f281e6 100644 --- a/pina/solver/physics_informed_solver/self_adaptive_pinn.py +++ b/pina/_src/solver/physics_informed_solver/self_adaptive_pinn.py @@ -1,12 +1,13 @@ """Module for the Self-Adaptive PINN solver.""" -from copy import deepcopy import torch -from ...utils import check_consistency -from ...problem import InverseProblem -from ..solver import MultiSolverInterface -from .pinn_interface import PINNInterface +from pina._src.core.utils import check_consistency +from pina._src.problem.inverse_problem import InverseProblem +from pina._src.solver.solver import MultiSolverInterface +from pina._src.solver.physics_informed_solver.pinn_interface import ( + PINNInterface, +) class Weights(torch.nn.Module): diff --git a/pina/solver/solver.py b/pina/_src/solver/solver.py similarity index 97% rename from pina/solver/solver.py rename to pina/_src/solver/solver.py index 57a28a8a7..d6abd493b 100644 --- a/pina/solver/solver.py +++ b/pina/_src/solver/solver.py @@ -5,11 +5,15 @@ import torch from torch._dynamo import OptimizedModule -from ..problem import AbstractProblem, InverseProblem -from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler -from ..loss import WeightingInterface -from ..loss.scalar_weighting import _NoWeighting -from ..utils import check_consistency, labelize_forward +from pina._src.problem.abstract_problem import AbstractProblem +from pina._src.problem.inverse_problem import InverseProblem +from pina._src.optim.optimizer_interface import Optimizer +from pina._src.optim.scheduler_interface import Scheduler +from pina._src.optim.torch_optimizer import TorchOptimizer +from pina._src.optim.torch_scheduler import TorchScheduler +from pina._src.loss.weighting_interface import WeightingInterface +from pina._src.loss.scalar_weighting import _NoWeighting +from pina._src.core.utils import check_consistency, labelize_forward class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): diff --git a/pina/solver/supervised_solver/__init__.py b/pina/_src/solver/supervised_solver/__init__.py similarity index 100% rename from pina/solver/supervised_solver/__init__.py rename to pina/_src/solver/supervised_solver/__init__.py diff --git a/pina/solver/supervised_solver/reduced_order_model.py b/pina/_src/solver/supervised_solver/reduced_order_model.py similarity index 97% rename from pina/solver/supervised_solver/reduced_order_model.py rename to pina/_src/solver/supervised_solver/reduced_order_model.py index 727f438e2..d9830d766 100644 --- a/pina/solver/supervised_solver/reduced_order_model.py +++ b/pina/_src/solver/supervised_solver/reduced_order_model.py @@ -1,8 +1,10 @@ """Module for the Reduced Order Model solver""" import torch -from .supervised_solver_interface import SupervisedSolverInterface -from ..solver import SingleSolverInterface +from pina._src.solver.supervised_solver.supervised_solver_interface import ( + SupervisedSolverInterface, +) +from pina._src.solver.solver import SingleSolverInterface class ReducedOrderModelSolver(SupervisedSolverInterface, SingleSolverInterface): diff --git a/pina/solver/supervised_solver/supervised.py b/pina/_src/solver/supervised_solver/supervised.py similarity index 95% rename from pina/solver/supervised_solver/supervised.py rename to pina/_src/solver/supervised_solver/supervised.py index 70cd8fe4b..65d438c01 100644 --- a/pina/solver/supervised_solver/supervised.py +++ b/pina/_src/solver/supervised_solver/supervised.py @@ -1,7 +1,9 @@ """Module for the Supervised solver.""" -from .supervised_solver_interface import SupervisedSolverInterface -from ..solver import SingleSolverInterface +from pina._src.solver.supervised_solver.supervised_solver_interface import ( + SupervisedSolverInterface, +) +from pina._src.solver.solver import SingleSolverInterface class SupervisedSolver(SupervisedSolverInterface, SingleSolverInterface): diff --git a/pina/solver/supervised_solver/supervised_solver_interface.py b/pina/_src/solver/supervised_solver/supervised_solver_interface.py similarity index 92% rename from pina/solver/supervised_solver/supervised_solver_interface.py rename to pina/_src/solver/supervised_solver/supervised_solver_interface.py index 97070ce8f..030fc3f82 100644 --- a/pina/solver/supervised_solver/supervised_solver_interface.py +++ b/pina/_src/solver/supervised_solver/supervised_solver_interface.py @@ -5,10 +5,10 @@ import torch from torch.nn.modules.loss import _Loss -from ..solver import SolverInterface -from ...utils import check_consistency -from ...loss.loss_interface import LossInterface -from ...condition import InputTargetCondition +from pina._src.solver.solver import SolverInterface +from pina._src.core.utils import check_consistency +from pina._src.loss.loss_interface import LossInterface +from pina._src.condition.input_target_condition import InputTargetCondition class SupervisedSolverInterface(SolverInterface): diff --git a/pina/adaptive_function/__init__.py b/pina/adaptive_function/__init__.py index d53c5f368..9047be94a 100644 --- a/pina/adaptive_function/__init__.py +++ b/pina/adaptive_function/__init__.py @@ -1,4 +1,10 @@ -"""Adaptive Activation Functions Module.""" +"""Adaptive activation functions with learnable parameters. + +This module provides implementations of standard activation functions (ReLU, +SiLU, Tanh, etc.) augmented with trainable weights, as well as specialized +functions like SIREN, designed to improve convergence in PINNs and Neural +Operators. +""" __all__ = [ "AdaptiveActivationFunctionInterface", @@ -16,7 +22,7 @@ "AdaptiveExp", ] -from .adaptive_function import ( +from pina._src.adaptive_function.adaptive_function import ( AdaptiveReLU, AdaptiveSigmoid, AdaptiveTanh, @@ -30,4 +36,6 @@ AdaptiveSIREN, AdaptiveExp, ) -from .adaptive_function_interface import AdaptiveActivationFunctionInterface +from pina._src.adaptive_function.adaptive_function_interface import ( + AdaptiveActivationFunctionInterface, +) diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index 92da661cb..2f6d5a0a2 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -1,4 +1,9 @@ -"""Module for the Pina Callbacks.""" +"""Training callbacks for PINA lifecycle management. + +This module provides specialized callbacks for training Scientific Machine +Learning models, including adaptive sample refinement (R3), optimizer +switching logic, and data normalization utilities. +""" __all__ = [ "SwitchOptimizer", @@ -9,9 +14,11 @@ "R3Refinement", ] -from .optim.switch_optimizer import SwitchOptimizer -from .optim.switch_scheduler import SwitchScheduler -from .processing.normalizer_data_callback import NormalizerDataCallback -from .processing.pina_progress_bar import PINAProgressBar -from .processing.metric_tracker import MetricTracker -from .refinement import R3Refinement +from pina._src.callback.optim.switch_optimizer import SwitchOptimizer +from pina._src.callback.optim.switch_scheduler import SwitchScheduler +from pina._src.callback.processing.normalizer_data_callback import ( + NormalizerDataCallback, +) +from pina._src.callback.processing.pina_progress_bar import PINAProgressBar +from pina._src.callback.processing.metric_tracker import MetricTracker +from pina._src.callback.refinement.r3_refinement import R3Refinement diff --git a/pina/callback/refinement/__init__.py b/pina/callback/refinement/__init__.py deleted file mode 100644 index 396fcabaa..000000000 --- a/pina/callback/refinement/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -Module for Pina Refinement callbacks. -""" - -__all__ = [ - "RefinementInterface", - "R3Refinement", -] - -from .refinement_interface import RefinementInterface -from .r3_refinement import R3Refinement diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..696567fa8 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -1,4 +1,10 @@ -"""Module for PINA Conditions classes.""" +"""Conditions for defining physics and data constraints. + +This module provides the interface and implementations for binding mathematical +equations, experimental data, and neural network targets to specific spatial +domains or graph structures. It supports various input-target mappings including +tensor-based, graph-based, and equation-based constraints. +""" __all__ = [ "Condition", @@ -17,22 +23,24 @@ "TensorDataCondition", ] -from .condition_interface import ConditionInterface -from .condition import Condition -from .domain_equation_condition import DomainEquationCondition -from .input_target_condition import ( +from pina._src.condition.condition_interface import ConditionInterface +from pina._src.condition.condition import Condition +from pina._src.condition.domain_equation_condition import ( + DomainEquationCondition, +) +from pina._src.condition.input_target_condition import ( InputTargetCondition, TensorInputTensorTargetCondition, TensorInputGraphTargetCondition, GraphInputTensorTargetCondition, GraphInputGraphTargetCondition, ) -from .input_equation_condition import ( +from pina._src.condition.input_equation_condition import ( InputEquationCondition, InputTensorEquationCondition, InputGraphEquationCondition, ) -from .data_condition import ( +from pina._src.condition.data_condition import ( DataCondition, GraphDataCondition, TensorDataCondition, diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 70e100011..164a6c7aa 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -1,7 +1,12 @@ -"""Module for data, data module, and dataset.""" +"""Data management utilities for PINA. + +This module provides specialized Dataset and DataModule implementations +designed to handle physical coordinates, experimental observations, and +graph-structured data within the PINA training pipeline. +""" __all__ = ["PinaDataModule", "PinaDataset"] -from .data_module import PinaDataModule -from .dataset import PinaDataset +from pina._src.data.data_module import PinaDataModule +from pina._src.data.dataset import PinaDataset diff --git a/pina/domain/__init__.py b/pina/domain/__init__.py index 57999f4d8..6782563db 100644 --- a/pina/domain/__init__.py +++ b/pina/domain/__init__.py @@ -1,4 +1,9 @@ -"""Module to create and handle domains.""" +"""Geometry and domain definitions for spatial sampling. + +This module provides tools for defining the physical space of a problem, +including primitive shapes (Cartesian, Ellipsoid, Simplex) and set-theoretic +operations (Union, Intersection, etc.) for building complex geometries. +""" __all__ = [ "DomainInterface", @@ -13,13 +18,13 @@ "Exclusion", ] -from .domain_interface import DomainInterface -from .base_domain import BaseDomain -from .cartesian_domain import CartesianDomain -from .ellipsoid_domain import EllipsoidDomain -from .simplex_domain import SimplexDomain -from .operation_interface import OperationInterface -from .union import Union -from .intersection import Intersection -from .difference import Difference -from .exclusion import Exclusion +from pina._src.domain.domain_interface import DomainInterface +from pina._src.domain.base_domain import BaseDomain +from pina._src.domain.cartesian_domain import CartesianDomain +from pina._src.domain.ellipsoid_domain import EllipsoidDomain +from pina._src.domain.simplex_domain import SimplexDomain +from pina._src.domain.operation_interface import OperationInterface +from pina._src.domain.union import Union +from pina._src.domain.intersection import Intersection +from pina._src.domain.difference import Difference +from pina._src.domain.exclusion import Exclusion diff --git a/pina/equation/__init__.py b/pina/equation/__init__.py index 87a33554b..551099af6 100644 --- a/pina/equation/__init__.py +++ b/pina/equation/__init__.py @@ -1,4 +1,10 @@ -"""Module to define equations and systems of equations.""" +"""Mathematical equations and physical laws. + +This module provides a framework for defining differential equations, +boundary conditions, and complex systems of equations. It includes +pre-defined physical models such as Poisson, Laplace, and Wave equations, +along with factories for common derivative-based constraints. +""" __all__ = [ "SystemEquation", @@ -16,8 +22,8 @@ "AcousticWave", ] -from .equation import Equation -from .equation_factory import ( +from pina._src.equation.equation import Equation +from pina._src.equation.equation_factory import ( FixedFlux, FixedGradient, FixedLaplacian, @@ -30,4 +36,4 @@ Poisson, AcousticWave, ) -from .system_equation import SystemEquation +from pina._src.equation.system_equation import SystemEquation diff --git a/pina/graph.py b/pina/graph.py index 201f37a24..04c6374f5 100644 --- a/pina/graph.py +++ b/pina/graph.py @@ -1,421 +1,13 @@ -"""Module to build Graph objects and perform operations on them.""" +"""Public API for Graph connectivity and neighborhood logic. -import torch -from torch_geometric.data import Data, Batch -from torch_geometric.utils import to_undirected -from torch_geometric.utils.loop import remove_self_loops -from .label_tensor import LabelTensor -from .utils import check_consistency, is_function +This module exposes core graph types used to define spatial relationships +between points, such as fixed-radius and k-nearest neighbor (KNN) structures. +""" +from pina._src.core.graph import Graph, RadiusGraph, KNNGraph -class Graph(Data): - """ - Extends :class:`~torch_geometric.data.Data` class to include additional - checks and functionlities. - """ - - def __new__( - cls, - **kwargs, - ): - """ - Create a new instance of the :class:`~pina.graph.Graph` class by - checking the consistency of the input data and storing the attributes. - - :param dict kwargs: Parameters used to initialize the - :class:`~pina.graph.Graph` object. - :return: A new instance of the :class:`~pina.graph.Graph` class. - :rtype: Graph - """ - # create class instance - instance = Data.__new__(cls) - - # check the consistency of types defined in __init__, the others are not - # checked (as in pyg Data object) - instance._check_type_consistency(**kwargs) - - return instance - - def __init__( - self, - x=None, - edge_index=None, - pos=None, - edge_attr=None, - undirected=False, - **kwargs, - ): - """ - Initialize the object by setting the node features, edge index, - edge attributes, and positions. The edge index is preprocessed to make - the graph undirected if required. For more details, see the - :meth:`torch_geometric.data.Data` - - :param x: Optional tensor of node features ``(N, F)`` where ``F`` is the - number of features per node. - :type x: torch.Tensor, LabelTensor - :param torch.Tensor edge_index: A tensor of shape ``(2, E)`` - representing the indices of the graph's edges. - :param pos: A tensor of shape ``(N, D)`` representing the positions of - ``N`` points in ``D``-dimensional space. - :type pos: torch.Tensor | LabelTensor - :param edge_attr: Optional tensor of edge_featured ``(E, F')`` where - ``F'`` is the number of edge features - :type edge_attr: torch.Tensor | LabelTensor - :param bool undirected: Whether to make the graph undirected - :param dict kwargs: Additional keyword arguments passed to the - :class:`~torch_geometric.data.Data` class constructor. - """ - # preprocessing - self._preprocess_edge_index(edge_index, undirected) - - # calling init - super().__init__( - x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos, **kwargs - ) - - def _check_type_consistency(self, **kwargs): - """ - Check the consistency of the types of the input data. - - :param dict kwargs: Attributes to be checked for consistency. - """ - # default types, specified in cls.__new__, by default they are Nont - # if specified in **kwargs they get override - x, pos, edge_index, edge_attr = None, None, None, None - if "pos" in kwargs: - pos = kwargs["pos"] - self._check_pos_consistency(pos) - if "edge_index" in kwargs: - edge_index = kwargs["edge_index"] - self._check_edge_index_consistency(edge_index) - if "x" in kwargs: - x = kwargs["x"] - self._check_x_consistency(x, pos) - if "edge_attr" in kwargs: - edge_attr = kwargs["edge_attr"] - self._check_edge_attr_consistency(edge_attr, edge_index) - if "undirected" in kwargs: - undirected = kwargs["undirected"] - check_consistency(undirected, bool) - - @staticmethod - def _check_pos_consistency(pos): - """ - Check if the position tensor is consistent. - :param torch.Tensor pos: The position tensor. - :raises ValueError: If the position tensor is not consistent. - """ - if pos is not None: - check_consistency(pos, (torch.Tensor, LabelTensor)) - if pos.ndim != 2: - raise ValueError("pos must be a 2D tensor.") - - @staticmethod - def _check_edge_index_consistency(edge_index): - """ - Check if the edge index is consistent. - - :param torch.Tensor edge_index: The edge index tensor. - :raises ValueError: If the edge index tensor is not consistent. - """ - check_consistency(edge_index, (torch.Tensor, LabelTensor)) - if edge_index.ndim != 2: - raise ValueError("edge_index must be a 2D tensor.") - if edge_index.size(0) != 2: - raise ValueError("edge_index must have shape [2, num_edges].") - - @staticmethod - def _check_edge_attr_consistency(edge_attr, edge_index): - """ - Check if the edge attribute tensor is consistent in type and shape - with the edge index. - - :param edge_attr: The edge attribute tensor. - :type edge_attr: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: The edge index tensor. - :raises ValueError: If the edge attribute tensor is not consistent. - """ - if edge_attr is not None: - check_consistency(edge_attr, (torch.Tensor, LabelTensor)) - if edge_attr.ndim != 2: - raise ValueError("edge_attr must be a 2D tensor.") - if edge_attr.size(0) != edge_index.size(1): - raise ValueError( - "edge_attr must have shape " - "[num_edges, num_edge_features], expected " - f"num_edges {edge_index.size(1)} " - f"got {edge_attr.size(0)}." - ) - - @staticmethod - def _check_x_consistency(x, pos=None): - """ - Check if the input tensor x is consistent with the position tensor - `pos`. - - :param x: The input tensor. - :type x: torch.Tensor | LabelTensor - :param pos: The position tensor. - :type pos: torch.Tensor | LabelTensor - :raises ValueError: If the input tensor is not consistent. - """ - if x is not None: - check_consistency(x, (torch.Tensor, LabelTensor)) - if x.ndim != 2: - raise ValueError("x must be a 2D tensor.") - if pos is not None: - if x.size(0) != pos.size(0): - raise ValueError("Inconsistent number of nodes.") - - @staticmethod - def _preprocess_edge_index(edge_index, undirected): - """ - Preprocess the edge index to make the graph undirected (if required). - - :param torch.Tensor edge_index: The edge index. - :param bool undirected: Whether the graph is undirected. - :return: The preprocessed edge index. - :rtype: torch.Tensor - """ - if undirected: - edge_index = to_undirected(edge_index) - return edge_index - - def extract(self, labels, attr="x"): - """ - Perform extraction of labels from the attribute specified by `attr`. - - :param labels: Labels to extract - :type labels: list[str] | tuple[str] | str | dict - :return: Batch object with extraction performed on x - :rtype: PinaBatch - """ - # Extract labels from LabelTensor object - tensor = getattr(self, attr).extract(labels) - # Set the extracted tensor as the new attribute - setattr(self, attr, tensor) - return self - - -class GraphBuilder: - """ - A class that allows an easy definition of :class:`Graph` instances. - """ - - def __new__( - cls, - pos, - edge_index, - x=None, - edge_attr=False, - custom_edge_func=None, - loop=True, - **kwargs, - ): - """ - Compute the edge attributes and create a new instance of the - :class:`~pina.graph.Graph` class. - - :param pos: A tensor of shape ``(N, D)`` representing the positions of - ``N`` points in ``D``-dimensional space. - :type pos: torch.Tensor or LabelTensor - :param edge_index: A tensor of shape ``(2, E)`` representing the indices - of the graph's edges. - :type edge_index: torch.Tensor - :param x: Optional tensor of node features of shape ``(N, F)``, where - ``F`` is the number of features per node. - :type x: torch.Tensor | LabelTensor, optional - :param bool edge_attr: Whether to compute the edge attributes. - :param custom_edge_func: A custom function to compute edge attributes. - If provided, overrides ``edge_attr``. - :type custom_edge_func: Callable, optional - :param bool loop: Whether to include self-loops. - :param kwargs: Additional keyword arguments passed to the - :class:`~pina.graph.Graph` class constructor. - :return: A :class:`~pina.graph.Graph` instance constructed using the - provided information. - :rtype: Graph - """ - if not loop: - edge_index = remove_self_loops(edge_index)[0] - edge_attr = cls._create_edge_attr( - pos, edge_index, edge_attr, custom_edge_func or cls._build_edge_attr - ) - return Graph( - x=x, - edge_index=edge_index, - edge_attr=edge_attr, - pos=pos, - **kwargs, - ) - - @staticmethod - def _create_edge_attr(pos, edge_index, edge_attr, func): - """ - Create the edge attributes based on the input parameters. - - :param pos: Positions of the points. - :type pos: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: Edge indices. - :param bool edge_attr: Whether to compute the edge attributes. - :param Callable func: Function to compute the edge attributes. - :raises ValueError: If ``func`` is not a function. - :return: The edge attributes. - :rtype: torch.Tensor | LabelTensor | None - """ - check_consistency(edge_attr, bool) - if edge_attr: - if is_function(func): - return func(pos, edge_index) - raise ValueError("custom_edge_func must be a function.") - return None - - @staticmethod - def _build_edge_attr(pos, edge_index): - """ - Default function to compute the edge attributes. - - :param pos: Positions of the points. - :type pos: torch.Tensor | LabelTensor - :param torch.Tensor edge_index: Edge indices. - :return: The edge attributes. - :rtype: torch.Tensor - """ - return ( - (pos[edge_index[0]] - pos[edge_index[1]]) - .abs() - .as_subclass(torch.Tensor) - ) - - -class RadiusGraph(GraphBuilder): - """ - Extends the :class:`~pina.graph.GraphBuilder` class to compute - ``edge_index`` based on a radius. Each point is connected to all the points - within the radius. - """ - - def __new__(cls, pos, radius, **kwargs): - """ - Instantiate the :class:`~pina.graph.Graph` class by computing the - ``edge_index`` based on the radius provided. - - :param pos: A tensor of shape ``(N, D)`` representing the positions of - ``N`` points in ``D``-dimensional space. - :type pos: torch.Tensor | LabelTensor - :param float radius: The radius within which points are connected. - :param dict kwargs: The additional keyword arguments to be passed to - :class:`GraphBuilder` and :class:`Graph` classes. - :return: A :class:`~pina.graph.Graph` instance with the computed - ``edge_index``. - :rtype: Graph - """ - edge_index = cls.compute_radius_graph(pos, radius) - return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) - - @staticmethod - def compute_radius_graph(points, radius): - """ - Computes the ``edge_index`` based on the radius. Each point is connected - to all the points within the radius. - - :param points: A tensor of shape ``(N, D)`` representing the positions - of ``N`` points in ``D``-dimensional space. - :type points: torch.Tensor | LabelTensor - :param float radius: The radius within which points are connected. - :return: A tensor of shape ``(2, E)``, with ``E`` number of edges, - representing the edge indices of the graph. - :rtype: torch.Tensor - """ - dist = torch.cdist(points, points, p=2) - return ( - torch.nonzero(dist <= radius, as_tuple=False) - .t() - .as_subclass(torch.Tensor) - ) - - -class KNNGraph(GraphBuilder): - """ - Extends the :class:`~pina.graph.GraphBuilder` class to compute - ``edge_index`` based on a K-nearest neighbors algorithm. - """ - - def __new__(cls, pos, neighbours, **kwargs): - """ - Instantiate the :class:`~pina.graph.Graph` class by computing the - ``edge_index`` based on the K-nearest neighbors algorithm. - - :param pos: A tensor of shape ``(N, D)`` representing the positions of - ``N`` points in ``D``-dimensional space. - :type pos: torch.Tensor | LabelTensor - :param int neighbours: The number of nearest neighbors to consider when - building the graph. - :param dict kwargs: The additional keyword arguments to be passed to - :class:`GraphBuilder` and :class:`Graph` classes. - - :return: A :class:`~pina.graph.Graph` instance with the computed - ``edge_index``. - :rtype: Graph - """ - - edge_index = cls.compute_knn_graph(pos, neighbours) - return super().__new__(cls, pos=pos, edge_index=edge_index, **kwargs) - - @staticmethod - def compute_knn_graph(points, neighbours): - """ - Computes the ``edge_index`` based on the K-nearest neighbors algorithm. - - :param points: A tensor of shape ``(N, D)`` representing the positions - of ``N`` points in ``D``-dimensional space. - :type points: torch.Tensor | LabelTensor - :param int neighbours: The number of nearest neighbors to consider when - building the graph. - :return: A tensor of shape ``(2, E)``, with ``E`` number of edges, - representing the edge indices of the graph. - :rtype: torch.Tensor - """ - dist = torch.cdist(points, points, p=2) - knn_indices = torch.topk(dist, k=neighbours, largest=False).indices - row = torch.arange(points.size(0)).repeat_interleave(neighbours) - col = knn_indices.flatten() - return torch.stack([row, col], dim=0).as_subclass(torch.Tensor) - - -class LabelBatch(Batch): - """ - Extends the :class:`~torch_geometric.data.Batch` class to include - :class:`~pina.label_tensor.LabelTensor` objects. - """ - - @classmethod - def from_data_list(cls, data_list): - """ - Create a Batch object from a list of :class:`~torch_geometric.data.Data` - or :class:`~pina.graph.Graph` objects. - - :param data_list: List of :class:`~torch_geometric.data.Data` or - :class:`~pina.graph.Graph` objects. - :type data_list: list[Data] | list[Graph] - :return: A :class:`~torch_geometric.data.Batch` object containing - the input data. - :rtype: :class:`~torch_geometric.data.Batch` - """ - # Store the labels of Data/Graph objects (all data have the same labels) - # If the data do not contain labels, labels is an empty dictionary, - # therefore the labels are not stored - labels = { - k: v.labels - for k, v in data_list[0].items() - if isinstance(v, LabelTensor) - } - - # Create a Batch object from the list of Data objects - batch = super().from_data_list(data_list) - - # Put the labels back in the Batch object - for k, v in labels.items(): - batch[k].labels = v - return batch +__all__ = [ + "Graph", + "RadiusGraph", + "KNNGraph", +] diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index d91cf7ab0..83ad5ef7e 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -1,4 +1,10 @@ -"""Module for loss functions and weighting functions.""" +"""Loss functions and balancing strategies for multi-objective optimization. + +This module provides standard error metrics (Lp, Power loss) and sophisticated +weighting schemes designed to balance residual, boundary, and data-driven loss +terms, including dynamic methods like Neural Tangent Kernel (NTK) and +self-adaptive weighting. +""" __all__ = [ "LossInterface", @@ -11,11 +17,11 @@ "LinearWeighting", ] -from .loss_interface import LossInterface -from .power_loss import PowerLoss -from .lp_loss import LpLoss -from .weighting_interface import WeightingInterface -from .scalar_weighting import ScalarWeighting -from .ntk_weighting import NeuralTangentKernelWeighting -from .self_adaptive_weighting import SelfAdaptiveWeighting -from .linear_weighting import LinearWeighting +from pina._src.loss.loss_interface import LossInterface +from pina._src.loss.power_loss import PowerLoss +from pina._src.loss.lp_loss import LpLoss +from pina._src.loss.weighting_interface import WeightingInterface +from pina._src.loss.scalar_weighting import ScalarWeighting +from pina._src.loss.ntk_weighting import NeuralTangentKernelWeighting +from pina._src.loss.self_adaptive_weighting import SelfAdaptiveWeighting +from pina._src.loss.linear_weighting import LinearWeighting diff --git a/pina/model/__init__.py b/pina/model/__init__.py index 05ccc6c8c..0310eef5c 100644 --- a/pina/model/__init__.py +++ b/pina/model/__init__.py @@ -16,18 +16,21 @@ "PirateNet", "EquivariantGraphNeuralOperator", "SINDy", + "SplineSurface", ] -from .feed_forward import FeedForward, ResidualFeedForward -from .multi_feed_forward import MultiFeedForward -from .deeponet import DeepONet, MIONet -from .fourier_neural_operator import FNO, FourierIntegralKernel -from .kernel_neural_operator import KernelNeuralOperator -from .average_neural_operator import AveragingNeuralOperator -from .low_rank_neural_operator import LowRankNeuralOperator -from .spline import Spline -from .spline_surface import SplineSurface -from .graph_neural_operator import GraphNeuralOperator -from .pirate_network import PirateNet -from .equivariant_graph_neural_operator import EquivariantGraphNeuralOperator -from .sindy import SINDy +from pina._src.model.feed_forward import FeedForward, ResidualFeedForward +from pina._src.model.multi_feed_forward import MultiFeedForward +from pina._src.model.deeponet import DeepONet, MIONet +from pina._src.model.fourier_neural_operator import FNO, FourierIntegralKernel +from pina._src.model.kernel_neural_operator import KernelNeuralOperator +from pina._src.model.average_neural_operator import AveragingNeuralOperator +from pina._src.model.low_rank_neural_operator import LowRankNeuralOperator +from pina._src.model.spline import Spline +from pina._src.model.spline_surface import SplineSurface +from pina._src.model.graph_neural_operator import GraphNeuralOperator +from pina._src.model.pirate_network import PirateNet +from pina._src.model.equivariant_graph_neural_operator import ( + EquivariantGraphNeuralOperator, +) +from pina._src.model.sindy import SINDy diff --git a/pina/model/block/__init__.py b/pina/model/block/__init__.py index 08b313387..88bfd9e43 100644 --- a/pina/model/block/__init__.py +++ b/pina/model/block/__init__.py @@ -1,4 +1,10 @@ -"""Module for the building blocks of the neural models.""" +"""Architectural primitives and building blocks. + +This module provides a comprehensive collection of neural network components, +ranging from standard units (Residual, Enhanced Linear) to specialized layers +for Scientific Machine Learning, including Neural Operator blocks (FNO, GNO, +AVNO), spectral convolutions, and coordinate embeddings (Fourier Features). +""" __all__ = [ "ContinuousConvBlock", @@ -21,19 +27,26 @@ "PirateNetBlock", ] -from .convolution_2d import ContinuousConvBlock -from .residual import ResidualBlock, EnhancedLinear -from .spectral import ( +from pina._src.model.block.convolution_2d import ContinuousConvBlock +from pina._src.model.block.residual import ResidualBlock, EnhancedLinear +from pina._src.model.block.spectral import ( SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D, ) -from .fourier_block import FourierBlock1D, FourierBlock2D, FourierBlock3D -from .pod_block import PODBlock -from .orthogonal import OrthogonalBlock -from .embedding import PeriodicBoundaryEmbedding, FourierFeatureEmbedding -from .average_neural_operator_block import AVNOBlock -from .low_rank_block import LowRankBlock -from .rbf_block import RBFBlock -from .gno_block import GNOBlock -from .pirate_network_block import PirateNetBlock +from pina._src.model.block.fourier_block import ( + FourierBlock1D, + FourierBlock2D, + FourierBlock3D, +) +from pina._src.model.block.pod_block import PODBlock +from pina._src.model.block.orthogonal import OrthogonalBlock +from pina._src.model.block.embedding import ( + PeriodicBoundaryEmbedding, + FourierFeatureEmbedding, +) +from pina._src.model.block.average_neural_operator_block import AVNOBlock +from pina._src.model.block.low_rank_block import LowRankBlock +from pina._src.model.block.rbf_block import RBFBlock +from pina._src.model.block.gno_block import GNOBlock +from pina._src.model.block.pirate_network_block import PirateNetBlock diff --git a/pina/model/block/message_passing.py b/pina/model/block/message_passing.py new file mode 100644 index 000000000..652e9dbde --- /dev/null +++ b/pina/model/block/message_passing.py @@ -0,0 +1,25 @@ +"""Module for the message passing blocks of the graph neural models.""" + +__all__ = [ + "InteractionNetworkBlock", + "DeepTensorNetworkBlock", + "EnEquivariantNetworkBlock", + "RadialFieldNetworkBlock", + "EquivariantGraphNeuralOperatorBlock", +] + +from pina._src.model.block.message_passing.interaction_network_block import ( + InteractionNetworkBlock, +) +from pina._src.model.block.message_passing.deep_tensor_network_block import ( + DeepTensorNetworkBlock, +) +from pina._src.model.block.message_passing.en_equivariant_network_block import ( + EnEquivariantNetworkBlock, +) +from pina._src.model.block.message_passing.radial_field_network_block import ( + RadialFieldNetworkBlock, +) +from pina._src.model.block.message_passing.equivariant_graph_neural_operator_block import ( + EquivariantGraphNeuralOperatorBlock, +) diff --git a/pina/model/block/message_passing/__init__.py b/pina/model/block/message_passing/__init__.py deleted file mode 100644 index 202e1fde4..000000000 --- a/pina/model/block/message_passing/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Module for the message passing blocks of the graph neural models.""" - -__all__ = [ - "InteractionNetworkBlock", - "DeepTensorNetworkBlock", - "EnEquivariantNetworkBlock", - "RadialFieldNetworkBlock", - "EquivariantGraphNeuralOperatorBlock", -] - -from .interaction_network_block import InteractionNetworkBlock -from .deep_tensor_network_block import DeepTensorNetworkBlock -from .en_equivariant_network_block import EnEquivariantNetworkBlock -from .radial_field_network_block import RadialFieldNetworkBlock -from .equivariant_graph_neural_operator_block import ( - EquivariantGraphNeuralOperatorBlock, -) diff --git a/pina/operator.py b/pina/operator.py index bf2351bce..fcd214804 100644 --- a/pina/operator.py +++ b/pina/operator.py @@ -1,483 +1,29 @@ -""" -Module for vectorized differential operators implementation. - -Differential operators are used to define differential problems and are -implemented to run efficiently on various accelerators, including CPU, GPU, TPU, -and MPS. - -Each differential operator takes the following inputs: -- A tensor on which the operator is applied. -- A tensor with respect to which the operator is computed. -- The names of the output variables for which the operator is evaluated. -- The names of the variables with respect to which the operator is computed. +"""A public API for differential operators and automatic differentiation utilities. -Each differential operator has its fast version, which performs no internal -checks on input and output tensors. For these methods, the user is always -required to specify both ``components`` and ``d`` as lists of strings. +This module provides standard vector calculus operators (gradient, divergence, +laplacian, advection) implemented using automatic differentiation. It includes +both high-level general operators and optimized 'fast' variants for improved +computational efficiency during training. """ -import torch -from .label_tensor import LabelTensor - - -def _check_values(output_, input_, components, d): - """ - Perform checks on arguments of differential operators. - - :param LabelTensor output_: The output tensor on which the operator is - computed. - :param LabelTensor input_: The input tensor with respect to which the - operator is computed. - :param components: The names of the output variables for which to compute - the operator. It must be a subset of the output labels. - If ``None``, all output variables are considered. Default is ``None``. - :type components: str | list[str] - :param d: The names of the input variables with respect to which the - operator is computed. It must be a subset of the input labels. - If ``None``, all input variables are considered. Default is ``None``. - :type d: str | list[str] - :raises TypeError: If the input tensor is not a LabelTensor. - :raises TypeError: If the output tensor is not a LabelTensor. - :raises RuntimeError: If derivative labels are missing from the ``input_``. - :raises RuntimeError: If component labels are missing from the ``output_``. - :return: The components and d lists. - :rtype: tuple[list[str], list[str]] - """ - # Check if the input is a LabelTensor - if not isinstance(input_, LabelTensor): - raise TypeError("Input must be a LabelTensor.") - - # Check if the output is a LabelTensor - if not isinstance(output_, LabelTensor): - raise TypeError("Output must be a LabelTensor.") - - # If no labels are provided, use all labels - d = d or input_.labels - components = components or output_.labels - - # Convert to list if not already - d = d if isinstance(d, list) else [d] - components = components if isinstance(components, list) else [components] - - # Check if all labels are present in the input tensor - if not all(di in input_.labels for di in d): - raise RuntimeError("Derivative labels missing from input tensor.") - - # Check if all labels are present in the output tensor - if not all(c in output_.labels for c in components): - raise RuntimeError("Component label missing from output tensor.") - - return components, d - - -def _scalar_grad(output_, input_, d): - """ - Compute the gradient of a scalar-valued ``output_``. - - :param LabelTensor output_: The output tensor on which the gradient is - computed. It must be a column tensor. - :param LabelTensor input_: The input tensor with respect to which the - gradient is computed. - :param list[str] d: The names of the input variables with respect to - which the gradient is computed. It must be a subset of the input - labels. If ``None``, all input variables are considered. - :return: The computed gradient tensor. - :rtype: LabelTensor - """ - grad_out = torch.autograd.grad( - outputs=output_, - inputs=input_, - grad_outputs=torch.ones_like(output_), - create_graph=True, - retain_graph=True, - allow_unused=True, - )[0] - - return grad_out[..., [input_.labels.index(i) for i in d]] - - -def _scalar_laplacian(output_, input_, d): - """ - Compute the laplacian of a scalar-valued ``output_``. - - :param LabelTensor output_: The output tensor on which the laplacian is - computed. It must be a column tensor. - :param LabelTensor input_: The input tensor with respect to which the - laplacian is computed. - :param list[str] d: The names of the input variables with respect to - which the laplacian is computed. It must be a subset of the input - labels. If ``None``, all input variables are considered. - :return: The computed laplacian tensor. - :rtype: LabelTensor - """ - first_grad = fast_grad( - output_=output_, input_=input_, components=output_.labels, d=d - ) - second_grad = fast_grad( - output_=first_grad, input_=input_, components=first_grad.labels, d=d - ) - labels_to_extract = [f"d{c}d{d_}" for c, d_ in zip(first_grad.labels, d)] - return torch.sum( - second_grad.extract(labels_to_extract), dim=-1, keepdim=True - ) - - -def fast_grad(output_, input_, components, d): - """ - Compute the gradient of the ``output_`` with respect to the ``input``. - - Unlike ``grad``, this function performs no internal checks on input and - output tensors. The user is required to specify both ``components`` and - ``d`` as lists of strings. It is designed to enhance computation speed. - - This operator supports both vector-valued and scalar-valued functions with - one or multiple input coordinates. - - :param LabelTensor output_: The output tensor on which the gradient is - computed. - :param LabelTensor input_: The input tensor with respect to which the - gradient is computed. - :param list[str] components: The names of the output variables for which to - compute the gradient. It must be a subset of the output labels. - :param list[str] d: The names of the input variables with respect to which - the gradient is computed. It must be a subset of the input labels. - :return: The computed gradient tensor. - :rtype: LabelTensor - """ - # Scalar gradient - if output_.shape[-1] == 1: - return LabelTensor( - _scalar_grad(output_=output_, input_=input_, d=d), - labels=[f"d{output_.labels[0]}d{i}" for i in d], - ) - - # Vector gradient - grads = torch.cat( - [ - _scalar_grad(output_=output_.extract(c), input_=input_, d=d) - for c in components - ], - dim=-1, - ) - - return LabelTensor( - grads, labels=[f"d{c}d{i}" for c in components for i in d] - ) - - -def fast_div(output_, input_, components, d): - """ - Compute the divergence of the ``output_`` with respect to ``input``. - - Unlike ``div``, this function performs no internal checks on input and - output tensors. The user is required to specify both ``components`` and - ``d`` as lists of strings. It is designed to enhance computation speed. - - This operator supports vector-valued functions with multiple input - coordinates. - - :param LabelTensor output_: The output tensor on which the divergence is - computed. - :param LabelTensor input_: The input tensor with respect to which the - divergence is computed. - :param list[str] components: The names of the output variables for which to - compute the divergence. It must be a subset of the output labels. - :param list[str] d: The names of the input variables with respect to which - the divergence is computed. It must be a subset of the input labels. - :rtype: LabelTensor - """ - grad_out = fast_grad( - output_=output_, input_=input_, components=components, d=d - ) - tensors_to_sum = [ - grad_out.extract(f"d{c}d{d_}") for c, d_ in zip(components, d) - ] - - return LabelTensor.summation(tensors_to_sum) - - -def fast_laplacian(output_, input_, components, d, method="std"): - """ - Compute the laplacian of the ``output_`` with respect to ``input``. - - Unlike ``laplacian``, this function performs no internal checks on input and - output tensors. The user is required to specify both ``components`` and - ``d`` as lists of strings. It is designed to enhance computation speed. - - This operator supports both vector-valued and scalar-valued functions with - one or multiple input coordinates. - - :param LabelTensor output_: The output tensor on which the laplacian is - computed. - :param LabelTensor input_: The input tensor with respect to which the - laplacian is computed. - :param list[str] components: The names of the output variables for which to - compute the laplacian. It must be a subset of the output labels. - :param list[str] d: The names of the input variables with respect to which - the laplacian is computed. It must be a subset of the input labels. - :param str method: The method used to compute the Laplacian. Available - methods are ``std`` and ``divgrad``. The ``std`` method computes the - trace of the Hessian matrix, while the ``divgrad`` method computes the - divergence of the gradient. Default is ``std``. - :return: The computed laplacian tensor. - :rtype: LabelTensor - :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``. - """ - # Scalar laplacian - if output_.shape[-1] == 1: - return LabelTensor( - _scalar_laplacian(output_=output_, input_=input_, d=d), - labels=[f"dd{c}" for c in components], - ) - - # Initialize the result tensor and its labels - labels = [f"dd{c}" for c in components] - result = torch.empty( - input_.shape[0], len(components), device=output_.device - ) - - # Vector laplacian - if method == "std": - result = torch.cat( - [ - _scalar_laplacian( - output_=output_.extract(c), input_=input_, d=d - ) - for c in components - ], - dim=-1, - ) - - elif method == "divgrad": - grads = fast_grad( - output_=output_, input_=input_, components=components, d=d - ) - result = torch.cat( - [ - fast_div( - output_=grads, - input_=input_, - components=[f"d{c}d{i}" for i in d], - d=d, - ) - for c in components - ], - dim=-1, - ) - - else: - raise ValueError( - "Invalid method. Available methods are ``std`` and ``divgrad``." - ) - - return LabelTensor(result, labels=labels) - - -def fast_advection(output_, input_, velocity_field, components, d): - """ - Perform the advection operation on the ``output_`` with respect to the - ``input``. This operator supports vector-valued functions with multiple - input coordinates. - - Unlike ``advection``, this function performs no internal checks on input and - output tensors. The user is required to specify both ``components`` and - ``d`` as lists of strings. It is designed to enhance computation speed. - - :param LabelTensor output_: The output tensor on which the advection is - computed. It includes both the velocity and the quantity to be advected. - :param LabelTensor input_: the input tensor with respect to which advection - is computed. - :param list[str] velocity_field: The name of the output variables used as - velocity field. It must be chosen among the output labels. - :param list[str] components: The names of the output variables for which to - compute the advection. It must be a subset of the output labels. - :param list[str] d: The names of the input variables with respect to which - the advection is computed. It must be a subset of the input labels. - :return: The computed advection tensor. - :rtype: LabelTensor - """ - # Add a dimension to the velocity field for following operations - velocity = output_.extract(velocity_field).unsqueeze(-1) - - # Compute the gradient - grads = fast_grad( - output_=output_, input_=input_, components=components, d=d - ) - - # Reshape into [..., len(filter_components), len(d)] - tmp = grads.reshape(*output_.shape[:-1], len(components), len(d)) - - # Transpose to [..., len(d), len(filter_components)] - tmp = tmp.transpose(-1, -2) - - adv = (tmp * velocity).sum(dim=tmp.tensor.ndim - 2) - return LabelTensor(adv, labels=[f"adv_{c}" for c in components]) - - -def grad(output_, input_, components=None, d=None): - """ - Compute the gradient of the ``output_`` with respect to the ``input``. - - This operator supports both vector-valued and scalar-valued functions with - one or multiple input coordinates. - - :param LabelTensor output_: The output tensor on which the gradient is - computed. - :param LabelTensor input_: The input tensor with respect to which the - gradient is computed. - :param components: The names of the output variables for which to compute - the gradient. It must be a subset of the output labels. - If ``None``, all output variables are considered. Default is ``None``. - :type components: str | list[str] - :param d: The names of the input variables with respect to which the - gradient is computed. It must be a subset of the input labels. - If ``None``, all input variables are considered. Default is ``None``. - :type d: str | list[str] - :raises TypeError: If the input tensor is not a LabelTensor. - :raises TypeError: If the output tensor is not a LabelTensor. - :raises RuntimeError: If derivative labels are missing from the ``input_``. - :raises RuntimeError: If component labels are missing from the ``output_``. - :return: The computed gradient tensor. - :rtype: LabelTensor - """ - components, d = _check_values( - output_=output_, input_=input_, components=components, d=d - ) - return fast_grad(output_=output_, input_=input_, components=components, d=d) - - -def div(output_, input_, components=None, d=None): - """ - Compute the divergence of the ``output_`` with respect to ``input``. - - This operator supports vector-valued functions with multiple input - coordinates. - - :param LabelTensor output_: The output tensor on which the divergence is - computed. - :param LabelTensor input_: The input tensor with respect to which the - divergence is computed. - :param components: The names of the output variables for which to compute - the divergence. It must be a subset of the output labels. - If ``None``, all output variables are considered. Default is ``None``. - :type components: str | list[str] - :param d: The names of the input variables with respect to which the - divergence is computed. It must be a subset of the input labels. - If ``None``, all input variables are considered. Default is ``None``. - :type components: str | list[str] - :raises TypeError: If the input tensor is not a LabelTensor. - :raises TypeError: If the output tensor is not a LabelTensor. - :raises ValueError: If the length of ``components`` and ``d`` do not match. - :return: The computed divergence tensor. - :rtype: LabelTensor - """ - components, d = _check_values( - output_=output_, input_=input_, components=components, d=d - ) - - # Components and d must be of the same length - if len(components) != len(d): - raise ValueError( - "Divergence requires components and d to be of the same length." - ) - - return fast_div(output_=output_, input_=input_, components=components, d=d) - - -def laplacian(output_, input_, components=None, d=None, method="std"): - """ - Compute the laplacian of the ``output_`` with respect to ``input``. - - This operator supports both vector-valued and scalar-valued functions with - one or multiple input coordinates. - - :param LabelTensor output_: The output tensor on which the laplacian is - computed. - :param LabelTensor input_: The input tensor with respect to which the - laplacian is computed. - :param components: The names of the output variables for which to - compute the laplacian. It must be a subset of the output labels. - If ``None``, all output variables are considered. Default is ``None``. - :type components: str | list[str] - :param d: The names of the input variables with respect to which - the laplacian is computed. It must be a subset of the input labels. - If ``None``, all input variables are considered. Default is ``None``. - :type d: str | list[str] - :param str method: The method used to compute the Laplacian. Available - methods are ``std`` and ``divgrad``. The ``std`` method computes the - trace of the Hessian matrix, while the ``divgrad`` method computes the - divergence of the gradient. Default is ``std``. - :raises TypeError: If the input tensor is not a LabelTensor. - :raises TypeError: If the output tensor is not a LabelTensor. - :raises ValueError: If the passed method is neither ``std`` nor ``divgrad``. - :return: The computed laplacian tensor. - :rtype: LabelTensor - """ - components, d = _check_values( - output_=output_, input_=input_, components=components, d=d - ) - - return fast_laplacian( - output_=output_, - input_=input_, - components=components, - d=d, - method=method, - ) - - -def advection(output_, input_, velocity_field, components=None, d=None): - """ - Perform the advection operation on the ``output_`` with respect to the - ``input``. This operator supports vector-valued functions with multiple - input coordinates. - - :param LabelTensor output_: The output tensor on which the advection is - computed. It includes both the velocity and the quantity to be advected. - :param LabelTensor input_: the input tensor with respect to which advection - is computed. - :param velocity_field: The name of the output variables used as velocity - field. It must be chosen among the output labels. - :type velocity_field: str | list[str] - :param components: The names of the output variables for which to compute - the advection. It must be a subset of the output labels. - If ``None``, all output variables are considered. Default is ``None``. - :type components: str | list[str] - :param d: The names of the input variables with respect to which the - advection is computed. It must be a subset of the input labels. - If ``None``, all input variables are considered. Default is ``None``. - :type d: str | list[str] - :raises TypeError: If the input tensor is not a LabelTensor. - :raises TypeError: If the output tensor is not a LabelTensor. - :raises RuntimeError: If the velocity field is not a subset of the output - labels. - :raises RuntimeError: If the dimensionality of the velocity field does not - match that of the input tensor. - :return: The computed advection tensor. - :rtype: LabelTensor - """ - components, d = _check_values( - output_=output_, input_=input_, components=components, d=d - ) - - # Map velocity_field to a list if it is a string - if isinstance(velocity_field, str): - velocity_field = [velocity_field] - - # Check if all the velocity_field labels are present in the output labels - if not all(vi in output_.labels for vi in velocity_field): - raise RuntimeError("Velocity labels missing from output tensor.") - - # Check if the velocity has the same dimensionality as the input tensor - if len(velocity_field) != len(d): - raise RuntimeError( - "Velocity dimensionality does not match input dimensionality." - ) - - return fast_advection( - output_=output_, - input_=input_, - velocity_field=velocity_field, - components=components, - d=d, - ) +from pina._src.core.operator import ( + grad, + fast_grad, + fast_div, + fast_laplacian, + fast_advection, + div, + laplacian, + advection, +) + +__all__ = [ + "grad", + "fast_grad", + "fast_div", + "fast_laplacian", + "fast_advection", + "div", + "laplacian", + "advection", +] diff --git a/pina/optim/__init__.py b/pina/optim/__init__.py index 8266c8ca1..682b6225e 100644 --- a/pina/optim/__init__.py +++ b/pina/optim/__init__.py @@ -7,7 +7,7 @@ "TorchScheduler", ] -from .optimizer_interface import Optimizer -from .torch_optimizer import TorchOptimizer -from .scheduler_interface import Scheduler -from .torch_scheduler import TorchScheduler +from pina._src.optim.optimizer_interface import Optimizer +from pina._src.optim.torch_optimizer import TorchOptimizer +from pina._src.optim.scheduler_interface import Scheduler +from pina._src.optim.torch_scheduler import TorchScheduler diff --git a/pina/problem/__init__.py b/pina/problem/__init__.py index e95f99703..b170bec21 100644 --- a/pina/problem/__init__.py +++ b/pina/problem/__init__.py @@ -8,8 +8,8 @@ "InverseProblem", ] -from .abstract_problem import AbstractProblem -from .spatial_problem import SpatialProblem -from .time_dependent_problem import TimeDependentProblem -from .parametric_problem import ParametricProblem -from .inverse_problem import InverseProblem +from pina._src.problem.abstract_problem import AbstractProblem +from pina._src.problem.spatial_problem import SpatialProblem +from pina._src.problem.time_dependent_problem import TimeDependentProblem +from pina._src.problem.parametric_problem import ParametricProblem +from pina._src.problem.inverse_problem import InverseProblem diff --git a/pina/problem/zoo.py b/pina/problem/zoo.py new file mode 100644 index 000000000..e5c23ae81 --- /dev/null +++ b/pina/problem/zoo.py @@ -0,0 +1,23 @@ +"""Module for implemented problems.""" + +__all__ = [ + "SupervisedProblem", + "HelmholtzProblem", + "AllenCahnProblem", + "AdvectionProblem", + "Poisson2DSquareProblem", + "DiffusionReactionProblem", + "InversePoisson2DSquareProblem", + "AcousticWaveProblem", +] + +from pina._src.problem.zoo.supervised_problem import SupervisedProblem +from pina._src.problem.zoo.helmholtz import HelmholtzProblem +from pina._src.problem.zoo.allen_cahn import AllenCahnProblem +from pina._src.problem.zoo.advection import AdvectionProblem +from pina._src.problem.zoo.poisson_2d_square import Poisson2DSquareProblem +from pina._src.problem.zoo.diffusion_reaction import DiffusionReactionProblem +from pina._src.problem.zoo.inverse_poisson_2d_square import ( + InversePoisson2DSquareProblem, +) +from pina._src.problem.zoo.acoustic_wave import AcousticWaveProblem diff --git a/pina/problem/zoo/__init__.py b/pina/problem/zoo/__init__.py deleted file mode 100644 index 73e3ad9b6..000000000 --- a/pina/problem/zoo/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Module for implemented problems.""" - -__all__ = [ - "SupervisedProblem", - "HelmholtzProblem", - "AllenCahnProblem", - "AdvectionProblem", - "Poisson2DSquareProblem", - "DiffusionReactionProblem", - "InversePoisson2DSquareProblem", - "AcousticWaveProblem", -] - -from .supervised_problem import SupervisedProblem -from .helmholtz import HelmholtzProblem -from .allen_cahn import AllenCahnProblem -from .advection import AdvectionProblem -from .poisson_2d_square import Poisson2DSquareProblem -from .diffusion_reaction import DiffusionReactionProblem -from .inverse_poisson_2d_square import InversePoisson2DSquareProblem -from .acoustic_wave import AcousticWaveProblem diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 43f18078f..327056035 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -1,4 +1,13 @@ -"""Module for the solver classes.""" +""" +Unified solvers for Physics-Informed and Data-Driven modeling. + +This module provides the high-level training orchestrators used to solve +differential equations and regression problems. It includes: +* **Physics-Informed Solvers**: Standard PINN, Gradient-enhanced (gPINN), Causal, + and Self-Adaptive variants. +* **Supervised Solvers**: For purely data-driven tasks and Reduced Order Modeling. +* **Ensemble Solvers**: For uncertainty quantification via Deep Ensembles. +""" __all__ = [ "SolverInterface", @@ -20,24 +29,33 @@ "GAROM", ] -from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface -from .physics_informed_solver import ( - PINNInterface, - PINN, - GradientPINN, - CausalPINN, +from pina._src.solver.solver import ( + SolverInterface, + SingleSolverInterface, + MultiSolverInterface, +) +from pina._src.solver.physics_informed_solver.pinn import PINNInterface, PINN +from pina._src.solver.physics_informed_solver.gradient_pinn import GradientPINN +from pina._src.solver.physics_informed_solver.causal_pinn import CausalPINN +from pina._src.solver.physics_informed_solver.competitive_pinn import ( CompetitivePINN, +) +from pina._src.solver.physics_informed_solver.self_adaptive_pinn import ( SelfAdaptivePINN, - RBAPINN, ) -from .supervised_solver import ( +from pina._src.solver.physics_informed_solver.rba_pinn import RBAPINN +from pina._src.solver.supervised_solver.supervised_solver_interface import ( + SupervisedSolverInterface, +) + +from pina._src.solver.supervised_solver import ( SupervisedSolverInterface, SupervisedSolver, ReducedOrderModelSolver, ) -from .ensemble_solver import ( +from pina._src.solver.ensemble_solver import ( DeepEnsembleSolverInterface, DeepEnsembleSupervisedSolver, DeepEnsemblePINN, ) -from .garom import GAROM +from pina._src.solver.garom import GAROM diff --git a/pina/solver/ensemble_solver/__init__.py b/pina/solver/ensemble_solver/__init__.py deleted file mode 100644 index 0e4eab54b..000000000 --- a/pina/solver/ensemble_solver/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Module for the Ensemble solver classes.""" - -__all__ = [ - "DeepEnsembleSolverInterface", - "DeepEnsembleSupervisedSolver", - "DeepEnsemblePINN", -] - -from .ensemble_solver_interface import DeepEnsembleSolverInterface -from .ensemble_supervised import DeepEnsembleSupervisedSolver -from .ensemble_pinn import DeepEnsemblePINN diff --git a/pina/solver/physics_informed_solver/__init__.py b/pina/solver/physics_informed_solver/__init__.py deleted file mode 100644 index f0fb8ebcd..000000000 --- a/pina/solver/physics_informed_solver/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Module for the Physics-Informed solvers.""" - -__all__ = [ - "PINNInterface", - "PINN", - "GradientPINN", - "CausalPINN", - "CompetitivePINN", - "SelfAdaptivePINN", - "RBAPINN", -] - -from .pinn_interface import PINNInterface -from .pinn import PINN -from .rba_pinn import RBAPINN -from .causal_pinn import CausalPINN -from .gradient_pinn import GradientPINN -from .competitive_pinn import CompetitivePINN -from .self_adaptive_pinn import SelfAdaptivePINN diff --git a/pina/trainer.py b/pina/trainer.py index e92928d1e..5cd598b4c 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,362 +1,5 @@ -"""Module for the Trainer.""" +"""Public API for Trainer.""" -import sys -import warnings -import torch -import lightning -from .utils import check_consistency, custom_warning_format -from .data import PinaDataModule -from .solver import SolverInterface, PINNInterface +from pina._src.core.trainer import Trainer -# set the warning for compile options -warnings.formatwarning = custom_warning_format -warnings.filterwarnings("always", category=UserWarning) - - -class Trainer(lightning.pytorch.Trainer): - """ - PINA custom Trainer class to extend the standard Lightning functionality. - - This class enables specific features or behaviors required by the PINA - framework. It modifies the standard - :class:`lightning.pytorch.Trainer ` - class to better support the training process in PINA. - """ - - def __init__( - self, - solver, - batch_size=None, - train_size=1.0, - test_size=0.0, - val_size=0.0, - compile=None, - repeat=None, - automatic_batching=None, - num_workers=None, - pin_memory=None, - shuffle=None, - **kwargs, - ): - """ - Initialization of the :class:`Trainer` class. - - :param SolverInterface solver: A - :class:`~pina.solver.solver.SolverInterface` solver used to solve a - :class:`~pina.problem.abstract_problem.AbstractProblem`. - :param int batch_size: The number of samples per batch to load. - If ``None``, all samples are loaded and data is not batched. - Default is ``None``. - :param float train_size: The percentage of elements to include in the - training dataset. Default is ``1.0``. - :param float test_size: The percentage of elements to include in the - test dataset. Default is ``0.0``. - :param float val_size: The percentage of elements to include in the - validation dataset. Default is ``0.0``. - :param bool compile: If ``True``, the model is compiled before training. - Default is ``False``. For Windows users, it is always disabled. Not - supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is - ``False``. - :param bool automatic_batching: If ``True``, automatic PyTorch batching - is performed, otherwise the items are retrieved from the dataset - all at once. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is - ``False``. - :param int num_workers: The number of worker threads for data loading. - Default is ``0`` (serial loading). - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. Default is ``False``. - :param bool shuffle: Whether to shuffle the data during training. - Default is ``True``. - :param dict kwargs: Additional keyword arguments that specify the - training setup. These can be selected from the `pytorch-lightning - Trainer API - `_. - """ - # check consistency for init types - self._check_input_consistency( - solver=solver, - train_size=train_size, - test_size=test_size, - val_size=val_size, - repeat=repeat, - automatic_batching=automatic_batching, - compile=compile, - ) - pin_memory, num_workers, shuffle, batch_size = ( - self._check_consistency_and_set_defaults( - pin_memory, num_workers, shuffle, batch_size - ) - ) - - # inference mode set to false when validating/testing PINNs otherwise - # gradient is not tracked and optimization_cycle fails - if isinstance(solver, PINNInterface): - kwargs["inference_mode"] = False - - # Logging depends on the batch size, when batch_size is None then - # log_every_n_steps should be zero - if batch_size is None: - kwargs["log_every_n_steps"] = 0 - else: - kwargs.setdefault("log_every_n_steps", 50) # default for lightning - - # Setting default kwargs, overriding lightning defaults - kwargs.setdefault("enable_progress_bar", True) - - super().__init__(**kwargs) - - # checking compilation and automatic batching - # compilation disabled for Windows and for Python 3.14+ - if ( - compile is None - or sys.platform == "win32" - or sys.version_info >= (3, 14) - ): - compile = False - warnings.warn( - "Compilation is disabled for Python 3.14+ and for Windows.", - UserWarning, - ) - - repeat = repeat if repeat is not None else False - - automatic_batching = ( - automatic_batching if automatic_batching is not None else False - ) - - # set attributes - self.compile = compile - self.solver = solver - self.batch_size = batch_size - self._move_to_device() - self.data_module = None - self._create_datamodule( - train_size=train_size, - test_size=test_size, - val_size=val_size, - batch_size=batch_size, - repeat=repeat, - automatic_batching=automatic_batching, - pin_memory=pin_memory, - num_workers=num_workers, - shuffle=shuffle, - ) - - # logging - self.logging_kwargs = { - "sync_dist": bool( - len(self._accelerator_connector._parallel_devices) > 1 - ), - "on_step": bool(kwargs["log_every_n_steps"] > 0), - "prog_bar": bool(kwargs["enable_progress_bar"]), - "on_epoch": True, - } - - def _move_to_device(self): - """ - Moves the ``unknown_parameters`` of an instance of - :class:`~pina.problem.abstract_problem.AbstractProblem` to the - :class:`Trainer` device. - """ - device = self._accelerator_connector._parallel_devices[0] - # move parameters to device - pb = self.solver.problem - if hasattr(pb, "unknown_parameters"): - for key in pb.unknown_parameters: - pb.unknown_parameters[key] = torch.nn.Parameter( - pb.unknown_parameters[key].data.to(device) - ) - - def _create_datamodule( - self, - train_size, - test_size, - val_size, - batch_size, - repeat, - automatic_batching, - pin_memory, - num_workers, - shuffle, - ): - """ - This method is designed to handle the creation of a data module when - resampling is needed during training. Instead of manually defining and - modifying the trainer's dataloaders, this method is called to - automatically configure the data module. - - :param float train_size: The percentage of elements to include in the - training dataset. - :param float test_size: The percentage of elements to include in the - test dataset. - :param float val_size: The percentage of elements to include in the - validation dataset. - :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. - :param bool automatic_batching: Whether to perform automatic batching - with PyTorch. - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. - :param int num_workers: The number of worker threads for data loading. - :param bool shuffle: Whether to shuffle the data during training. - :raises RuntimeError: If not all conditions are sampled. - """ - if not self.solver.problem.are_all_domains_discretised: - error_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.solver.problem.discretised_domains - else - "not sampled"}""" - for key in self.solver.problem.domains.keys() - ] - ) - raise RuntimeError( - "Cannot create Trainer if not all conditions " - "are sampled. The Trainer got the following:\n" - f"{error_message}" - ) - self.data_module = PinaDataModule( - self.solver.problem, - train_size=train_size, - test_size=test_size, - val_size=val_size, - batch_size=batch_size, - repeat=repeat, - automatic_batching=automatic_batching, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - ) - - def train(self, **kwargs): - """ - Manage the training process of the solver. - - :param dict kwargs: Additional keyword arguments. See `pytorch-lightning - Trainer API `_ - for details. - """ - return super().fit(self.solver, datamodule=self.data_module, **kwargs) - - def test(self, **kwargs): - """ - Manage the test process of the solver. - - :param dict kwargs: Additional keyword arguments. See `pytorch-lightning - Trainer API `_ - for details. - """ - return super().test(self.solver, datamodule=self.data_module, **kwargs) - - @property - def solver(self): - """ - Get the solver. - - :return: The solver. - :rtype: SolverInterface - """ - return self._solver - - @solver.setter - def solver(self, solver): - """ - Set the solver. - - :param SolverInterface solver: The solver to set. - """ - self._solver = solver - - @staticmethod - def _check_input_consistency( - solver, - train_size, - test_size, - val_size, - repeat, - automatic_batching, - compile, - ): - """ - Verifies the consistency of the parameters for the solver configuration. - - :param SolverInterface solver: The solver. - :param float train_size: The percentage of elements to include in the - training dataset. - :param float test_size: The percentage of elements to include in the - test dataset. - :param float val_size: The percentage of elements to include in the - validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. - :param bool automatic_batching: Whether to perform automatic batching - with PyTorch. - :param bool compile: If ``True``, the model is compiled before training. - """ - - check_consistency(solver, SolverInterface) - check_consistency(train_size, float) - check_consistency(test_size, float) - check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) - if automatic_batching is not None: - check_consistency(automatic_batching, bool) - if compile is not None: - check_consistency(compile, bool) - - @staticmethod - def _check_consistency_and_set_defaults( - pin_memory, num_workers, shuffle, batch_size - ): - """ - Checks the consistency of input parameters and sets default values - for missing or invalid parameters. - - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. - :param int num_workers: The number of worker threads for data loading. - :param bool shuffle: Whether to shuffle the data during training. - :param int batch_size: The number of samples per batch to load. - """ - if pin_memory is not None: - check_consistency(pin_memory, bool) - else: - pin_memory = False - if num_workers is not None: - check_consistency(num_workers, int) - else: - num_workers = 0 - if shuffle is not None: - check_consistency(shuffle, bool) - else: - shuffle = True - if batch_size is not None: - check_consistency(batch_size, int) - return pin_memory, num_workers, shuffle, batch_size - - @property - def compile(self): - """ - Whether compilation is required or not. - - :return: ``True`` if compilation is required, ``False`` otherwise. - :rtype: bool - """ - return self._compile - - @compile.setter - def compile(self, value): - """ - Setting the value of compile. - - :param bool value: Whether compilation is required or not. - """ - check_consistency(value, bool) - self._compile = value +__all__ = ["Trainer"] diff --git a/tests/test_adaptive_function.py b/tests/test_adaptive_function.py index bce5059d7..fae547ffb 100644 --- a/tests/test_adaptive_function.py +++ b/tests/test_adaptive_function.py @@ -16,7 +16,6 @@ AdaptiveExp, ) - adaptive_function = ( AdaptiveReLU, AdaptiveSigmoid, diff --git a/tests/test_block/test_low_rank_block.py b/tests/test_block/test_low_rank_block.py index 0e6ddcb89..17f0dabd6 100644 --- a/tests/test_block/test_low_rank_block.py +++ b/tests/test_block/test_low_rank_block.py @@ -4,7 +4,6 @@ from pina.model.block import LowRankBlock from pina import LabelTensor - input_dimensions = 2 embedding_dimenion = 1 rank = 4 diff --git a/tests/test_callback/test_metric_tracker.py b/tests/test_callback/test_metric_tracker.py index 062664b79..49b904885 100644 --- a/tests/test_callback/test_metric_tracker.py +++ b/tests/test_callback/test_metric_tracker.py @@ -4,7 +4,6 @@ from pina.callback import MetricTracker from pina.problem.zoo import Poisson2DSquareProblem as Poisson - # make the problem poisson_problem = Poisson() n = 10 diff --git a/tests/test_callback/test_pina_progress_bar.py b/tests/test_callback/test_pina_progress_bar.py index ec7129852..8956ebaf0 100644 --- a/tests/test_callback/test_pina_progress_bar.py +++ b/tests/test_callback/test_pina_progress_bar.py @@ -4,7 +4,6 @@ from pina.callback import PINAProgressBar from pina.problem.zoo import Poisson2DSquareProblem as Poisson - # make the problem poisson_problem = Poisson() n = 10 diff --git a/tests/test_callback/test_r3_refinement.py b/tests/test_callback/test_r3_refinement.py index 191266ee1..f8b9519e9 100644 --- a/tests/test_callback/test_r3_refinement.py +++ b/tests/test_callback/test_r3_refinement.py @@ -6,7 +6,6 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson from pina.callback import R3Refinement - # make the problem poisson_problem = Poisson() poisson_problem.discretise_domain(10, "grid", domains="boundary") diff --git a/tests/test_callback/test_switch_optimizer.py b/tests/test_callback/test_switch_optimizer.py index 3383c792c..c7490a231 100644 --- a/tests/test_callback/test_switch_optimizer.py +++ b/tests/test_callback/test_switch_optimizer.py @@ -8,7 +8,6 @@ from pina.callback import SwitchOptimizer from pina.problem.zoo import Poisson2DSquareProblem as Poisson - # Define the problem problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_callback/test_switch_scheduler.py b/tests/test_callback/test_switch_scheduler.py index df91f0c59..36b177853 100644 --- a/tests/test_callback/test_switch_scheduler.py +++ b/tests/test_callback/test_switch_scheduler.py @@ -8,7 +8,6 @@ from pina.callback import SwitchScheduler from pina.problem.zoo import Poisson2DSquareProblem as Poisson - # Define the problem problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_label_tensor/test_label_tensor.py b/tests/test_label_tensor/test_label_tensor.py index 973864d0e..ca4ae2f1a 100644 --- a/tests/test_label_tensor/test_label_tensor.py +++ b/tests/test_label_tensor/test_label_tensor.py @@ -1,7 +1,7 @@ import torch import pytest -from pina.label_tensor import LabelTensor +from pina import LabelTensor data = torch.rand((20, 3)) labels_column = {1: {"name": "space", "dof": ["x", "y", "z"]}} diff --git a/tests/test_model/test_average_neural_operator.py b/tests/test_model/test_average_neural_operator.py index ded81c43d..4a7ecb44b 100644 --- a/tests/test_model/test_average_neural_operator.py +++ b/tests/test_model/test_average_neural_operator.py @@ -3,7 +3,6 @@ from pina import LabelTensor import pytest - batch_size = 15 n_layers = 4 embedding_dim = 24 diff --git a/tests/test_model/test_low_rank_neural_operator.py b/tests/test_model/test_low_rank_neural_operator.py index 3702df91b..ba4b2fffe 100644 --- a/tests/test_model/test_low_rank_neural_operator.py +++ b/tests/test_model/test_low_rank_neural_operator.py @@ -3,7 +3,6 @@ from pina import LabelTensor import pytest - batch_size = 15 n_layers = 4 embedding_dim = 24 diff --git a/tests/test_model/test_spline.py b/tests/test_model/test_spline.py index b47ea8d30..baff81940 100644 --- a/tests/test_model/test_spline.py +++ b/tests/test_model/test_spline.py @@ -5,7 +5,6 @@ from pina.model import Spline from pina import LabelTensor - # Utility quantities for testing order = torch.randint(3, 6, (1,)).item() n_ctrl_pts = torch.randint(order, order + 5, (1,)).item() diff --git a/tests/test_model/test_spline_surface.py b/tests/test_model/test_spline_surface.py index dee57173c..4cd6dc3aa 100644 --- a/tests/test_model/test_spline_surface.py +++ b/tests/test_model/test_spline_surface.py @@ -5,7 +5,6 @@ from pina.operator import grad from pina import LabelTensor - # Utility quantities for testing orders = [random.randint(3, 6) for _ in range(2)] n_ctrl_pts = random.randint(max(orders), max(orders) + 5) diff --git a/tests/test_solver/test_competitive_pinn.py b/tests/test_solver/test_competitive_pinn.py index 8f585f029..67902197a 100644 --- a/tests/test_solver/test_competitive_pinn.py +++ b/tests/test_solver/test_competitive_pinn.py @@ -16,7 +16,6 @@ ) from torch._dynamo.eval_frame import OptimizedModule - # define problems problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_solver/test_ensemble_pinn.py b/tests/test_solver/test_ensemble_pinn.py index 50669f00e..e34ad3643 100644 --- a/tests/test_solver/test_ensemble_pinn.py +++ b/tests/test_solver/test_ensemble_pinn.py @@ -13,7 +13,6 @@ from pina.problem.zoo import Poisson2DSquareProblem as Poisson from torch._dynamo.eval_frame import OptimizedModule - # define problems problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_solver/test_pinn.py b/tests/test_solver/test_pinn.py index d726047ef..76094b473 100644 --- a/tests/test_solver/test_pinn.py +++ b/tests/test_solver/test_pinn.py @@ -16,7 +16,6 @@ ) from torch._dynamo.eval_frame import OptimizedModule - # define problems problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_solver/test_self_adaptive_pinn.py b/tests/test_solver/test_self_adaptive_pinn.py index b2d1361ca..244f10d4f 100644 --- a/tests/test_solver/test_self_adaptive_pinn.py +++ b/tests/test_solver/test_self_adaptive_pinn.py @@ -16,7 +16,6 @@ ) from torch._dynamo.eval_frame import OptimizedModule - # define problems problem = Poisson() problem.discretise_domain(10) diff --git a/tests/test_weighting/test_linear_weighting.py b/tests/test_weighting/test_linear_weighting.py index a11952073..db5e8a9ac 100644 --- a/tests/test_weighting/test_linear_weighting.py +++ b/tests/test_weighting/test_linear_weighting.py @@ -6,7 +6,6 @@ from pina.loss import LinearWeighting from pina.problem.zoo import Poisson2DSquareProblem - # Initialize problem and model problem = Poisson2DSquareProblem() problem.discretise_domain(10) diff --git a/tests/test_weighting/test_ntk_weighting.py b/tests/test_weighting/test_ntk_weighting.py index 49442b9fb..f908ae538 100644 --- a/tests/test_weighting/test_ntk_weighting.py +++ b/tests/test_weighting/test_ntk_weighting.py @@ -5,7 +5,6 @@ from pina.loss import NeuralTangentKernelWeighting from pina.problem.zoo import Poisson2DSquareProblem - # Initialize problem and model problem = Poisson2DSquareProblem() problem.discretise_domain(10) diff --git a/tests/test_weighting/test_scalar_weighting.py b/tests/test_weighting/test_scalar_weighting.py index bbf71afde..395cdbcc0 100644 --- a/tests/test_weighting/test_scalar_weighting.py +++ b/tests/test_weighting/test_scalar_weighting.py @@ -6,7 +6,6 @@ from pina.loss import ScalarWeighting from pina.problem.zoo import Poisson2DSquareProblem - # Initialize problem and model problem = Poisson2DSquareProblem() problem.discretise_domain(50) diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py index 066e8855e..e11aff14c 100644 --- a/tests/test_weighting/test_self_adaptive_weighting.py +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -5,7 +5,6 @@ from pina.loss import SelfAdaptiveWeighting from pina.problem.zoo import Poisson2DSquareProblem - # Initialize problem and model problem = Poisson2DSquareProblem() problem.discretise_domain(10)