diff --git a/data/tests/images/IC86lower_deepcore_test.npy b/data/tests/images/IC86lower_deepcore_test.npy new file mode 100644 index 000000000..178a09047 Binary files /dev/null and b/data/tests/images/IC86lower_deepcore_test.npy differ diff --git a/data/tests/images/IC86main_array_test.npy b/data/tests/images/IC86main_array_test.npy new file mode 100644 index 000000000..628cbfd71 Binary files /dev/null and b/data/tests/images/IC86main_array_test.npy differ diff --git a/data/tests/images/IC86upper_deepcore_test.npy b/data/tests/images/IC86upper_deepcore_test.npy new file mode 100644 index 000000000..24a3cc697 Binary files /dev/null and b/data/tests/images/IC86upper_deepcore_test.npy differ diff --git a/examples/04_training/09_train_cnn.py b/examples/04_training/09_train_cnn.py new file mode 100644 index 000000000..3b9f450df --- /dev/null +++ b/examples/04_training/09_train_cnn.py @@ -0,0 +1,324 @@ +"""Example of training a CNN Model.""" + +import os +from typing import Any, Dict, List, Optional + +from pytorch_lightning.loggers import WandbLogger +import torch +from torch.optim.adam import Adam + +from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR +from graphnet.data.constants import TRUTH +from graphnet.models import StandardModel +from graphnet.models.cnn import LCSC +from graphnet.models.data_representation import PercentileClusters +from graphnet.models.task.reconstruction import EnergyReconstruction +from graphnet.training.callbacks import PiecewiseLinearLR +from graphnet.training.loss_functions import LogCoshLoss +from graphnet.utilities.argparse import ArgumentParser +from graphnet.utilities.logging import Logger +from graphnet.data.dataset import SQLiteDataset +from graphnet.data.dataset import ParquetDataset +from graphnet.models.detector import ORCA150 +from torch_geometric.data import Batch +from graphnet.models.data_representation.images import ExamplePrometheusImage + +# Constants +features = ["sensor_id", "sensor_string_id", "t"] +truth = TRUTH.PROMETHEUS + + +def main( + path: str, + pulsemap: str, + target: str, + truth_table: str, + gpus: Optional[List[int]], + max_epochs: int, + early_stopping_patience: int, + batch_size: int, + num_workers: int, + wandb: bool = False, +) -> None: + """Run example.""" + # Construct Logger + logger = Logger() + + # Initialise Weights & Biases (W&B) run + if wandb: + # Make sure W&B output directory exists + wandb_dir = "./wandb/" + os.makedirs(wandb_dir, exist_ok=True) + wandb_logger = WandbLogger( + project="example-script", + entity="graphnet-team", + save_dir=wandb_dir, + log_model=True, + ) + + logger.info(f"features: {features}") + logger.info(f"truth: {truth}") + + # Configuration + config: Dict[str, Any] = { + "path": path, + "pulsemap": pulsemap, + "batch_size": batch_size, + "num_workers": num_workers, + "target": target, + "early_stopping_patience": early_stopping_patience, + "fit": { + "gpus": gpus, + "max_epochs": max_epochs, + }, + "dataset_reference": ( + SQLiteDataset if path.endswith(".db") else ParquetDataset + ), + } + + archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_cnn_model") + run_name = "lcsc_{}_example".format(config["target"]) + if wandb: + # Log configuration to W&B + wandb_logger.experiment.config.update(config) + + # Build pulses → pixels → image tensors via `ImageRepresentation`. + # 1) `pixel_definition`: aggregates pulses (here per optical module). + # 2) `grid_definition` (inside `ExamplePrometheusImage`): detector grid + # layouts and scatter into shaped tensors. + # Multiple tensors per event are supported (e.g. IC86 main + DeepCore). + + # Here we use the PercentileClusters pixel definition, which + # aggregates the light pulses that arrive at the same optical + # module with percentiles. + print(features) + pixel_definition = PercentileClusters( + cluster_on=["sensor_id", "sensor_string_id"], + percentiles=[10, 50, 90], + add_counts=True, + input_feature_names=features, + ) + + # `ExamplePrometheusImage` wires a Prometheus `GridDefinition` + # for the example layout. + # It maps optical modules into the image + # using the sensor_string_id and sensor_id + # (number of the optical module). + # The detector class standardizes the input features, + # so that the features are in a ML friendly range. + # For the mapping of the optical modules to the image it is + # essential to not change the value of the sensor_id and + # sensor_string_id. Therefore we need to make sure that + # these features are not standardized, which is done by the + # `replace_with_identity` argument of the detector. + image_representation = ExamplePrometheusImage( + detector=ORCA150( + replace_with_identity=[ + "sensor_id", + "sensor_string_id", + ], + ), + pixel_definition=pixel_definition, + input_feature_names=features, + string_label="sensor_string_id", + dom_number_label="sensor_id", + ) + + # Use SQLiteDataset to load in data + # The input here depends on the dataset being used, + # in this case the Prometheus dataset. + dataset = SQLiteDataset( + path=config["path"], + pulsemaps=config["pulsemap"], + truth_table=truth_table, + features=features, + truth=truth, + data_representation=image_representation, + ) + + # Create the training and validation dataloaders. + training_dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + collate_fn=Batch.from_data_list, + ) + + validation_dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config["batch_size"], + num_workers=config["num_workers"], + collate_fn=Batch.from_data_list, + ) + + # Building model + + # LCSC spatial size must match this `ImageRepresentation`'s grid. + image_size = image_representation.single_image_spatial_shape() + + # Define architecture of the backbone, in this example + # the LCSC architecture from Alexander Harnisch is used. + backbone = LCSC( + num_input_features=image_representation.nb_outputs, + out_put_dim=2, + input_norm=True, + num_conv_layers=5, + conv_filters=[5, 10, 20, 40, 60], + kernel_size=3, + image_size=image_size, + pooling_type=[ + "Avg", + None, + "Avg", + None, + "Avg", + ], + pooling_kernel_size=[ + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + pooling_stride=[ + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + num_fc_neurons=50, + norm_list=True, + norm_type="Batch", + ) + # Define the task. + # Here an energy reconstruction, with a LogCoshLoss function. + # The target and prediction are transformed using the log10 function. + # When infering the prediction is transformed back to the + # original scale using 10^x. + task = EnergyReconstruction( + hidden_size=backbone.nb_outputs, + target_labels=config["target"], + loss_function=LogCoshLoss(), + transform_prediction_and_target=lambda x: torch.log10(x), + transform_inference=lambda x: torch.pow(10, x), + ) + # Define the full model, which includes the backbone, task(s), + # along with typical machine learning options such as + # learning rate optimizers and schedulers. + model = StandardModel( + data_representation=image_representation, + backbone=backbone, + tasks=[task], + optimizer_class=Adam, + optimizer_kwargs={"lr": 1e-03, "eps": 1e-03}, + scheduler_class=PiecewiseLinearLR, + scheduler_kwargs={ + "milestones": [ + 0, + len(training_dataloader) / 2, + len(training_dataloader) * config["fit"]["max_epochs"], + ], + "factors": [1e-2, 1, 1e-02], + }, + scheduler_config={ + "interval": "step", + }, + ) + + # Training model + model.fit( + training_dataloader, + validation_dataloader, + early_stopping_patience=config["early_stopping_patience"], + logger=wandb_logger if wandb else None, + **config["fit"], + ) + + # Get predictions + additional_attributes = model.target_labels + assert isinstance(additional_attributes, list) # mypy + + results = model.predict_as_dataframe( + validation_dataloader, + additional_attributes=additional_attributes + ["event_no"], + gpus=config["fit"]["gpus"], + ) + + # Save predictions and model to file + db_name = path.split("/")[-1].split(".")[0] + path = os.path.join(archive, db_name, run_name) + logger.info(f"Writing results to {path}") + os.makedirs(path, exist_ok=True) + + # Save results as .csv + results.to_csv(f"{path}/cnn_results.csv") + + # Save model config and state dict - Version safe save method. + # This method of saving models is the safest way. + model.save_state_dict(f"{path}/cnn_state_dict.pth") + model.save_config(f"{path}/cnn_model_config.yml") + + +if __name__ == "__main__": + + # Parse command-line arguments + parser = ArgumentParser(description=""" +Train GNN model without the use of config files. +""") + + parser.add_argument( + "--path", + help="Path to dataset file (default: %(default)s)", + default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db", + ) + + parser.add_argument( + "--pulsemap", + help="Name of pulsemap to use (default: %(default)s)", + default="total", + ) + + parser.add_argument( + "--target", + help=( + "Name of feature to use as regression target (default: " + "%(default)s)" + ), + default="total_energy", + ) + + parser.add_argument( + "--truth-table", + help="Name of truth table to be used (default: %(default)s)", + default="mc_truth", + ) + + parser.with_standard_arguments( + "gpus", + ("max-epochs", 1), + "early-stopping-patience", + ("batch-size", 16), + ("num-workers", 2), + ) + + parser.add_argument( + "--wandb", + action="store_true", + help="If True, Weights & Biases are used to track the experiment.", + ) + + args, unknown = parser.parse_known_args() + + main( + args.path, + args.pulsemap, + args.target, + args.truth_table, + args.gpus, + args.max_epochs, + args.early_stopping_patience, + args.batch_size, + args.num_workers, + args.wandb, + ) diff --git a/src/graphnet/constants.py b/src/graphnet/constants.py index 5e5717f78..a069fbfd8 100644 --- a/src/graphnet/constants.py +++ b/src/graphnet/constants.py @@ -21,6 +21,14 @@ TEST_PARQUET_DATA = os.path.join( TEST_DATA_DIR, "parquet", _test_dataset_name, "merged" ) +TEST_IMAGE_DIR = os.path.join(TEST_DATA_DIR, "images") +TEST_IC86MAIN_IMAGE = os.path.join(TEST_IMAGE_DIR, "IC86main_array_test.npy") +TEST_IC86LOWERDC_IMAGE = os.path.join( + TEST_IMAGE_DIR, "IC86lower_deepcore_test.npy" +) +TEST_IC86UPPERDC_IMAGE = os.path.join( + TEST_IMAGE_DIR, "IC86upper_deepcore_test.npy" +) # Example data EXAMPLE_DATA_DIR = os.path.join(DATA_DIR, "examples") diff --git a/src/graphnet/models/cnn/__init__.py b/src/graphnet/models/cnn/__init__.py new file mode 100644 index 000000000..d44dd9f83 --- /dev/null +++ b/src/graphnet/models/cnn/__init__.py @@ -0,0 +1,5 @@ +"""CNN-specific modules, for performing the main learnable operations.""" + +from .cnn import CNN +from .icecube_dnn import IceCubeDNN +from .lcsc import LCSC diff --git a/src/graphnet/models/cnn/cnn.py b/src/graphnet/models/cnn/cnn.py new file mode 100644 index 000000000..2453790e4 --- /dev/null +++ b/src/graphnet/models/cnn/cnn.py @@ -0,0 +1,35 @@ +"""Base CNN-specific `Model` class(es).""" + +from abc import abstractmethod + +from torch import Tensor +from torch_geometric.data import Data + +from graphnet.models import Model + + +class CNN(Model): + """Base class for all core CNN models in graphnet.""" + + def __init__(self, nb_inputs: int, nb_outputs: int) -> None: + """Construct `CNN`.""" + # Base class constructor + super().__init__() + + # Member variables + self._nb_inputs = nb_inputs + self._nb_outputs = nb_outputs + + @property + def nb_inputs(self) -> int: + """Return number of input features.""" + return self._nb_inputs + + @property + def nb_outputs(self) -> int: + """Return number of output features.""" + return self._nb_outputs + + @abstractmethod + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass in model.""" diff --git a/src/graphnet/models/cnn/icecube_dnn.py b/src/graphnet/models/cnn/icecube_dnn.py new file mode 100644 index 000000000..c899f39a3 --- /dev/null +++ b/src/graphnet/models/cnn/icecube_dnn.py @@ -0,0 +1,421 @@ +"""Implementation of the IceCube DNN image convolution model by Theo Glauch. + +Based on the `upgoing_muon_energy` model from +https://github.com/IceCubeOpenSource/i3deepice/tree/master +""" + +from typing import List, Tuple, Union + +import torch +from torch import nn +from pytorch_lightning import LightningModule +from torch_geometric.data import Data +from .cnn import CNN + + +class Conv3dBN(LightningModule): + """3D convolution with batch normalization from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int, int], + padding: Union[str, Tuple[int, int, int]], + bias: bool = False, + ): + """Create a Conv3dBN module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Size of the kernel. + padding: Padding of the kernel. + bias: If True, bias is used in the Convolution. + """ + super().__init__() + + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + ) + + self.bn = nn.BatchNorm3d(out_channels) + self.activation = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the Conv3dBN.""" + return self.activation(self.bn(self.conv(x))) + + +class InceptionBlock4(LightningModule): + """Inception block with 4 parallel towers from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + t0: int = 2, + t1: int = 4, + t2: int = 5, + n_pool: int = 3, + ): + """Create a InceptionBlock4 module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + t0: Size of the first kernel sequence. + t1: Size of the second kernel sequence. + t2: Size of the third kernel sequence. + n_pool: Size of the pooling kernel. + """ + super().__init__() + + self.tower0 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(t0, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t0, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t0), + padding="same", + ), + ) + + self.tower1 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(t1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t1), + padding="same", + ), + ) + + self.tower4 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, t2), + padding="same", + ), + ) + + self.tower3 = nn.Sequential( + nn.MaxPool3d( + kernel_size=(n_pool, n_pool, n_pool), + stride=(1, 1, 1), + padding=(n_pool // 2, n_pool // 2, n_pool // 2), + ), + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + ) + self.out_channels = out_channels * 4 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the InceptionBlock4.""" + ret = torch.cat( + [ + self.tower0(x), + self.tower1(x), + self.tower3(x), + self.tower4(x), + ], + dim=1, + ) + return ret + + +class InceptionResnet(LightningModule): + """Inception block with residual connections from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + t1: int = 2, + t2: int = 4, + n_pool: int = 3, + scale: float = 0.1, + ): + """Create a InceptionResnet module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + t1: Size of the first kernel sequence. + t2: Size of the second kernel sequence. + n_pool: Size of the pooling kernel. + scale: Scaling factor for the residual connection. + """ + super().__init__() + self._scale = scale + self.tower1 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(t1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t1), + padding="same", + ), + ) + self.tower2 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(t2, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t2, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t2), + padding="same", + ), + ) + self.tower3 = nn.Sequential( + nn.MaxPool3d( + kernel_size=(n_pool, n_pool, n_pool), + stride=(1, 1, 1), + padding=(n_pool // 2, n_pool // 2, n_pool // 2), + ), + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the InceptionResnet block.""" + tmp = torch.cat( + [ + self.tower1(x), + self.tower2(x), + self.tower3(x), + ], + dim=1, + ) + return x + self._scale * tmp + + +class IceCubeDNN(CNN): + """Implementation of the IceCube DNN by Theo Glauch. + + An inception-based 3D CNN originally used within IceCube. Based on + the model from + https://github.com/IceCubeOpenSource/i3deepice/tree/master + """ + + def __init__( + self, + nb_inputs: int = 15, + nb_outputs: int = 16, + image_size: Tuple[int, int, int] = (10, 10, 60), + inception_out_channels: int = 18, + inception_configs: List[Tuple[int, int, int]] = [ + (2, 5, 8), + (2, 3, 7), + (2, 4, 8), + (3, 5, 9), + (2, 8, 9), + ], + resnet_out_channels: int = 24, + resnet_t2_pattern: List[int] = [3, 4, 5], + num_resblocks1_repeats: int = 6, + num_resblocks2_repeats: int = 6, + avgpool1_size: Tuple[int, int, int] = (2, 2, 3), + avgpool2_size: Tuple[int, int, int] = (1, 1, 2), + avgpool3_size: Tuple[int, int, int] = (1, 1, 2), + pointwise_channels: List[int] = [64, 4], + mlp_hidden_sizes: List[int] = [120, 64], + ) -> None: + """Construct `IceCubeDNN`. + + Args: + nb_inputs: Number of input features. + nb_outputs: Number of output features. + image_size: Spatial dimensions of the input image + (height, width, depth). + inception_out_channels: Output channels per tower in each + inception block. + inception_configs: List of (t0, t1, t2) kernel size tuples + for each InceptionBlock4 layer. + resnet_out_channels: Output channels per tower in each + inception-resnet block. + resnet_t2_pattern: Pattern of t2 kernel sizes repeated in + each group of resnet blocks. + num_resblocks1_repeats: Number of times to repeat the + resnet_t2_pattern in the first resnet stage. + num_resblocks2_repeats: Number of times to repeat the + resnet_t2_pattern in the second resnet stage. + avgpool1_size: Kernel size for the first average pooling. + avgpool2_size: Kernel size for the second average pooling. + avgpool3_size: Kernel size for the third average pooling. + pointwise_channels: Output channels for each 1x1x1 + convolution layer. + mlp_hidden_sizes: Hidden layer sizes for the final MLP. + The input size is computed from the preceding layers + and the output size is nb_outputs. + """ + super().__init__(nb_inputs, nb_outputs) + + # Inception blocks + inception_blocks = [] + in_ch = nb_inputs + for t0, t1, t2 in inception_configs: + inception_blocks.append( + InceptionBlock4( + in_channels=in_ch, + out_channels=inception_out_channels, + t0=t0, + t1=t1, + t2=t2, + ) + ) + in_ch = inception_out_channels * 4 + self.inceptionblocks4 = nn.Sequential(*inception_blocks) + + # All inception/resnet blocks use "same" padding, so spatial + # dimensions only change at pooling layers. + spatial = list(image_size) + + self.avgpool1 = nn.AvgPool3d(avgpool1_size) + spatial = [s // p for s, p in zip(spatial, avgpool1_size)] + self.bn1 = nn.BatchNorm3d(in_ch) + + # First resnet stage + resnet_in_ch = in_ch + tmp = [] + for _ in range(num_resblocks1_repeats): + for t2 in resnet_t2_pattern: + tmp.append( + InceptionResnet( + in_channels=resnet_in_ch, + out_channels=resnet_out_channels, + t2=t2, + ) + ) + resnet_in_ch = resnet_out_channels * 3 + self.resblocks1 = nn.Sequential(*tmp) + + self.avgpool2 = nn.AvgPool3d(avgpool2_size) + spatial = [s // p for s, p in zip(spatial, avgpool2_size)] + self.bn2 = nn.BatchNorm3d(resnet_in_ch) + + # Second resnet stage + tmp = [] + for _ in range(num_resblocks2_repeats): + for t2 in resnet_t2_pattern: + tmp.append( + InceptionResnet( + in_channels=resnet_in_ch, + out_channels=resnet_out_channels, + t2=t2, + ) + ) + resnet_in_ch = resnet_out_channels * 3 + self.resblocks2 = nn.Sequential(*tmp) + + # Pointwise 1x1x1 convolutions + pointwise_layers: List[nn.Module] = [] + pw_in = resnet_in_ch + for pw_out in pointwise_channels: + pointwise_layers.append( + nn.Conv3d( + in_channels=pw_in, + out_channels=pw_out, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + ) + ) + pointwise_layers.append(nn.ReLU()) + pw_in = pw_out + self.convs111 = nn.Sequential(*pointwise_layers) + + self.avgpool3 = nn.AvgPool3d(avgpool3_size) + spatial = [s // p for s, p in zip(spatial, avgpool3_size)] + + # MLP head + latent_dim = pw_in * spatial[0] * spatial[1] * spatial[2] + mlp_sizes = [latent_dim] + mlp_hidden_sizes + [nb_outputs] + mlp_layers: List[nn.Module] = [] + for i in range(len(mlp_sizes) - 1): + mlp_layers.append(nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])) + self.mlps = nn.Sequential(*mlp_layers) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass in model.""" + assert len(data.x) == 1, "Only one image expected" + x = data.x[0] + x = self.inceptionblocks4(x) + x = self.avgpool1(x) + x = self.bn1(x) + x = self.resblocks1(x) + x = self.avgpool2(x) + x = self.bn2(x) + x = self.resblocks2(x) + x = self.convs111(x) + x = self.avgpool3(x) + x = nn.Flatten()(x) + x = self.mlps(x) + return x diff --git a/src/graphnet/models/cnn/lcsc.py b/src/graphnet/models/cnn/lcsc.py new file mode 100644 index 000000000..67e65d8bd --- /dev/null +++ b/src/graphnet/models/cnn/lcsc.py @@ -0,0 +1,551 @@ +"""Module for the Lightning CNN signal classifier (LCSC). + +All credits go to Alexander Harnisch (https://github.com/AlexHarn) +""" + +from .cnn import CNN +import torch +from torch_geometric.data import Data +from typing import List, Union, Tuple + + +class LCSC(CNN): + """Lightning CNN Signal Classifier (LCSC). + + All credits go to Alexander Harnisch ( + https://github.com/AlexHarn) + + Works with any single-image representation. The default + parameters were tested on IceCube simulation using the + Main Array image only. + """ + + def __init__( + self, + num_input_features: int, + out_put_dim: int = 2, + input_norm: bool = True, + num_conv_layers: int = 8, + conv_filters: List[int] = [50, 50, 50, 50, 50, 50, 50, 10], + kernel_size: Union[int, List[Union[int, List[int]]]] = 3, + padding: Union[str, int, List[Union[str, int]]] = "Same", + pooling_type: List[Union[None, str]] = [ + None, + "Avg", + None, + "Avg", + None, + "Avg", + None, + "Avg", + ], + pooling_kernel_size: List[Union[None, int, List[int]]] = [ + None, + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + pooling_stride: Union[int, List[Union[None, int, List[int]]]] = [ + None, + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + num_fc_neurons: int = 50, + norm_list: bool = True, + norm_type: str = "Batch", + image_size: Tuple[int, int, int] = (10, 10, 60), + ) -> None: + """Initialize the Lightning CNN signal classifier (LCSC). + + Args: + num_input_features: Number of input features. + out_put_dim: Number of output dimensions of final MLP. + Defaults to 2. + input_norm: Whether to apply normalization to the input. + Defaults to True. + num_conv_layers: Number of convolutional layers. + Defaults to 8. + conv_filters: List of number of convolutional + filters to use in hidden layers. + Defaults to [50, 50, 50, 50, 50, 50, 50, 50, 10]. + NOTE needs to have the length of `num_conv_layers`. + kernel_size: Size of the convolutional kernels. + Options are: + int: single integer for all dimensions + and all layers, + e.g. 3 would equal [3, 3, 3] for each layer. + list: list of integers specifying the kernel size, + for each layer for all dimensions equally, + e.g. [3, 5, 6] would equal [[3,3,3], [5,5,5], [6,6,6]]. + NOTE: needs to have the length of `num_conv_layers`. + If a list of lists is provided, each list will be used + for the corresponding layer as kernel size. + NOTE: If a list if passed it needs to have the length + of `num_conv_layers`. + padding: Padding for the convolutional layers. + Options are: + 'Same' for same convolutional padding, + int: single integer for all dimensions and all layers, + e.g. 1 would equal [1, 1, 1]. + list: list of integers specifying the padding for each + dimension, for each layer equally, + e.g. [1, 2, 3]. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to 'Same'. + pooling_type: List of pooling types for layers. + Options are + None : No pooling is used, + 'Avg' : Average pooling is used, + 'Max' : Max pooling is used + Defaults to [ + None, 'Avg', + None, 'Avg', + None, 'Avg', + None, 'Avg' + ]. + NOTE: the length of the list must be equal to + `num_conv_layers`. + pooling_kernel_size: List of pooling kernel sizes for each + layer. If an integer is provided, it will be used for + all layers. In case of a list the options for its + elements are: + list: list of integers for each dimension, e.g. [1, 1, 2]. + int: single integer for all dimensions, + e.g. 2 would equal [2, 2, 2]. + If None, no pooling is applied. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to [ + None, [1, 1, 2], + None, [2, 2, 2], + None, [2, 2, 2], + None, [2, 2, 2] + ]. + pooling_stride: List of pooling strides for each layer. + If an integer is provided, it will be used for all layers. + In case of a list the options for its elements are: + list: list of integers for each dimension, e.g. [1, 1, 2]. + int: single integer for all dimensions, + e.g. 2 would equal [2, 2, 2]. + If None, no pooling is applied. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to [ + None, [1, 1, 2], + None, [2, 2, 2], + None, [2, 2, 2], + None, [2, 2, 2] + ]. + num_fc_neurons: Number of neurons in the fully connected + layers. Defaults to 50. + norm_list: Whether to apply normalization for each + convolutional layer. If a boolean is provided, it will + be used for all layers. Defaults to True. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + norm_type: Type of normalization to use. + Options are 'Batch' or 'Instance'. + Defaults to 'Batch'. + image_size: Size of the input image in the format + (height, width, depth). + NOTE: Only needs to be changed if the input image is not + the standard IceCube 86 image size. + """ + super().__init__(nb_inputs=num_input_features, nb_outputs=out_put_dim) + + # Check and parse input parameters + conv_filters, kernel_size, padding = self._parse_conv_arguments( + num_conv_layers=num_conv_layers, + conv_filters=conv_filters, + kernel_size=kernel_size, + padding=padding, + ) + pooling_kernel_size, pooling_stride = self._parse_pooling_arguments( + num_conv_layers=num_conv_layers, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + ) + self._norm_list, norm_class = self._parse_norm_arguments( + num_conv_layers=num_conv_layers, + num_input_features=num_input_features, + input_norm=input_norm, + norm_list=norm_list, + norm_type=norm_type, + ) + + # Set convolution, pooling, and normalization layers + self.input_norm = input_norm + dimensions = self._set_conv_layers( + num_conv_layers=num_conv_layers, + num_input_features=num_input_features, + image_size=image_size, + conv_filters=conv_filters, + kernel_size=kernel_size, + padding=padding, + pooling_type=pooling_type, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + norm_class=norm_class, + ) + + # Set linear layers + latent_dim = ( + dimensions[0] * dimensions[1] * dimensions[2] * dimensions[3] + ) + self.flatten = torch.nn.Flatten() + self.fc1 = torch.nn.Linear(latent_dim, num_fc_neurons) + self.fc2 = torch.nn.Linear(num_fc_neurons, out_put_dim) + + def _parse_conv_arguments( + self, + num_conv_layers: int, + conv_filters: Union[int, List[int]], + kernel_size: Union[int, List[Union[int, List[int]]]], + padding: Union[str, int, List[Union[str, int]]], + ) -> Tuple[List[int], List, List]: + """Parse and validate convolution arguments. + + Args: + num_conv_layers: Number of convolutional layers. + conv_filters: Convolutional filters per layer. + kernel_size: Kernel sizes per layer. + padding: Padding per layer. + + Returns: + Parsed conv_filters, kernel_size, and padding as lists. + """ + if isinstance(conv_filters, int): + conv_filters = [conv_filters for _ in range(num_conv_layers)] + else: + if not isinstance(conv_filters, list): + raise TypeError( + f"`conv_filters` must be a " + f"list or an integer, not {type(conv_filters)}!" + ) + if len(conv_filters) != num_conv_layers: + raise ValueError( + f"`conv_filters` must have {num_conv_layers} " + f"elements, not {len(conv_filters)}!" + ) + + if isinstance(kernel_size, int): + kernel_size = [ # type: ignore[assignment] + [kernel_size, kernel_size, kernel_size] + for _ in range(num_conv_layers) + ] + else: + if not isinstance(kernel_size, list): + raise TypeError( + "`kernel_size` must be a list or an " + f"integer, not {type(kernel_size)}!" + ) + if len(kernel_size) != num_conv_layers: + raise ValueError( + f"`kernel_size` must have {num_conv_layers} " + f"elements, not {len(kernel_size)}!" + ) + + if isinstance(padding, int): + padding = [padding for _ in range(num_conv_layers)] + elif isinstance(padding, str): + if padding.lower() == "same": + padding = ["same" for _ in range(num_conv_layers)] + else: + raise ValueError( + "`padding` must be 'Same' or an integer, " + f"not {padding}!" + ) + else: + if not isinstance(padding, list): + raise TypeError( + f"`padding` must be a list or " + f"an integer, not {type(padding)}!" + ) + if len(padding) != num_conv_layers: + raise ValueError( + f"`padding` must have {num_conv_layers} " + f"elements, not {len(padding)}!" + ) + + return conv_filters, kernel_size, padding + + def _parse_pooling_arguments( + self, + num_conv_layers: int, + pooling_kernel_size: Union[int, List[Union[None, int, List[int]]]], + pooling_stride: Union[int, List[Union[None, int, List[int]]]], + ) -> Tuple[List, List]: + """Parse and validate pooling arguments. + + Args: + num_conv_layers: Number of convolutional layers. + pooling_kernel_size: Pooling kernel sizes per layer. + pooling_stride: Pooling strides per layer. + + Returns: + Parsed pooling_kernel_size and pooling_stride as lists. + """ + if isinstance(pooling_kernel_size, int): + pooling_kernel_size = [ + pooling_kernel_size for _ in range(num_conv_layers) + ] + else: + if not isinstance(pooling_kernel_size, list): + raise TypeError( + "`pooling_kernel_size` must be a list or " + f"an integer, not {type(pooling_kernel_size)}!" + ) + if len(pooling_kernel_size) != num_conv_layers: + raise ValueError( + f"`pooling_kernel_size` must have " + f"{num_conv_layers} elements, not " + f"{len(pooling_kernel_size)}!" + ) + + if isinstance(pooling_stride, int): + pooling_stride = [pooling_stride for _ in range(num_conv_layers)] + else: + if not isinstance(pooling_stride, list): + raise TypeError( + "`pooling_stride` must be a list or an integer, " + f"not {type(pooling_stride)}!" + ) + if len(pooling_stride) != num_conv_layers: + raise ValueError( + f"`pooling_stride` must have {num_conv_layers} " + f"elements, not {len(pooling_stride)}!" + ) + + return pooling_kernel_size, pooling_stride + + def _parse_norm_arguments( + self, + num_conv_layers: int, + num_input_features: int, + input_norm: bool, + norm_list: Union[bool, List[bool]], + norm_type: str, + ) -> Tuple[List[bool], type]: + """Parse and validate normalization arguments. + + Args: + num_conv_layers: Number of convolutional layers. + num_input_features: Number of input features. + input_norm: Whether to apply input normalization. + norm_list: Per-layer normalization flags. + norm_type: Type of normalization ('Batch' or 'Instance'). + + Returns: + Parsed norm_list and the normalization class. + """ + if isinstance(norm_list, bool): + parsed_norm_list = [norm_list for _ in range(num_conv_layers)] + else: + if not isinstance(norm_list, list): + raise TypeError( + "`norm_list` must be a list or a boolean, " + f"not {type(norm_list)}!" + ) + if len(norm_list) != num_conv_layers: + raise ValueError( + f"`norm_list` must have {num_conv_layers} " + f"elements, not {len(norm_list)}!" + ) + parsed_norm_list = norm_list + + if norm_type.lower() == "instance": + norm_class = torch.nn.InstanceNorm3d + if input_norm: + self.input_normal = torch.nn.InstanceNorm3d(num_input_features) + elif norm_type.lower() == "batch": + norm_class = torch.nn.BatchNorm3d + if input_norm: + self.input_normal = torch.nn.BatchNorm3d( + num_input_features, momentum=None, affine=False + ) + else: + raise ValueError( + "`norm_type` has to be 'instance' or " + f"'batch', not '{norm_type}'!" + ) + + return parsed_norm_list, norm_class + + def _set_conv_layers( + self, + num_conv_layers: int, + num_input_features: int, + image_size: Tuple[int, int, int], + conv_filters: List[int], + kernel_size: List, + padding: List, + pooling_type: List[Union[None, str]], + pooling_kernel_size: List, + pooling_stride: List, + norm_class: type, + ) -> List[int]: + """Build convolution, pooling, and normalization layers. + + Args: + num_conv_layers: Number of convolutional layers. + num_input_features: Number of input features. + image_size: Size of the input image (height, width, depth). + conv_filters: Convolutional filters per layer. + kernel_size: Kernel sizes per layer. + padding: Padding per layer. + pooling_type: Pooling type per layer. + pooling_kernel_size: Pooling kernel sizes per layer. + pooling_stride: Pooling strides per layer. + norm_class: Normalization class to use. + + Returns: + Output dimensions after all layers. + """ + self.conv = torch.nn.ModuleList() + self.pool = torch.nn.ModuleList() + self.normal = torch.nn.ModuleList() + + dimensions: List[int] = [num_input_features, *image_size] + for i in range(num_conv_layers): + self.conv.append( + torch.nn.Conv3d( + dimensions[0], + conv_filters[i], + kernel_size=kernel_size[i], + padding=padding[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + conv_filters[i], + kernel_size[i], + padding[i], + ) + if pooling_type[i] is None or pooling_type[i] == "None": + self.pool.append(None) + elif pooling_type[i] == "Avg": + self.pool.append( + torch.nn.AvgPool3d( + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + out_channels=dimensions[0], + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + elif pooling_type[i] == "Max": + self.pool.append( + torch.nn.MaxPool3d( + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + out_channels=dimensions[0], + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + else: + raise ValueError( + "Pooling type must be 'None', 'Avg' or 'Max'!" + ) + if self._norm_list[i]: + self.normal.append(norm_class(dimensions[0])) + else: + self.normal.append(None) + + return dimensions + + def _calc_output_dimension( + self, + dimensions: List[int], + out_channels: int, + kernel_size: Union[None, int, List[int]], + padding: Union[str, int, List[int]] = 0, + stride: Union[None, int, List[int]] = 1, + ) -> List[int]: + """Calculate the output dimension after a CNN layers. + + Works for Conv3D, MaxPool3D and AvgPool3D layers. + + Args: + dimensions: Current dimensions of the input tensor. + (C,H,W,D) where C is the number of channels, + H is the height, W is the width and D is the depth. + out_channels: Number of output channels. + kernel_size: Size of the kernel. + If an integer is provided, it will be used for all dimensions. + padding: Padding size. + If an integer is provided, it will be used for all dimensions. + If 'Same', the padding will be calculated to keep the + output size the same as the input size. + Defaults to 0. + stride: Stride size. + If an integer is provided, it will be used for all dimensions. + Defaults to 1. + + Returns: + New dimensions after the layer. + + NOTE: For the pooling layers, set out_channels equal to the + input channels. Since they do not change the number of channels. + """ + krnl_sz: int + if isinstance(padding, str): + if not padding.lower() == "same": + raise ValueError( + f"`padding` must be 'Same' or an integer, not {padding}!" + ) + dimensions[0] = out_channels + else: + for i in range(1, 4): + if isinstance(kernel_size, list): + krnl_sz = kernel_size[i - 1] + else: + assert isinstance(kernel_size, int) + krnl_sz = kernel_size + if isinstance(padding, list): + pad = padding[i - 1] + else: + pad = padding + if isinstance(stride, list): + strd = stride[i - 1] + else: + assert isinstance(stride, int) + strd = stride + dimensions[i] = (dimensions[i] + 2 * pad - krnl_sz) // strd + 1 + + return dimensions + + def forward(self, data: Data) -> torch.Tensor: + """Forward pass of the LCSC.""" + assert len(data.x) == 1, "Only a single image is expected" + x = data.x[0] + if self.input_norm: + x = self.input_normal(x) + for i in range(len(self.conv)): + x = self.conv[i](x) + if self.pool[i] is not None: + x = self.pool[i](x) + x = torch.nn.functional.elu(x) + if self.normal[i] is not None: + x = self.normal[i](x) + + x = self.flatten(x) + x = torch.nn.functional.elu(self.fc1(x)) + x = torch.nn.functional.elu(self.fc2(x)) + return x diff --git a/src/graphnet/models/data_representation/__init__.py b/src/graphnet/models/data_representation/__init__.py index 84dd64331..e72c2c712 100644 --- a/src/graphnet/models/data_representation/__init__.py +++ b/src/graphnet/models/data_representation/__init__.py @@ -18,3 +18,11 @@ NodeAsDOMTimeSeries, IceMixNodes, ) +from .images import ( + ExamplePrometheusGridDefinition, + ExamplePrometheusImage, + GridDefinition, + IC86GridDefinition, + IC86Image, + ImageRepresentation, +) diff --git a/src/graphnet/models/data_representation/images/__init__.py b/src/graphnet/models/data_representation/images/__init__.py new file mode 100644 index 000000000..14e1d2bac --- /dev/null +++ b/src/graphnet/models/data_representation/images/__init__.py @@ -0,0 +1,9 @@ +"""CNN images: ``ImageRepresentation`` + ``GridDefinition``.""" + +from .image_representation import ImageRepresentation +from .images import IC86Image, ExamplePrometheusImage +from .mappings import ( + ExamplePrometheusGridDefinition, + GridDefinition, + IC86GridDefinition, +) diff --git a/src/graphnet/models/data_representation/images/image_representation.py b/src/graphnet/models/data_representation/images/image_representation.py new file mode 100644 index 000000000..38e0c64e1 --- /dev/null +++ b/src/graphnet/models/data_representation/images/image_representation.py @@ -0,0 +1,145 @@ +"""Image CNN input pipeline: pulses → pixels → grids / tensors.""" + +from typing import List, Optional, Dict, Tuple, Union, Any, Callable +import torch +import numpy as np +from numpy.random import Generator + +from graphnet.models.data_representation import DataRepresentation +from graphnet.models.data_representation.graphs import NodeDefinition +from torch_geometric.data import Data +from .mappings import GridDefinition + + +class ImageRepresentation(DataRepresentation): + """Compose a pixel definition with a detector grid for CNN inputs. + + A :class:`~graphnet.models.data_representation.graphs.nodes.NodeDefinition` + acts as **pixel definition**: pulses ``X`` are aggregated into unordered + pixel rows ``P`` (the same abstraction as graph nodes, without requiring + graph terminology for CNN users). + + A :class:`GridDefinition` defines detector-bound orthonormal grid shape(s) + and lookup table(s); its :meth:`~GridDefinition.forward` scatters ``P`` + into image tensor(s). + + The :class:`~graphnet.models.detector.detector.Detector` is taken from + ``grid_definition.detector`` so the grid matches the preprocessing geometry. + """ + + def __init__( + self, + pixel_definition: NodeDefinition, + grid_definition: GridDefinition, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + add_inactive_sensors: bool = False, + sensor_mask: Optional[List[int]] = None, + string_mask: Optional[List[int]] = None, + ): + """Construct `ImageRepresentation`. + + Args: + pixel_definition: Pulse-level features → one row per pixel/DOM. + grid_definition: Pixel keys + voxel indices + scatter into images. + input_feature_names: Column names in raw pulse tables. If omitted, + the detector's feature list is used. + dtype: Feature dtype (e.g. ``torch.float``). + perturbation_dict: Optional feature noise (see ``DataRepresentation``). + seed: RNG for perturbations. + add_inactive_sensors: Pad inactive sensors when True. + sensor_mask: Drop these sensor IDs. + string_mask: Drop these string IDs. + + Note: + ``pixel_definition`` output columns must match what + ``grid_definition`` expects (including key fields in + :attr:`GridDefinition.map_pixels_by`). + """ + super().__init__( + detector=grid_definition.detector, + input_feature_names=input_feature_names, + dtype=dtype, + perturbation_dict=perturbation_dict, + seed=seed, + add_inactive_sensors=add_inactive_sensors, + sensor_mask=sensor_mask, + string_mask=string_mask, + repeat_labels=False, + ) + self._pixel_definition = pixel_definition + self._grid_definition = grid_definition + + @property + def shape(self) -> List[List[int]]: + """Channel-spatial layout per image tensor (see ``GridDefinition``).""" + return self._grid_definition.shape + + def single_image_spatial_shape(self) -> Tuple[int, int, int]: + """Return spatial size as ``(height, width, depth)`` for one 3D image. + + Raises: + ValueError: If ``shape`` does not describe exactly one four-axis + layout (channels plus three spatial axes). + """ + layouts = self.shape + if len(layouts) != 1: + raise ValueError( + "Expected a single-image data representation (one shape " + f"entry), got {len(layouts)}. For multi-image inputs, build " + "the backbone explicitly for each tensor." + ) + layout = layouts[0] + if len(layout) != 4: + raise ValueError( + "Expected each image layout as " + "[num_channels, height, width, depth]; " + f"got {layout!r}." + ) + return (layout[1], layout[2], layout[3]) + + def _set_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + """Sync pixel columns and grid output names.""" + self._pixel_definition.set_output_feature_names(input_feature_names) + self._grid_definition._set_image_feature_names( + self._pixel_definition._output_feature_names + ) + return self._grid_definition.image_feature_names + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + truth_dicts: Optional[List[Dict[str, Any]]] = None, + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + data_path: Optional[str] = None, + ) -> Data: + """Build a ``Data`` object with image tensor(s) on ``x``.""" + data = super().forward( + input_features=input_features, + input_feature_names=input_feature_names, + truth_dicts=truth_dicts, + custom_label_functions=custom_label_functions, + loss_weight_column=loss_weight_column, + loss_weight=loss_weight, + loss_weight_default_value=loss_weight_default_value, + data_path=data_path, + ) + data.x = self._pixel_definition(data.x) + data.x = data.x.type(self.dtype) + data = self._grid_definition(data, self.output_feature_names) + if not isinstance(data.x, list): + data.x = [data.x] + nb_nodes = [] + for i, x in enumerate(data.x): + data.x[i] = x.type(self.dtype) + nb_nodes.append(np.prod(list(data.x[i].size()[2:]))) + data.num_nodes = torch.tensor(np.sum(nb_nodes)) + return data diff --git a/src/graphnet/models/data_representation/images/images.py b/src/graphnet/models/data_representation/images/images.py new file mode 100644 index 000000000..e4a7491ee --- /dev/null +++ b/src/graphnet/models/data_representation/images/images.py @@ -0,0 +1,123 @@ +"""Concrete :class:`ImageRepresentation` subclasses for common detectors.""" + +from typing import List, Optional, Any +import torch + +from graphnet.models.data_representation.graphs import NodeDefinition +from graphnet.models.detector import Detector, IceCube86, ORCA150 + +from .image_representation import ImageRepresentation +from .mappings import IC86GridDefinition, ExamplePrometheusGridDefinition + + +class IC86Image(ImageRepresentation): + """IceCube-86 images (main array + optional DeepCore tensors).""" + + def __init__( + self, + pixel_definition: NodeDefinition, + input_feature_names: List[str], + include_lower_dc: bool = True, + include_upper_dc: bool = True, + string_label: str = "string", + dom_number_label: str = "dom_number", + dtype: Optional[torch.dtype] = torch.float, + detector: Optional[Detector] = None, + **kwargs: Any, + ) -> None: + """Construct `IC86Image`. + + Args: + pixel_definition: Pulse → DOM row features (:class:`NodeDefinition`). + input_feature_names: Raw input column names. + include_lower_dc: Include lower DeepCore grid. + include_upper_dc: Include upper DeepCore grid. + string_label: DOM string column in pixel rows. + dom_number_label: DOM index column in pixel rows. + dtype: Tensor dtype for images. + detector: ``IceCube86``; default standardizes all but coordinates. + """ + if detector is None: + detector = IceCube86( + replace_with_identity=input_feature_names, + ) + else: + assert isinstance(detector, IceCube86) + pixel_definition.set_output_feature_names(input_feature_names) + assert ( + string_label in input_feature_names + ), f"String label '{string_label}' not in input feature names" + assert ( + dom_number_label in input_feature_names + ), f"DOM number label '{dom_number_label}' not in input feature names" + + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + string_label=string_label, + dom_number_label=dom_number_label, + pixel_feature_names=pixel_definition._output_feature_names, + include_lower_dc=include_lower_dc, + include_upper_dc=include_upper_dc, + ) + + super().__init__( + pixel_definition=pixel_definition, + grid_definition=grid_definition, + input_feature_names=input_feature_names, + add_inactive_sensors=False, + **kwargs, + ) + + +class ExamplePrometheusImage(ImageRepresentation): + """Example Prometheus-style single-image layout (tutorial scripts only).""" + + def __init__( + self, + pixel_definition: NodeDefinition, + input_feature_names: List[str], + string_label: str = "sensor_string_id", + dom_number_label: str = "sensor_id", + dtype: Optional[torch.dtype] = torch.float, + detector: Optional[Detector] = None, + **kwargs: Any, + ) -> None: + """Construct `ExamplePrometheusImage`. + + Args: + pixel_definition: Pulse → sensor row features (:class:`NodeDefinition`). + input_feature_names: Raw input column names. + string_label: String id column in pixel rows. + dom_number_label: Sensor id column (internal grid key name). + dtype: Tensor dtype for images. + detector: ``ORCA150`` by default. + """ + if detector is None: + detector = ORCA150( + replace_with_identity=input_feature_names, + ) + + pixel_definition.set_output_feature_names(input_feature_names) + assert ( + string_label in input_feature_names + ), f"String label '{string_label}' not in input feature names" + assert ( + dom_number_label in input_feature_names + ), f"DOM number label '{dom_number_label}' not in input feature names" + + grid_definition = ExamplePrometheusGridDefinition( + detector=detector, + dtype=dtype, + string_label=string_label, + sensor_number_label=dom_number_label, + pixel_feature_names=pixel_definition._output_feature_names, + ) + + super().__init__( + pixel_definition=pixel_definition, + grid_definition=grid_definition, + input_feature_names=input_feature_names, + add_inactive_sensors=False, + **kwargs, + ) diff --git a/src/graphnet/models/data_representation/images/mappings/__init__.py b/src/graphnet/models/data_representation/images/mappings/__init__.py new file mode 100644 index 000000000..a62f8ccd9 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/__init__.py @@ -0,0 +1,7 @@ +"""Detector grids: shapes, lookup tables, scatter into image tensors.""" + +from .grid_definition import ( + ExamplePrometheusGridDefinition, + GridDefinition, + IC86GridDefinition, +) diff --git a/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py b/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py new file mode 100644 index 000000000..aa9035822 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py @@ -0,0 +1,667 @@ +"""CNN image-grid lookup tables for IceCube-86 and the Prometheus example. + +Tensor columns use :data:`MAT_AX0_COL`, :data:`MAT_AX1_COL`, and +:data:`MAT_AX2_COL` for voxel indices. IceCube-86 is built by +:func:`build_ic86_cnn_mapping`; the Prometheus example uses a fixed row list +from the former ``prometheus_CNN_mapping.parquet`` layout. +""" + +from __future__ import annotations + +from typing import Dict, Final, List, Tuple + +import pandas as pd + +MAT_AX0_COL = "mat_ax0" +MAT_AX1_COL = "mat_ax1" +MAT_AX2_COL = "mat_ax2" + +# Fixed (mat_ax0, mat_ax1) placement for IceCube86 main-array strings +# (1..78) on a 10x10 grid. Matches the legacy IC86 CNN mapping table. +_IC86_STRING_TO_AX01: Dict[int, Tuple[int, int]] = { + 1: (9, 4), + 2: (9, 5), + 3: (9, 6), + 4: (9, 7), + 5: (9, 8), + 6: (9, 9), + 7: (8, 3), + 8: (8, 4), + 9: (8, 5), + 10: (8, 6), + 11: (8, 7), + 12: (8, 8), + 13: (8, 9), + 14: (7, 2), + 15: (7, 3), + 16: (7, 4), + 17: (7, 5), + 18: (7, 6), + 19: (7, 7), + 20: (7, 8), + 21: (7, 9), + 22: (6, 1), + 23: (6, 2), + 24: (6, 3), + 25: (6, 4), + 26: (6, 5), + 27: (6, 6), + 28: (6, 7), + 29: (6, 8), + 30: (6, 9), + 31: (5, 0), + 32: (5, 1), + 33: (5, 2), + 34: (5, 3), + 35: (5, 4), + 36: (5, 5), + 37: (5, 6), + 38: (5, 7), + 39: (5, 8), + 40: (5, 9), + 41: (4, 0), + 42: (4, 1), + 43: (4, 2), + 44: (4, 3), + 45: (4, 4), + 46: (4, 5), + 47: (4, 6), + 48: (4, 7), + 49: (4, 8), + 50: (4, 9), + 51: (3, 0), + 52: (3, 1), + 53: (3, 2), + 54: (3, 3), + 55: (3, 4), + 56: (3, 5), + 57: (3, 6), + 58: (3, 7), + 59: (3, 8), + 60: (2, 0), + 61: (2, 1), + 62: (2, 2), + 63: (2, 3), + 64: (2, 4), + 65: (2, 5), + 66: (2, 6), + 67: (2, 7), + 68: (1, 0), + 69: (1, 1), + 70: (1, 2), + 71: (1, 3), + 72: (1, 4), + 73: (1, 5), + 74: (1, 6), + 75: (0, 0), + 76: (0, 1), + 77: (0, 2), + 78: (0, 3), +} + +# Sentinel for DeepCore rows in ``mat_ax2`` (two axes only in sub-images). +_DC_AX2_SENTINEL = -500 + + +def build_ic86_cnn_mapping( + string_label: str, + dom_number_label: str, +) -> pd.DataFrame: + """Build the IceCube-86 CNN pixel mapping DataFrame. + + Places the 78 main-array strings on a fixed 10x10 grid and folds the + eight DeepCore strings (79..86) into upper/lower DeepCore views using + the usual DOM numbering convention. + + Args: + string_label: Column name for IceCube string number (1..86). + dom_number_label: Column name for DOM number (1..60). + + Returns: + DataFrame indexed by (``string_label``, ``dom_number_label``) with + voxel index columns ``mat_ax0``, ``mat_ax1``, and ``mat_ax2``. + """ + rows: List[Tuple[int, int, int, int, float]] = [] + + for string in range(1, 79): + ax0, ax1 = _IC86_STRING_TO_AX01[string] + for dom_number in range(1, 61): + rows.append((string, dom_number, ax0, ax1, float(dom_number - 1))) + + for string in range(79, 87): + ax0 = string - 79 + for dom_number in range(1, 11): + rows.append( + (string, dom_number, ax0, dom_number - 1, _DC_AX2_SENTINEL) + ) + for dom_number in range(11, 61): + rows.append( + (string, dom_number, ax0, dom_number - 11, _DC_AX2_SENTINEL) + ) + + df = pd.DataFrame( + rows, + columns=[ + string_label, + dom_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ], + ) + df.sort_values( + by=[string_label, dom_number_label], + ascending=[True, True], + inplace=True, + ) + df.set_index([string_label, dom_number_label], inplace=True, drop=False) + return df + + +_Row = Tuple[int, int, int, int, int] + +PROMETHEUS_CNN_MAPPING_ROWS: Final[tuple[_Row, ...]] = ( + (0, 0, 3, 4, 0), + (0, 1, 3, 4, 1), + (0, 2, 3, 4, 2), + (0, 3, 3, 4, 3), + (0, 4, 3, 4, 4), + (0, 5, 3, 4, 5), + (0, 6, 3, 4, 6), + (0, 7, 3, 4, 7), + (0, 8, 3, 4, 8), + (0, 9, 3, 4, 9), + (0, 10, 3, 4, 10), + (0, 11, 3, 4, 11), + (0, 12, 3, 4, 12), + (0, 13, 3, 4, 13), + (0, 14, 3, 4, 14), + (0, 15, 3, 4, 15), + (0, 16, 3, 4, 16), + (0, 17, 3, 4, 17), + (0, 18, 3, 4, 18), + (0, 19, 3, 4, 19), + (0, 20, 3, 4, 20), + (0, 21, 3, 4, 21), + (1, 22, 4, 3, 0), + (1, 23, 4, 3, 1), + (1, 24, 4, 3, 2), + (1, 25, 4, 3, 3), + (1, 26, 4, 3, 4), + (1, 27, 4, 3, 5), + (1, 28, 4, 3, 6), + (1, 29, 4, 3, 7), + (1, 30, 4, 3, 8), + (1, 31, 4, 3, 9), + (1, 32, 4, 3, 10), + (1, 33, 4, 3, 11), + (1, 34, 4, 3, 12), + (1, 35, 4, 3, 13), + (1, 36, 4, 3, 14), + (1, 37, 4, 3, 15), + (1, 38, 4, 3, 16), + (1, 39, 4, 3, 17), + (1, 40, 4, 3, 18), + (1, 41, 4, 3, 19), + (1, 42, 4, 3, 20), + (1, 43, 4, 3, 21), + (2, 44, 3, 5, 0), + (2, 45, 3, 5, 1), + (2, 46, 3, 5, 2), + (2, 47, 3, 5, 3), + (2, 48, 3, 5, 4), + (2, 49, 3, 5, 5), + (2, 50, 3, 5, 6), + (2, 51, 3, 5, 7), + (2, 52, 3, 5, 8), + (2, 53, 3, 5, 9), + (2, 54, 3, 5, 10), + (2, 55, 3, 5, 11), + (2, 56, 3, 5, 12), + (2, 57, 3, 5, 13), + (2, 58, 3, 5, 14), + (2, 59, 3, 5, 15), + (2, 60, 3, 5, 16), + (2, 61, 3, 5, 17), + (2, 62, 3, 5, 18), + (2, 63, 3, 5, 19), + (2, 64, 3, 5, 20), + (2, 65, 3, 5, 21), + (3, 66, 3, 3, 0), + (3, 67, 3, 3, 1), + (3, 68, 3, 3, 2), + (3, 69, 3, 3, 3), + (3, 70, 3, 3, 4), + (3, 71, 3, 3, 5), + (3, 72, 3, 3, 6), + (3, 73, 3, 3, 7), + (3, 74, 3, 3, 8), + (3, 75, 3, 3, 9), + (3, 76, 3, 3, 10), + (3, 77, 3, 3, 11), + (3, 78, 3, 3, 12), + (3, 79, 3, 3, 13), + (3, 80, 3, 3, 14), + (3, 81, 3, 3, 15), + (3, 82, 3, 3, 16), + (3, 83, 3, 3, 17), + (3, 84, 3, 3, 18), + (3, 85, 3, 3, 19), + (3, 86, 3, 3, 20), + (3, 87, 3, 3, 21), + (4, 88, 4, 4, 0), + (4, 89, 4, 4, 1), + (4, 90, 4, 4, 2), + (4, 91, 4, 4, 3), + (4, 92, 4, 4, 4), + (4, 93, 4, 4, 5), + (4, 94, 4, 4, 6), + (4, 95, 4, 4, 7), + (4, 96, 4, 4, 8), + (4, 97, 4, 4, 9), + (4, 98, 4, 4, 10), + (4, 99, 4, 4, 11), + (4, 100, 4, 4, 12), + (4, 101, 4, 4, 13), + (4, 102, 4, 4, 14), + (4, 103, 4, 4, 15), + (4, 104, 4, 4, 16), + (4, 105, 4, 4, 17), + (4, 106, 4, 4, 18), + (4, 107, 4, 4, 19), + (4, 108, 4, 4, 20), + (4, 109, 4, 4, 21), + (5, 110, 2, 4, 0), + (5, 111, 2, 4, 1), + (5, 112, 2, 4, 2), + (5, 113, 2, 4, 3), + (5, 114, 2, 4, 4), + (5, 115, 2, 4, 5), + (5, 116, 2, 4, 6), + (5, 117, 2, 4, 7), + (5, 118, 2, 4, 8), + (5, 119, 2, 4, 9), + (5, 120, 2, 4, 10), + (5, 121, 2, 4, 11), + (5, 122, 2, 4, 12), + (5, 123, 2, 4, 13), + (5, 124, 2, 4, 14), + (5, 125, 2, 4, 15), + (5, 126, 2, 4, 16), + (5, 127, 2, 4, 17), + (5, 128, 2, 4, 18), + (5, 129, 2, 4, 19), + (5, 130, 2, 4, 20), + (5, 131, 2, 4, 21), + (6, 132, 4, 2, 0), + (6, 133, 4, 2, 1), + (6, 134, 4, 2, 2), + (6, 135, 4, 2, 3), + (6, 136, 4, 2, 4), + (6, 137, 4, 2, 5), + (6, 138, 4, 2, 6), + (6, 139, 4, 2, 7), + (6, 140, 4, 2, 8), + (6, 141, 4, 2, 9), + (6, 142, 4, 2, 10), + (6, 143, 4, 2, 11), + (6, 144, 4, 2, 12), + (6, 145, 4, 2, 13), + (6, 146, 4, 2, 14), + (6, 147, 4, 2, 15), + (6, 148, 4, 2, 16), + (6, 149, 4, 2, 17), + (6, 150, 4, 2, 18), + (6, 151, 4, 2, 19), + (6, 152, 4, 2, 20), + (6, 153, 4, 2, 21), + (7, 154, 4, 5, 0), + (7, 155, 4, 5, 1), + (7, 156, 4, 5, 2), + (7, 157, 4, 5, 3), + (7, 158, 4, 5, 4), + (7, 159, 4, 5, 5), + (7, 160, 4, 5, 6), + (7, 161, 4, 5, 7), + (7, 162, 4, 5, 8), + (7, 163, 4, 5, 9), + (7, 164, 4, 5, 10), + (7, 165, 4, 5, 11), + (7, 166, 4, 5, 12), + (7, 167, 4, 5, 13), + (7, 168, 4, 5, 14), + (7, 169, 4, 5, 15), + (7, 170, 4, 5, 16), + (7, 171, 4, 5, 17), + (7, 172, 4, 5, 18), + (7, 173, 4, 5, 19), + (7, 174, 4, 5, 20), + (7, 175, 4, 5, 21), + (8, 176, 2, 3, 0), + (8, 177, 2, 3, 1), + (8, 178, 2, 3, 2), + (8, 179, 2, 3, 3), + (8, 180, 2, 3, 4), + (8, 181, 2, 3, 5), + (8, 182, 2, 3, 6), + (8, 183, 2, 3, 7), + (8, 184, 2, 3, 8), + (8, 185, 2, 3, 9), + (8, 186, 2, 3, 10), + (8, 187, 2, 3, 11), + (8, 188, 2, 3, 12), + (8, 189, 2, 3, 13), + (8, 190, 2, 3, 14), + (8, 191, 2, 3, 15), + (8, 192, 2, 3, 16), + (8, 195, 2, 3, 19), + (8, 196, 2, 3, 20), + (8, 197, 2, 3, 21), + (9, 198, 5, 3, 0), + (9, 199, 5, 3, 1), + (9, 200, 5, 3, 2), + (9, 202, 5, 3, 4), + (9, 203, 5, 3, 5), + (9, 204, 5, 3, 6), + (9, 205, 5, 3, 7), + (9, 206, 5, 3, 8), + (9, 207, 5, 3, 9), + (9, 208, 5, 3, 10), + (9, 209, 5, 3, 11), + (9, 210, 5, 3, 12), + (9, 211, 5, 3, 13), + (9, 212, 5, 3, 14), + (9, 213, 5, 3, 15), + (9, 214, 5, 3, 16), + (9, 217, 5, 3, 19), + (9, 218, 5, 3, 20), + (9, 219, 5, 3, 21), + (10, 220, 2, 5, 0), + (10, 222, 2, 5, 2), + (10, 224, 2, 5, 4), + (10, 225, 2, 5, 5), + (10, 228, 2, 5, 8), + (10, 229, 2, 5, 9), + (10, 230, 2, 5, 10), + (10, 231, 2, 5, 11), + (10, 233, 2, 5, 13), + (10, 235, 2, 5, 15), + (10, 236, 2, 5, 16), + (10, 239, 2, 5, 19), + (10, 240, 2, 5, 20), + (10, 241, 2, 5, 21), + (11, 242, 3, 2, 0), + (11, 244, 3, 2, 2), + (11, 245, 3, 2, 3), + (11, 246, 3, 2, 4), + (11, 247, 3, 2, 5), + (11, 250, 3, 2, 8), + (11, 252, 3, 2, 10), + (11, 253, 3, 2, 11), + (11, 255, 3, 2, 13), + (11, 256, 3, 2, 14), + (11, 257, 3, 2, 15), + (11, 258, 3, 2, 16), + (11, 261, 3, 2, 19), + (11, 263, 3, 2, 21), + (12, 264, 5, 5, 0), + (12, 266, 5, 5, 2), + (12, 268, 5, 5, 4), + (12, 269, 5, 5, 5), + (12, 271, 5, 5, 7), + (12, 272, 5, 5, 8), + (12, 274, 5, 5, 10), + (12, 275, 5, 5, 11), + (12, 277, 5, 5, 13), + (12, 279, 5, 5, 15), + (12, 280, 5, 5, 16), + (12, 283, 5, 5, 19), + (12, 285, 5, 5, 21), + (13, 286, 1, 4, 0), + (13, 288, 1, 4, 2), + (13, 289, 1, 4, 3), + (13, 290, 1, 4, 4), + (13, 291, 1, 4, 5), + (13, 294, 1, 4, 8), + (13, 296, 1, 4, 10), + (13, 297, 1, 4, 11), + (13, 299, 1, 4, 13), + (13, 301, 1, 4, 15), + (13, 302, 1, 4, 16), + (13, 305, 1, 4, 19), + (14, 308, 5, 2, 0), + (14, 310, 5, 2, 2), + (14, 312, 5, 2, 4), + (14, 313, 5, 2, 5), + (14, 316, 5, 2, 8), + (14, 319, 5, 2, 11), + (14, 321, 5, 2, 13), + (14, 322, 5, 2, 14), + (14, 323, 5, 2, 15), + (14, 324, 5, 2, 16), + (14, 327, 5, 2, 19), + (14, 329, 5, 2, 21), + (15, 330, 3, 6, 0), + (15, 332, 3, 6, 2), + (15, 333, 3, 6, 3), + (15, 334, 3, 6, 4), + (15, 335, 3, 6, 5), + (15, 338, 3, 6, 8), + (15, 341, 3, 6, 11), + (15, 343, 3, 6, 13), + (15, 345, 3, 6, 15), + (15, 346, 3, 6, 16), + (15, 349, 3, 6, 19), + (16, 352, 2, 2, 0), + (16, 354, 2, 2, 2), + (16, 356, 2, 2, 4), + (16, 357, 2, 2, 5), + (16, 360, 2, 2, 8), + (16, 362, 2, 2, 10), + (16, 363, 2, 2, 11), + (16, 365, 2, 2, 13), + (16, 366, 2, 2, 14), + (16, 367, 2, 2, 15), + (16, 368, 2, 2, 16), + (16, 369, 2, 2, 17), + (16, 371, 2, 2, 19), + (17, 374, 5, 4, 0), + (17, 376, 5, 4, 2), + (17, 378, 5, 4, 4), + (17, 379, 5, 4, 5), + (17, 382, 5, 4, 8), + (17, 385, 5, 4, 11), + (17, 387, 5, 4, 13), + (17, 388, 5, 4, 14), + (17, 389, 5, 4, 15), + (17, 390, 5, 4, 16), + (17, 393, 5, 4, 19), + (18, 396, 1, 6, 0), + (18, 398, 1, 6, 2), + (18, 400, 1, 6, 4), + (18, 401, 1, 6, 5), + (18, 404, 1, 6, 8), + (18, 406, 1, 6, 10), + (18, 407, 1, 6, 11), + (18, 409, 1, 6, 13), + (18, 411, 1, 6, 15), + (18, 412, 1, 6, 16), + (18, 413, 1, 6, 17), + (18, 415, 1, 6, 19), + (19, 418, 4, 1, 0), + (19, 419, 4, 1, 1), + (19, 420, 4, 1, 2), + (19, 421, 4, 1, 3), + (19, 422, 4, 1, 4), + (19, 423, 4, 1, 5), + (19, 426, 4, 1, 8), + (19, 428, 4, 1, 10), + (19, 429, 4, 1, 11), + (19, 433, 4, 1, 15), + (19, 434, 4, 1, 16), + (19, 437, 4, 1, 19), + (20, 440, 4, 6, 0), + (20, 442, 4, 6, 2), + (20, 444, 4, 6, 4), + (20, 445, 4, 6, 5), + (20, 448, 4, 6, 8), + (20, 450, 4, 6, 10), + (20, 451, 4, 6, 11), + (20, 453, 4, 6, 13), + (20, 454, 4, 6, 14), + (20, 455, 4, 6, 15), + (20, 456, 4, 6, 16), + (20, 459, 4, 6, 19), + (21, 462, 1, 3, 0), + (21, 463, 1, 3, 1), + (21, 464, 1, 3, 2), + (21, 465, 1, 3, 3), + (21, 466, 1, 3, 4), + (21, 467, 1, 3, 5), + (21, 470, 1, 3, 8), + (21, 473, 1, 3, 11), + (21, 477, 1, 3, 15), + (21, 478, 1, 3, 16), + (21, 481, 1, 3, 19), + (22, 484, 6, 2, 0), + (22, 486, 6, 2, 2), + (22, 488, 6, 2, 4), + (22, 489, 6, 2, 5), + (22, 492, 6, 2, 8), + (22, 495, 6, 2, 11), + (22, 496, 6, 2, 12), + (22, 497, 6, 2, 13), + (22, 498, 6, 2, 14), + (22, 499, 6, 2, 15), + (22, 500, 6, 2, 16), + (22, 503, 6, 2, 19), + (23, 506, 2, 6, 0), + (23, 510, 2, 6, 4), + (23, 511, 2, 6, 5), + (23, 514, 2, 6, 8), + (23, 517, 2, 6, 11), + (23, 520, 2, 6, 14), + (23, 521, 2, 6, 15), + (23, 522, 2, 6, 16), + (23, 525, 2, 6, 19), + (24, 528, 3, 1, 0), + (24, 531, 3, 1, 3), + (24, 532, 3, 1, 4), + (24, 536, 3, 1, 8), + (24, 539, 3, 1, 11), + (24, 543, 3, 1, 15), + (24, 547, 3, 1, 19), + (25, 550, 6, 5, 0), + (25, 554, 6, 5, 4), + (25, 558, 6, 5, 8), + (25, 561, 6, 5, 11), + (25, 565, 6, 5, 15), + (25, 569, 6, 5, 19), + (26, 572, 1, 5, 0), + (26, 576, 1, 5, 4), + (26, 580, 1, 5, 8), + (26, 583, 1, 5, 11), + (26, 587, 1, 5, 15), + (26, 591, 1, 5, 19), + (27, 594, 5, 1, 0), + (27, 598, 5, 1, 4), + (27, 602, 5, 1, 8), + (27, 605, 5, 1, 11), + (27, 609, 5, 1, 15), + (27, 613, 5, 1, 19), + (28, 616, 3, 7, 0), + (28, 620, 3, 7, 4), + (28, 624, 3, 7, 8), + (28, 627, 3, 7, 11), + (28, 631, 3, 7, 15), + (28, 635, 3, 7, 19), + (29, 638, 1, 2, 0), + (29, 642, 1, 2, 4), + (29, 646, 1, 2, 8), + (29, 649, 1, 2, 11), + (29, 653, 1, 2, 15), + (29, 657, 1, 2, 19), + (30, 660, 6, 3, 0), + (30, 664, 6, 3, 4), + (30, 668, 6, 3, 8), + (30, 671, 6, 3, 11), + (30, 675, 6, 3, 15), + (30, 679, 6, 3, 19), + (31, 682, 1, 7, 0), + (31, 686, 1, 7, 4), + (31, 690, 1, 7, 8), + (31, 693, 1, 7, 11), + (31, 697, 1, 7, 15), + (31, 701, 1, 7, 19), + (32, 704, 4, 0, 0), + (32, 708, 4, 0, 4), + (32, 712, 4, 0, 8), + (32, 715, 4, 0, 11), + (32, 719, 4, 0, 15), + (32, 723, 4, 0, 19), + (33, 726, 5, 6, 0), + (33, 730, 5, 6, 4), + (33, 734, 5, 6, 8), + (33, 737, 5, 6, 11), + (33, 741, 5, 6, 15), + (33, 745, 5, 6, 19), + (34, 748, 0, 4, 0), + (34, 752, 0, 4, 4), + (34, 756, 0, 4, 8), + (34, 759, 0, 4, 11), + (34, 763, 0, 4, 15), + (34, 767, 0, 4, 19), + (35, 770, 6, 1, 0), + (35, 774, 6, 1, 4), + (35, 778, 6, 1, 8), + (35, 781, 6, 1, 11), + (35, 785, 6, 1, 15), + (35, 789, 6, 1, 19), + (36, 792, 2, 7, 0), + (36, 796, 2, 7, 4), + (36, 800, 2, 7, 8), + (36, 803, 2, 7, 11), + (36, 807, 2, 7, 15), + (36, 811, 2, 7, 19), + (37, 814, 2, 1, 0), + (37, 818, 2, 1, 4), + (37, 822, 2, 1, 8), + (37, 825, 2, 1, 11), + (37, 829, 2, 1, 15), + (37, 833, 2, 1, 19), + (38, 836, 6, 4, 0), + (38, 840, 6, 4, 4), + (38, 844, 6, 4, 8), + (38, 847, 6, 4, 11), + (38, 851, 6, 4, 15), + (38, 855, 6, 4, 19), + (39, 862, 0, 6, 4), + (39, 866, 0, 6, 8), + (39, 873, 0, 6, 15), + (39, 877, 0, 6, 19), + (40, 884, 5, 0, 4), + (40, 888, 5, 0, 8), + (40, 895, 5, 0, 15), + (40, 899, 5, 0, 19), + (41, 906, 4, 7, 4), + (41, 910, 4, 7, 8), + (41, 917, 4, 7, 15), + (41, 921, 4, 7, 19), + (42, 928, 0, 3, 4), + (42, 932, 0, 3, 8), + (42, 939, 0, 3, 15), + (42, 943, 0, 3, 19), + (43, 950, 7, 2, 4), + (43, 961, 7, 2, 15), + (44, 972, 1, 8, 4), + (44, 983, 1, 8, 15), + (45, 994, 3, 0, 4), + (45, 1005, 3, 0, 15), + (46, 1016, 6, 6, 4), + (46, 1027, 6, 6, 15), + (47, 1038, 0, 5, 4), + (47, 1049, 0, 5, 15), + (48, 1060, 6, 0, 4), + (48, 1071, 6, 0, 15), + (49, 1082, 3, 8, 4), +) diff --git a/src/graphnet/models/data_representation/images/mappings/grid_definition.py b/src/graphnet/models/data_representation/images/mappings/grid_definition.py new file mode 100644 index 000000000..046b1d9b4 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/grid_definition.py @@ -0,0 +1,465 @@ +"""Detector-specific grid layouts: lookup tables, shapes, and scatter into tensors. + +Each :class:`GridDefinition` is bound to a :class:`~graphnet.models.detector.Detector`. +:class:`~graphnet.models.data_representation.images.image_representation.ImageRepresentation` +calls :meth:`GridDefinition.forward` to place pixel rows into image tensor(s). +""" + +from abc import abstractmethod +from typing import List + +import numpy as np +import pandas as pd +import torch +from torch_geometric.data import Data + +from graphnet.models import Model +from graphnet.models.detector import Detector + +from .cnn_mapping_tables import ( + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + PROMETHEUS_CNN_MAPPING_ROWS, + build_ic86_cnn_mapping, +) + + +class GridDefinition(Model): + """Detector-specific orthonormal image grid(s). + + Holds tensor shapes and tables that map pixel keys to voxel indices. + """ + + def __init__( + self, + detector: Detector, + pixel_feature_names: List[str], + ) -> None: + """Construct `GridDefinition`. + + Args: + detector: Geometry this grid is defined for (CNN grids are + detector-specific). + pixel_feature_names: Column names expected on each pixel row, + including keys listed in :attr:`map_pixels_by`. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + self._detector = detector + self._set_image_feature_names(pixel_feature_names) + + @property + def detector(self) -> Detector: + """Detector instance this grid belongs to.""" + return self._detector + + @property + @abstractmethod + def map_pixels_by(self) -> List[str]: + """Feature columns that join pixel rows to :meth:`mappings`.""" + + @abstractmethod + def mappings(self) -> List[pd.DataFrame]: + """DataFrame(s) keyed by ``map_pixels_by`` with voxel index columns.""" + + @abstractmethod + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel features into shaped tensors (see subclasses).""" + + @abstractmethod + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final image feature names.""" + raise NotImplementedError + + @property + @abstractmethod + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples. + + In the dimensions (F,D,H,W) where F is the number of features + per pixel. And D,H,W are the dimension of the image + """ + pass + + +class IC86GridDefinition(GridDefinition): + """IceCube-86 CNN grid (main array + DeepCore) layouts and lookup table.""" + + def __init__( + self, + detector: Detector, + dtype: torch.dtype, + pixel_feature_names: List[str], + string_label: str = "string", + dom_number_label: str = "dom_number", + include_main_array: bool = True, + include_lower_dc: bool = True, + include_upper_dc: bool = True, + ) -> None: + """Construct `IC86GridDefinition`. + + The mapping from (string, dom_number) to a position in the + resulting images is generated programmatically at instantiation + time from the IceCube86 detector geometry, so no auxiliary file + is required. + + Args: + detector: ``IceCube86`` instance (grid is fixed to that geometry). + dtype: data type used for node features. e.g. ´torch.float´ + string_label: Name of the feature corresponding + to the DOM string number. Values Integers between 1 - 86 + dom_number_label: Name of the feature corresponding + to the DOM number (1 - 60). Values Integers between 1 - 60 + where 1 is the dom with the highest z coordinate. + pixel_feature_names: Names of each column in expected input data + that will be built into a image. + include_main_array: If True, the main array will be included. + include_lower_dc: If True, the lower DeepCore will be included. + include_upper_dc: If True, the upper DeepCore will be included. + + Raises: + ValueError: If no array type is included. + + NOTE: Expects input data to be DOMs with aggregated features. + """ + if not np.any( + [include_main_array, include_lower_dc, include_upper_dc] + ): + raise ValueError("Include at least one array type.") + + self._dtype = dtype + self._string_label = string_label + self._dom_number_label = dom_number_label + self._pixel_feature_names = pixel_feature_names + + self._set_indices(pixel_feature_names, dom_number_label, string_label) + + self._nb_cnn_features = ( + len(pixel_feature_names) - 2 + ) # 2 for string and dom_number + + self._include_main_array = include_main_array + self._include_lower_dc = include_lower_dc + self._include_upper_dc = include_upper_dc + + self._mapping = build_ic86_cnn_mapping( + string_label=string_label, + dom_number_label=dom_number_label, + ) + super().__init__( + detector=detector, pixel_feature_names=pixel_feature_names + ) + + @property + def map_pixels_by(self) -> List[str]: + """String and DOM identifiers used for voxel lookup.""" + return [self._string_label, self._dom_number_label] + + def mappings(self) -> List[pd.DataFrame]: + """Return the single combined lookup table for all sub-images.""" + return [self._mapping] + + def _set_indices( + self, + feature_names: List[str], + dom_number_label: str, + string_label: str, + ) -> None: + """Set the indices for the features.""" + self._cnn_features_idx = [] + for feature in feature_names: + if feature == dom_number_label: + self._dom_number_idx = feature_names.index(feature) + elif feature == string_label: + self._string_idx = feature_names.index(feature) + else: + self._cnn_features_idx.append(feature_names.index(feature)) + + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel rows into IceCube-86 image tensor(s).""" + # Initialize output arrays + if self._include_main_array: + main_arr = torch.zeros( + (self._nb_cnn_features, 10, 10, 60), + dtype=self._dtype, + ) + if self._include_upper_dc: + upper_dc_arr = torch.zeros( + (self._nb_cnn_features, 8, 10), + dtype=self._dtype, + ) + if self._include_lower_dc: + lower_dc_arr = torch.zeros( + (self._nb_cnn_features, 8, 50), + dtype=self._dtype, + ) + + # data.x is a tensor with shape (N, F) where N is the number of + # pixels (DOMs) and F is the number of features. Each row + # represents a single pixel. + x = data.x + + string_dom_number = x[ + :, [self._string_idx, self._dom_number_idx] + ].int() + batch_row_features = x[:, self._cnn_features_idx] + + # Look up the pixel position in each sub-image for every (string, + # dom_number) pair. Columns are referenced via the configurable + # `string_label`/`dom_number_label` and the internal axis-column + # constants so the lookup does not depend on free-form strings. + match_indices = self._mapping.loc[ + zip(*string_dom_number.t().tolist()) + ][ + [ + self._string_label, + self._dom_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ] + ].values.astype( + int + ) + + # Copy CNN features to the appropriate arrays + for i, row in enumerate(match_indices): + # Select appropriate array and indexing + if row[0] < 79: # Main Array + if self._include_main_array: + main_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + row[4], # mat_ax2 + ] = batch_row_features[i] + + elif row[1] < 11: # Upper DeepCore + if self._include_upper_dc: + upper_dc_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + ] = batch_row_features[i] + + else: # Lower DeepCore + if self._include_lower_dc: + lower_dc_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + ] = batch_row_features[i] + + # unsqueeze to add dimension for batching + # with collate_fn Batch.from_data_list + ret: List[torch.Tensor] = [] + if self._include_main_array: + ret.append(main_arr.unsqueeze(0)) + if self._include_upper_dc: + ret.append(upper_dc_arr.unsqueeze(0)) + if self._include_lower_dc: + ret.append(lower_dc_arr.unsqueeze(0)) + + # Set list of images as data.x + data.x = ret + return data + + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final output feature names.""" + # string and dom_number are only used for mapping + # and will not be included in the output features. + self.image_feature_names = [ + infeature + for infeature in input_feature_names + if infeature not in [self._string_label, self._dom_number_label] + ] + + @property + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples.""" + ret = [] + if self._include_main_array: + ret.append([self._nb_cnn_features, 10, 10, 60]) + if self._include_upper_dc: + ret.append([self._nb_cnn_features, 1, 8, 10]) + if self._include_lower_dc: + ret.append([self._nb_cnn_features, 1, 8, 50]) + return ret + + +class ExamplePrometheusGridDefinition(GridDefinition): + """Example single-image grid for Prometheus-style layouts.""" + + def __init__( + self, + detector: Detector, + dtype: torch.dtype, + pixel_feature_names: List[str], + string_label: str = "sensor_string_id", + sensor_number_label: str = "sensor_id", + ) -> None: + """Construct grid. + + Args: + detector: Typically ``ORCA150`` in the example scripts. + dtype: data type used for node features. e.g. ´torch.float´ + string_label: Name of the feature corresponding + to the sensor string number. + sensor_number_label: Name of the feature corresponding + to the sensor number + pixel_feature_names: Names of each column in expected input data + that will be built into a image. + + Raises: + ValueError: If no array type is included. + + NOTE: Expects input data to be sensors with aggregated features. + """ + self._dtype = dtype + self._string_label = string_label + self._sensor_number_label = sensor_number_label + self._pixel_feature_names = pixel_feature_names + + self._set_indices( + pixel_feature_names, sensor_number_label, string_label + ) + + self._nb_cnn_features = ( + len(pixel_feature_names) - 2 + ) # 2 for string and sensor number + + # Hand-crafted layout for the example ORCA image; stored as an + # embedded table (was ``prometheus_CNN_mapping.parquet``). + df = pd.DataFrame( + PROMETHEUS_CNN_MAPPING_ROWS, + columns=[ + "sensor_string_id", + "sensor_id", + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ], + ) + df.rename( + columns={ + "sensor_string_id": string_label, + "sensor_id": sensor_number_label, + }, + inplace=True, + ) + df.sort_values( + by=[string_label, sensor_number_label], + ascending=[True, True], + inplace=True, + ) + + df.set_index( + [string_label, sensor_number_label], + inplace=True, + drop=False, + ) + + self._mapping = df + super().__init__( + detector=detector, pixel_feature_names=pixel_feature_names + ) + + @property + def map_pixels_by(self) -> List[str]: + """String and sensor identifiers used for voxel lookup.""" + return [self._string_label, self._sensor_number_label] + + def mappings(self) -> List[pd.DataFrame]: + """Return the embedded lookup table for the example layout.""" + return [self._mapping] + + def _set_indices( + self, + feature_names: List[str], + sensor_number_label: str, + string_label: str, + ) -> None: + """Set the indices for the features.""" + self._cnn_features_idx = [] + for feature in feature_names: + if feature == sensor_number_label: + self._sensor_number_idx = feature_names.index(feature) + elif feature == string_label: + self._string_idx = feature_names.index(feature) + else: + self._cnn_features_idx.append(feature_names.index(feature)) + + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel rows into the example 3D image tensor.""" + # Initialize output arrays + image_tensor = torch.zeros( + (self._nb_cnn_features, 8, 9, 22), + dtype=self._dtype, + ) + + # data.x is expected to be a tensor with shape (N, F) + # where N is the number of nodes and F is the number of features. + x = data.x + + # Direct coordinate and feature extraction + string_sensor_number = x[ + :, [self._string_idx, self._sensor_number_idx] + ].int() + batch_row_features = x[:, self._cnn_features_idx] + + # Look up the pixel position in the image for every (string, + # sensor_id) pair. Column references go through the configurable + # labels and the internal axis-column constants so this method + # does not rely on hard-coded column names from the data file. + match_indices = self._mapping.loc[ + zip(*string_sensor_number.t().tolist()) + ][ + [ + self._string_label, + self._sensor_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ] + ].values.astype( + int + ) + + # Copy CNN features to the appropriate arrays + for i, row in enumerate(match_indices): + # Select appropriate array and indexing + image_tensor[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + row[4], # mat_ax2 + ] = batch_row_features[i] + + # unsqueeze to add dimension for batching + # with collate_fn Batch.from_data_list + ret: List[torch.Tensor] = [image_tensor.unsqueeze(0)] + + # Set list of images as data.x + data.x = ret + return data + + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final output feature names.""" + # string and sensor_number are only used for mapping + # and will not be included in the output features. + self.image_feature_names = [ + infeature + for infeature in input_feature_names + if infeature not in [self._string_label, self._sensor_number_label] + ] + + @property + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples.""" + return [[self._nb_cnn_features, 8, 9, 22]] diff --git a/tests/models/ic86_grid_testdata.py b/tests/models/ic86_grid_testdata.py new file mode 100644 index 000000000..af204315e --- /dev/null +++ b/tests/models/ic86_grid_testdata.py @@ -0,0 +1,26 @@ +"""Synthetic IC86 pixel tables for CNN grid tests (no parquet).""" + +import numpy as np + + +def ic86_full_detector_pixel_table() -> np.ndarray: + """Return all IceCube-86 DOM rows with redundant id columns. + + Columns are ``string``, ``dom_number``, ``redundant_string``, + ``redundant_dom_number`` (``redundant_*`` copy ``string`` / ``dom_number``), + in row-major order over strings ``1..86`` and DOMs ``1..60``. + + Shape is ``(5160, 4)`` with ``float32`` values, matching the tables that + were previously loaded from packaged parquet. + """ + s = np.repeat(np.arange(1, 87, dtype=np.float32), 60) + d = np.tile(np.arange(1, 61, dtype=np.float32), 86) + return np.column_stack([s, d, s.copy(), d.copy()]) + + +IC86_TEST_PIXEL_COLUMNS = [ + "string", + "dom_number", + "redundant_string", + "redundant_dom_number", +] diff --git a/tests/models/test_grid_definition.py b/tests/models/test_grid_definition.py new file mode 100644 index 000000000..51079cc5c --- /dev/null +++ b/tests/models/test_grid_definition.py @@ -0,0 +1,185 @@ +"""Tests for IceCube-86 :class:`IC86GridDefinition`.""" + +import numpy as np +import torch +from torch_geometric.data import Data +from copy import deepcopy +from graphnet.models.data_representation.images import IC86GridDefinition +from graphnet.models.detector import IceCube86 +from graphnet.constants import ( + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, +) +import pytest + +from tests.models.ic86_grid_testdata import ( + IC86_TEST_PIXEL_COLUMNS, + ic86_full_detector_pixel_table, +) + + +def basic_checks_picture(picture: Data, dtype: torch.dtype) -> None: + """Basic checks for grid scatter output.""" + assert isinstance( + picture, Data + ), f"Output should be a Data object got {type(picture)}" + assert isinstance( + picture.x, list + ), f"x should be a list of tensors got {type(picture.x)}" + assert np.all( + [isinstance(picture.x[i], torch.Tensor) for i in range(len(picture.x))] + ), ( + "All tensors in x should be torch.Tensors", + f"got {[type(picture.x[i]) for i in range(len(picture.x))]}", + ) + assert np.all( + [picture.x[i].dtype == dtype for i in range(len(picture.x))] + ), ( + "All tensors in x should have the dtype specified on the grid", + f"got {[picture.x[i].dtype for i in range(len(picture.x))]}", + ) + + +def test_ic86_grid_definition() -> None: + """End-to-end scatter for IC86 main + DeepCore grids.""" + dtype = torch.float32 + pixel_feature_names = ["string", "dom_number", "data1", "data2"] + string_label = "string" + dom_number_label = "dom_number" + + dummy_data = Data( + x=torch.tensor( + [[1, 2, 5.8, 1e-4], [79, 46, 3.7, 1e-18], [84, 9, 6.87, 2e5]], + dtype=dtype, + ), + ) + + detector = IceCube86(replace_with_identity=pixel_feature_names) + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + pixel_feature_names=pixel_feature_names, + string_label=string_label, + dom_number_label=dom_number_label, + include_lower_dc=True, + include_upper_dc=True, + ) + + picture = grid_definition(dummy_data, pixel_feature_names) + new_features = grid_definition.image_feature_names + n_features = len(new_features) + + basic_checks_picture(picture, dtype) + + assert ( + len(grid_definition.shape) == 3 + ), f"Expected shape to be 3 got {len(grid_definition.shape)}" + assert grid_definition.shape == [ + [n_features, 10, 10, 60], + [n_features, 1, 8, 10], + [n_features, 1, 8, 50], + ], ( + f"Expected shape to be [[{n_features},10,10,60], " + f"[{n_features},1,8,10], [{n_features},1,8,50]] got " + f"{grid_definition.shape}" + ) + assert isinstance( + new_features, list + ), f"Output should be a list of feature names got {type(new_features)}" + assert new_features == [ + "data1", + "data2", + ], f"Expected feature to be ['data1', 'data2'] names got: {new_features}" + assert len(picture.x) == 3, ( + "There should be three tensors in x ", + f"got list with length {len(picture.x)}" + "(main array, upper DeepCore, lower DeepCore)", + ) + assert picture.x[0].size() == torch.Size( + [1, 2, 10, 10, 60] + ), f"Main array should have shape (1,2,10,10,60) got {picture.x[0].size()}" + assert picture.x[1].size() == torch.Size( + [1, 2, 8, 10] + ), f"upper DeepCore should have shape (1,2,8,10) got {picture.x[1].size()}" + assert picture.x[2].size() == torch.Size( + [1, 2, 8, 50] + ), f"lower DeepCore should have shape (1,2,8,50) got {picture.x[2].size()}" + assert not torch.all( + picture.x[0] == 0 + ), "Main array should not be all zeros, got all zeros." + assert not torch.all( + picture.x[1] == 0 + ), "Upper DeepCore should not be all zeros, got all zeros." + assert not torch.all( + picture.x[2] == 0 + ), "Lower DeepCore should not be all zeros, got all zeros." + + dummy_data = Data( + x=torch.tensor( + [ + [100, 5, 5.8, 1e-4], + [54, 230, 3.7, 1e-18], + [1294, 500, 6.87, 2e5], + ], + dtype=dtype, + ), + ) + + with pytest.raises(KeyError): + grid_definition(dummy_data, pixel_feature_names) + + +def test_ic86_grid_segments() -> None: + """IC86 grid: single sub-image at a time vs reference arrays.""" + dtype = torch.float32 + string_label = "string" + dom_number_label = "dom_number" + pixel_feature_names = IC86_TEST_PIXEL_COLUMNS + + grid_tensor = Data( + x=torch.tensor(ic86_full_detector_pixel_table(), dtype=dtype) + ) + + detector = IceCube86(replace_with_identity=pixel_feature_names) + + for image, inc_main, inc_upc, inc_lowdc, label in zip( + [TEST_IC86MAIN_IMAGE, TEST_IC86UPPERDC_IMAGE, TEST_IC86LOWERDC_IMAGE], + [True, False, False], + [False, True, False], + [False, False, True], + ["main array", "upper deepcore", "lower deepcore"], + ): + tmp = deepcopy(grid_tensor) + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + pixel_feature_names=pixel_feature_names, + string_label=string_label, + dom_number_label=dom_number_label, + include_main_array=inc_main, + include_lower_dc=inc_lowdc, + include_upper_dc=inc_upc, + ) + picture = grid_definition(tmp, pixel_feature_names) + tensor_image: torch.Tensor = torch.tensor( + np.load(image), dtype=dtype + ).unsqueeze(0) + + basic_checks_picture(picture, dtype) + + assert len(picture.x) == 1, ( + "There should be one tensor in x ", + f"got list with length {len(picture.x)}", + ) + assert picture.x[0].size() == tensor_image.size(), ( + f"{label} should have shape {tensor_image.size()} " + f"got {picture.x[0].size()}" + ) + assert not torch.all( + picture.x[0] == 0 + ), f"{label} should not be all zeros, got all zeros." + assert torch.equal(tensor_image, picture.x[0]), ( + f"{label} should match the expected" + " main array from IC86 DNN mapping." + ) diff --git a/tests/models/test_image_representation.py b/tests/models/test_image_representation.py new file mode 100644 index 000000000..7cda8eaab --- /dev/null +++ b/tests/models/test_image_representation.py @@ -0,0 +1,88 @@ +from graphnet.models.data_representation import IC86Image +from graphnet.models.data_representation import NodesAsPulses +from graphnet.models.detector import IceCube86 +import torch +from torch_geometric.data import Data +import numpy as np +from graphnet.constants import ( + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, +) + +from tests.models.ic86_grid_testdata import ( + IC86_TEST_PIXEL_COLUMNS, + ic86_full_detector_pixel_table, +) + + +def test_ic86_image_representation() -> None: + """Pipeline IC86Image: pixel definition + grid → list of tensors.""" + dtype = torch.float32 + columns = IC86_TEST_PIXEL_COLUMNS + pixel_table = ic86_full_detector_pixel_table() + + pixel_def = NodesAsPulses( + input_feature_names=columns, + ) + + detector = IceCube86(replace_with_identity=columns) + + image_representation = IC86Image( + pixel_definition=pixel_def, + input_feature_names=columns, + include_lower_dc=True, + include_upper_dc=True, + string_label="string", + dom_number_label="dom_number", + dtype=dtype, + detector=detector, + ) + + assert ( + image_representation.nb_outputs == 2 + ), "Expected 2 outputs, got {}".format(image_representation.nb_outputs) + + output_feature_names = columns.copy() + output_feature_names.remove("string") + output_feature_names.remove("dom_number") + + assert image_representation.output_feature_names == output_feature_names, ( + f"Output feature names do not match expected output: " + f"{image_representation.output_feature_names} != {output_feature_names}" + ) + + image = image_representation( + pixel_table, + input_feature_names=columns, + ) + + assert isinstance( + image, Data + ), "Expected output to be a torch_geometric.data.Data object" + assert isinstance(image.x, list), "Expected image.x to be a list" + assert np.all( + [isinstance(x, torch.Tensor) for x in image.x] + ), "Expected all elements in image.x to be torch.Tensor" + assert ( + len(image.x) == 3 + ), "Expected image.x to have 3 elements, got {}".format(len(image.x)) + assert ( + "num_nodes" in image.keys() + ), "Expected 'num_nodes' in image attributes" + + image_list = [ + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, + ] + for i, img in enumerate(image_list): + expected_image = torch.tensor(np.load(img), dtype=dtype).unsqueeze(0) + assert image.x[i].size() == expected_image.size(), ( + f"Image at index {i} size mismatch: " + f"expected {torch.tensor(expected_image).size()}," + f"got {image.x[i].size()}" + ) + assert torch.equal( + image.x[i], expected_image + ), f"Image at index {i} does not match expected image"