Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/graphnet/models/cnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""CNN-specific modules, for performing the main learnable operations."""

from .cnn import CNN
from .icecube_dnn import IceCubeDNN
35 changes: 35 additions & 0 deletions src/graphnet/models/cnn/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Base CNN-specific `Model` class(es)."""

from abc import abstractmethod

from torch import Tensor
from torch_geometric.data import Data

from graphnet.models import Model


class CNN(Model):
"""Base class for all core CNN models in graphnet."""

def __init__(self, nb_inputs: int, nb_outputs: int) -> None:
"""Construct `CNN`."""
# Base class constructor
super().__init__()

# Member variables
self._nb_inputs = nb_inputs
self._nb_outputs = nb_outputs

@property
def nb_inputs(self) -> int:
"""Return number of input features."""
return self._nb_inputs

@property
def nb_outputs(self) -> int:
"""Return number of output features."""
return self._nb_outputs

@abstractmethod
def forward(self, data: Data) -> Tensor:
"""Apply learnable forward pass in model."""
Loading
Loading