From bce64b26525ae384d6ef393b665132fc1cf5beca Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Mon, 11 May 2026 21:34:11 -0400 Subject: [PATCH 01/18] readme upd --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1e8dca1..aeb9b6c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ ![LOGO](https://raw.githubusercontent.com/Yehor-Mishchyriak/microbiome2function/main/assets/M2F_banner.png) +[![Test](https://github.com/Yehor-Mishchyriak/microbiome2function/actions/workflows/test.yml/badge.svg)](https://github.com/Yehor-Mishchyriak/microbiome2function/actions/workflows/test.yml) # microbiome2function (M2F) From 975efe4474f09ce3197518da937358a5d49af041 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Mon, 8 Jun 2026 03:27:30 -0400 Subject: [PATCH 02/18] added GATConv --- src/M2F/__init__.py | 3 +- src/M2F/gnn.py | 350 +++++++++++++++++++++++++++++++++++++++++++- tests/test_gnn.py | 43 +++++- 3 files changed, 393 insertions(+), 3 deletions(-) diff --git a/src/M2F/__init__.py b/src/M2F/__init__.py index 072a68d..f0d4cdc 100644 --- a/src/M2F/__init__.py +++ b/src/M2F/__init__.py @@ -39,7 +39,7 @@ # models from .ffnn import FFNN -from .gnn import GraphConv, GraphConvNodeClassifier +from .gnn import GraphConv, GraphConvNodeClassifier, GATNodeClassifier # metrics from .testing_utils import accuracy, recall, precision, f1 @@ -90,6 +90,7 @@ "FFNN", "GraphConv", "GraphConvNodeClassifier", + "GATNodeClassifier", # metrics "accuracy", "recall", diff --git a/src/M2F/gnn.py b/src/M2F/gnn.py index ff13000..6a6dcb9 100644 --- a/src/M2F/gnn.py +++ b/src/M2F/gnn.py @@ -1,5 +1,5 @@ # third party -from torch_geometric.nn import MessagePassing +from torch_geometric.nn import GATConv, MessagePassing from torch.nn import Dropout, Linear, Module from torch.nn.functional import relu, sigmoid from torch.optim.lr_scheduler import ExponentialLR @@ -445,9 +445,357 @@ def test(self, test: NeighborLoader, *, threshold: float = 0.5) -> dict[str, flo return metrics +class GATNodeClassifier(Module): + """ + Represent the `GATNodeClassifier` type. + """ + def __init__(self, + in_dim: int, + edge_dim: int, + msg_dim: int, + state_dim: int, + out_dim: int, + *, + heads: int = 1, + attention_dropout_p: float = 0.0, + dropout_p: float = 0.5): + """ + Initialize a `GATNodeClassifier` instance. + + Args: + in_dim: Input value for `in_dim`. + edge_dim: Input value for `edge_dim`. + msg_dim: Input value for `msg_dim`. + state_dim: Input value for `state_dim`. + out_dim: Input value for `out_dim`. + heads: Input value for `heads`. + attention_dropout_p: Input value for `attention_dropout_p`. + dropout_p: Input value for `dropout_p`. + """ + super().__init__() + if edge_dim < 0: + raise ValueError("`edge_dim` must be >= 0") + if heads < 1: + raise ValueError("`heads` must be >= 1") + if state_dim % heads != 0: + raise ValueError("`state_dim` must be divisible by `heads`") + if not (0.0 <= attention_dropout_p <= 1.0): + raise ValueError("`attention_dropout_p` must be in [0, 1]") + + gat_out_dim = state_dim // heads + gat_edge_dim = edge_dim if edge_dim > 0 else None + self.edge_dim = edge_dim + self.msg_dim = msg_dim + self.conv1 = GATConv( + in_channels=in_dim, + out_channels=gat_out_dim, + heads=heads, + concat=True, + edge_dim=gat_edge_dim, + dropout=attention_dropout_p, + ) + self.conv2 = GATConv( + in_channels=state_dim, + out_channels=gat_out_dim, + heads=heads, + concat=True, + edge_dim=gat_edge_dim, + dropout=attention_dropout_p, + ) + self.lin = Linear(state_dim, out_dim) + self.dropout = Dropout(p=dropout_p) + + def _forward_logits(self, x, edge_index, edge_attr): + """ + Execute `forward logits`. + + Args: + x: Input value for `x`. + edge_index: Input value for `edge_index`. + edge_attr: Input value for `edge_attr`. + """ + gat_edge_attr = edge_attr if self.edge_dim > 0 else None + h = self.conv1(x, edge_index, edge_attr=gat_edge_attr) + h = relu(h) + h = self.dropout(h) + h = self.conv2(h, edge_index, edge_attr=gat_edge_attr) + return self.lin(h) + + def forward(self, x, edge_index, edge_attr): + """ + Run forward propagation for `GATNodeClassifier`. + + Args: + x: Input value for `x`. + edge_index: Input value for `edge_index`. + edge_attr: Input value for `edge_attr`. + """ + out = self._forward_logits(x, edge_index, edge_attr) + if self.training: + return out + return sigmoid(out) + + def fit(self, + train: NeighborLoader, + val: NeighborLoader, + epochs: int, + early_stopping: bool = True, + save_model_to: Path | str | None = None, + *, + tolerance: int = 5, + optimizer=None, + optimizer_kwargs: dict = None, + lr_sched=None, + lr_sched_kwargs: dict = None, + report_performance_every_kth_epoch: int = 10): + """ + Fit the current object. + + Args: + train: Input value for `train`. + val: Input value for `val`. + epochs: Input value for `epochs`. + early_stopping: Input value for `early_stopping`. + save_model_to: Input value for `save_model_to`. + tolerance: Input value for `tolerance`. + optimizer: Input value for `optimizer`. + optimizer_kwargs: Input value for `optimizer_kwargs`. + lr_sched: Input value for `lr_sched`. + lr_sched_kwargs: Input value for `lr_sched_kwargs`. + report_performance_every_kth_epoch: Input value for `report_performance_every_kth_epoch`. + """ + if epochs < 1: + raise ValueError("`epochs` must be >= 1") + if report_performance_every_kth_epoch < 1: + raise ValueError("`report_performance_every_kth_epoch` must be >= 1") + if tolerance < 0: + raise ValueError("`tolerance` must be >= 0") + + k = report_performance_every_kth_epoch + save_model_to = Path(save_model_to if save_model_to is not None else os.getcwd()) + save_model_to.mkdir(parents=True, exist_ok=True) + + device = next(self.parameters()).device # note, need to take `next` of `self.parameters()` because it is an iterator + criterion = torch.nn.BCEWithLogitsLoss() + _logger.info( + "Starting GAT fit (epochs=%d, early_stopping=%s, tolerance=%d, device=%s, save_dir=%s)", + epochs, + early_stopping, + tolerance, + device, + save_model_to, + ) + + # ------------------------------- optimizer ------------------------------- + if optimizer is None: + optimizer = torch.optim.Adam(params=self.parameters(), lr=1e-3, weight_decay=1e-4) + elif not isinstance(optimizer, torch.optim.Optimizer): + kwargs = dict(optimizer_kwargs or {}) + optimizer = optimizer(params=self.parameters(), **kwargs) + # ------------------------------------------------------------------------- + + # ------------------------------- scheduler ------------------------------- + if lr_sched is None: + lr_sched = ExponentialLR(optimizer=optimizer, gamma=0.99) + elif isinstance(lr_sched, type): + kwargs = dict(lr_sched_kwargs or {}) + kwargs.setdefault("optimizer", optimizer) + lr_sched = lr_sched(**kwargs) + elif not hasattr(lr_sched, "step"): + raise TypeError("`lr_sched` must be a scheduler instance or scheduler class.") + # ------------------------------------------------------------------------- + + no_generalization_after = 0 + best_val_loss = float("inf") + best_model_path: Path | None = None + history: list[dict[str, float | int]] = [] + + for epoch in range(1, epochs + 1): + # ------------------------------- train ------------------------------- + self.train() + train_loss_sum = 0.0 + train_acc_sum = 0.0 + train_recall_sum = 0.0 + train_examples = 0 + + for batch in train: + batch = batch.to(device) + batch_size = int(getattr(batch, "batch_size", batch.y.size(0))) + if batch_size == 0: + continue + + mask = torch.zeros(batch.y.size(0), dtype=torch.bool, device=device) + mask[:batch_size] = True + + optimizer.zero_grad() + logits = self._forward_logits(batch.x, batch.edge_index, batch.edge_attr) + y = batch.y.float() + loss = criterion(logits[mask], y[mask]) + loss.backward() + optimizer.step() + + with torch.no_grad(): + train_loss_sum += float(loss.item()) * batch_size + train_acc_sum += accuracy(logits, y, mask) * batch_size + train_recall_sum += recall(logits, y, mask) * batch_size + train_examples += batch_size + + if train_examples == 0: + raise RuntimeError("Train loader produced no batches with seed nodes.") + + train_loss = train_loss_sum / train_examples + train_acc = train_acc_sum / train_examples + train_recall = train_recall_sum / train_examples + # ------------------------------------------------------------------- + + # -------------------------------- val ------------------------------ + self.eval() + val_loss_sum = 0.0 + val_acc_sum = 0.0 + val_recall_sum = 0.0 + val_examples = 0 + with torch.no_grad(): + for batch in val: + batch = batch.to(device) + batch_size = int(getattr(batch, "batch_size", batch.y.size(0))) + if batch_size == 0: + continue + + mask = torch.zeros(batch.y.size(0), dtype=torch.bool, device=device) + mask[:batch_size] = True + + logits = self._forward_logits(batch.x, batch.edge_index, batch.edge_attr) + y = batch.y.float() + loss = criterion(logits[mask], y[mask]) + + val_loss_sum += float(loss.item()) * batch_size + val_acc_sum += accuracy(logits, y, mask) * batch_size + val_recall_sum += recall(logits, y, mask) * batch_size + val_examples += batch_size + + if val_examples == 0: + raise RuntimeError("Validation loader produced no batches with seed nodes.") + + current_val_loss = val_loss_sum / val_examples + val_acc = val_acc_sum / val_examples + val_recall = val_recall_sum / val_examples + # ------------------------------------------------------------------- + + # -------------------------- scheduler + early stop ------------------ + try: + lr_sched.step(current_val_loss) + except TypeError: + lr_sched.step() + + improved = current_val_loss < best_val_loss + if improved: + best_val_loss = current_val_loss + no_generalization_after = 0 + best_model_path = save_model_to / f"m2f_gat_{current_time()}.pt" + torch.save(self.state_dict(), best_model_path) + _logger.debug( + "New best validation loss %.6f at epoch %d; saved checkpoint to %s", + best_val_loss, + epoch, + best_model_path, + ) + else: + no_generalization_after += 1 + + history.append({ + "epoch": epoch, + "train_loss": train_loss, + "train_acc": train_acc, + "train_recall": train_recall, + "val_loss": current_val_loss, + "val_acc": val_acc, + "val_recall": val_recall, + }) + + if epoch == 1 or epoch % k == 0: + _logger.info( + "Epoch %d | train_loss=%.6f train_acc=%.4f train_recall=%.4f | " + "val_loss=%.6f val_acc=%.4f val_recall=%.4f", + epoch, train_loss, train_acc, train_recall, current_val_loss, val_acc, val_recall + ) + + if early_stopping and no_generalization_after > tolerance: + _logger.info( + "No validation improvement for %d epoch(s). Stopping early.", + no_generalization_after + ) + break + # ------------------------------------------------------------------- + + out = { + "best_val_loss": best_val_loss, + "best_model_path": str(best_model_path) if best_model_path is not None else None, + "history": history, + } + _logger.info( + "Finished GAT fit (epochs_ran=%d, best_val_loss=%.6f, best_model_path=%s)", + len(history), + best_val_loss, + out["best_model_path"], + ) + return out + + def test(self, test: NeighborLoader, *, threshold: float = 0.5) -> dict[str, float]: + """ + Test the current object. + + Args: + test: Input value for `test`. + threshold: Input value for `threshold`. + """ + device = next(self.parameters()).device + criterion = torch.nn.BCEWithLogitsLoss() + _logger.info("Starting GAT test (threshold=%.3f, device=%s)", threshold, device) + + self.eval() + test_loss_sum = 0.0 + test_acc_sum = 0.0 + test_recall_sum = 0.0 + test_examples = 0 + + with torch.no_grad(): + for batch in test: + batch = batch.to(device) + batch_size = int(getattr(batch, "batch_size", batch.y.size(0))) + if batch_size == 0: + continue + + mask = torch.zeros(batch.y.size(0), dtype=torch.bool, device=device) + mask[:batch_size] = True + + logits = self._forward_logits(batch.x, batch.edge_index, batch.edge_attr) + y = batch.y.float() + loss = criterion(logits[mask], y[mask]) + + test_loss_sum += float(loss.item()) * batch_size + test_acc_sum += accuracy(logits, y, mask, threshold=threshold) * batch_size + test_recall_sum += recall(logits, y, mask, threshold=threshold) * batch_size + test_examples += batch_size + + if test_examples == 0: + raise RuntimeError("Test loader produced no batches with seed nodes.") + + metrics = { + "test_loss": test_loss_sum / test_examples, + "test_acc": test_acc_sum / test_examples, + "test_recall": test_recall_sum / test_examples, + } + _logger.info( + "Test metrics | loss=%.6f acc=%.4f recall=%.4f", + metrics["test_loss"], metrics["test_acc"], metrics["test_recall"] + ) + return metrics + + __all__ = [ "GraphConv", "GraphConvNodeClassifier", + "GATNodeClassifier", ] diff --git a/tests/test_gnn.py b/tests/test_gnn.py index e2451e9..8c9200d 100644 --- a/tests/test_gnn.py +++ b/tests/test_gnn.py @@ -9,7 +9,7 @@ sys.path.insert(0, os.path.abspath("src")) -from M2F.gnn import GraphConv, GraphConvNodeClassifier +from M2F.gnn import GATNodeClassifier, GraphConv, GraphConvNodeClassifier class TestGNN(unittest.TestCase): @@ -61,6 +61,47 @@ def test_classifier_fit_and_test(self): self.assertIn("test_acc", metrics) self.assertIn("test_recall", metrics) + def test_gat_classifier_fit_and_test(self): + model = GATNodeClassifier(in_dim=2, edge_dim=1, msg_dim=4, state_dim=4, out_dim=1, heads=2, dropout_p=0.0) + train = [self._batch()] + val = [self._batch()] + test = [self._batch()] + + logits = model._forward_logits(self._batch().x, self._batch().edge_index, self._batch().edge_attr) + self.assertEqual(logits.shape, (3, 1)) + + hist = model.fit( + train, + val, + epochs=2, + report_performance_every_kth_epoch=1, + save_model_to=self.tmp, + early_stopping=False, + ) + self.assertIn("best_val_loss", hist) + self.assertEqual(len(hist["history"]), 2) + + metrics = model.test(test) + self.assertIn("test_loss", metrics) + self.assertIn("test_acc", metrics) + self.assertIn("test_recall", metrics) + + def test_gat_classifier_accepts_empty_edge_attrs(self): + batch = self._batch() + batch.edge_attr = torch.empty((batch.edge_index.size(1), 0), dtype=torch.float32) + model = GATNodeClassifier(in_dim=2, edge_dim=0, msg_dim=4, state_dim=4, out_dim=1, dropout_p=0.0) + + logits = model._forward_logits(batch.x, batch.edge_index, batch.edge_attr) + self.assertEqual(logits.shape, (3, 1)) + + def test_gat_validation(self): + with self.assertRaises(ValueError): + GATNodeClassifier(in_dim=2, edge_dim=-1, msg_dim=4, state_dim=4, out_dim=1) + with self.assertRaises(ValueError): + GATNodeClassifier(in_dim=2, edge_dim=1, msg_dim=4, state_dim=4, out_dim=1, heads=0) + with self.assertRaises(ValueError): + GATNodeClassifier(in_dim=2, edge_dim=1, msg_dim=4, state_dim=5, out_dim=1, heads=2) + def test_fit_validation(self): model = GraphConvNodeClassifier(in_dim=2, edge_dim=1, msg_dim=4, state_dim=4, out_dim=1) train = [self._batch()] From a08656171c9bd2c8b7704673fcadb351d008ce88 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Mon, 8 Jun 2026 03:31:56 -0400 Subject: [PATCH 03/18] readme and docs upd --- README.md | 16 +++++++++++++-- docs.md | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index aeb9b6c..7560b55 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Current top-level exports include: - Cleaning: `clean_col`, `clean_cols` - Embedding/encoding: `AAChainEmbedder`, `FreeTXTEmbedder`, `MultiHotEncoder`, `GOEncoder`, `ECEncoder`, `encode_multihot`, `get_GODag` - Feature engineering/persistence: `embed_ft_domains`, `embed_AAsequences`, `embed_freetxt_cols`, `encode_go`, `encode_ec`, `empty_tuples_to_NaNs`, `save_df`, `load_df` -- Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier` +- Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier`, `GATNodeClassifier` - Metrics: `accuracy`, `recall`, `precision`, `f1` - Dataset interfaces: `DatasetInput`, `build_topology_from_DatasetInput`, `build_features_from_DatasetInput`, `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`, `ProteinDataset` - Utility namespace: `util` @@ -138,7 +138,7 @@ For graph mode, edge chunk files must exist and match the expected naming patter ```python from pathlib import Path import torch -from M2F import ProteinGraphOnDiskDataset, GraphConvNodeClassifier +from M2F import ProteinGraphOnDiskDataset, GraphConvNodeClassifier, GATNodeClassifier ds = ProteinGraphOnDiskDataset( root=Path("runs/graph_ondisk"), @@ -160,6 +160,17 @@ model = GraphConvNodeClassifier( out_dim=int(ds.meta["y_dim"]), ) +# Drop-in attention variant for the same graph/loaders: +# model = GATNodeClassifier( +# in_dim=int(ds.meta["x_dim"]), +# edge_dim=int(ds.meta["edge_attr_dim"]), +# msg_dim=128, # kept for constructor compatibility +# state_dim=128, +# out_dim=int(ds.meta["y_dim"]), +# heads=4, +# attention_dropout_p=0.1, +# ) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) @@ -221,6 +232,7 @@ dset.close() - Topology for on-disk graph datasets is built after feature processing so filtered-node reindexing is stable. - Feature shards with duplicate `Entry` rows are rejected. - Inconsistent per-row feature dimensions are rejected. +- `GATNodeClassifier` uses the same graph datasets/loaders as `GraphConvNodeClassifier`; `state_dim` must be divisible by `heads`. - `force_reload=True` rebuilds raw/processed artifacts from scratch. ## Logging diff --git a/docs.md b/docs.md index 675ecd1..f6d279e 100644 --- a/docs.md +++ b/docs.md @@ -87,7 +87,7 @@ Current exported API (`M2F.__all__`) includes: - Cleaning: `clean_col`, `clean_cols`. - Embedding / Encoding: `AAChainEmbedder`, `FreeTXTEmbedder`, `MultiHotEncoder`, `GOEncoder`, `ECEncoder`, `encode_multihot`, `get_GODag`. - Feature engineering / persistence: `embed_ft_domains`, `embed_AAsequences`, `embed_freetxt_cols`, `encode_go`, `encode_ec`, `empty_tuples_to_NaNs`, `save_df`, `load_df`. -- Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier`. +- Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier`, `GATNodeClassifier`. - Metrics: `accuracy`, `recall`, `precision`, `f1`. - Dataset interfaces: `DatasetInput`, `build_topology_from_DatasetInput`, `build_features_from_DatasetInput`, `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`, `ProteinDataset`. - Utility namespace: `util`. @@ -500,7 +500,48 @@ Implementation details worth knowing: - During neighbor sampling, only seed nodes are supervised in each batch (`batch_size` mask logic). - `fit(...)` returns `best_val_loss`, `best_model_path`, and epoch-wise `history`. -## 8.2 FFNN: `FFNN` +## 8.2 GNN Attention: `GATNodeClassifier` + +`GATNodeClassifier` is the attention-based graph model. It uses the same `ProteinGraphInMemoryDataset` / `ProteinGraphOnDiskDataset` loaders as `GraphConvNodeClassifier`. + +```python +import torch +from M2F import GATNodeClassifier + +model = GATNodeClassifier( + in_dim=128, + edge_dim=4, + msg_dim=64, # retained for constructor compatibility + state_dim=64, + out_dim=1, + heads=4, + attention_dropout_p=0.1, + dropout_p=0.3, +) + +model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) + +history = model.fit( + train=train_loader, + val=val_loader, + epochs=50, + early_stopping=True, + tolerance=5, + report_performance_every_kth_epoch=1, + save_model_to="runs/checkpoints_gat", +) + +metrics = model.test(test_loader, threshold=0.5) +print(history["best_val_loss"], metrics) +``` + +Implementation details worth knowing: +- Internally uses PyTorch Geometric `GATConv`. +- `state_dim` must be divisible by `heads` because head outputs are concatenated. +- Edge attributes are used when `edge_dim > 0`; empty edge-attribute tensors are ignored when `edge_dim=0`. +- Training, evaluation, masking, loss, and returned history match `GraphConvNodeClassifier`. + +## 8.3 FFNN: `FFNN` ```python import torch @@ -527,7 +568,7 @@ Implementation details: - Loss: `BCEWithLogitsLoss`. - `forward(...)` returns logits during training, sigmoid probabilities during eval. -## 8.3 Metrics Utilities +## 8.4 Metrics Utilities Available helpers (`M2F.testing_utils`): - `accuracy(logits, y_true, mask, threshold=0.5)` @@ -598,6 +639,7 @@ from M2F import ( DatasetInput, ProteinGraphOnDiskDataset, GraphConvNodeClassifier, + GATNodeClassifier, ) configure_logging("logs", file_level=logging.DEBUG, console_level=logging.INFO) @@ -637,6 +679,17 @@ model = GraphConvNodeClassifier( out_dim=y_dim, ) +# Drop-in attention variant for the same dataset and loaders: +# model = GATNodeClassifier( +# in_dim=x_dim, +# edge_dim=edge_dim, +# msg_dim=128, +# state_dim=128, +# out_dim=y_dim, +# heads=4, +# attention_dropout_p=0.1, +# ) + history = model.fit( train=train_loader, val=val_loader, @@ -686,7 +739,7 @@ python -m pip install dist/microbiome2function-0.1.0-py3-none-any.whl - `M2F.embedding_utils`: ESM and OpenAI embedding + GO/EC/multihot encoders. - `M2F.feature_engineering_utils`: high-level embedding wrappers + zarr zip persistence. - `M2F.pyg_data_interfaces`: graph and FFNN dataset interfaces + standalone builders. -- `M2F.gnn`: graph convolution model and training/eval loops. +- `M2F.gnn`: graph convolution, graph attention, and training/eval loops. - `M2F.ffnn`: feed-forward model and training/eval loops. - `M2F.testing_utils`: metric helpers. - `M2F.util`: utility helpers and zarr feature-store backend. From e31adbc9fb94d5819b5346e06c6fa76d1a1d0937 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:17:48 -0400 Subject: [PATCH 04/18] comment edits etc. --- src/M2F/__init__.py | 2 +- src/M2F/ffnn.py | 2 +- src/M2F/gnn.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/M2F/__init__.py b/src/M2F/__init__.py index f0d4cdc..bf8104c 100644 --- a/src/M2F/__init__.py +++ b/src/M2F/__init__.py @@ -104,5 +104,5 @@ "ProteinGraphOnDiskDataset", "ProteinDataset", # utility module namespace - "util", + "util" ] diff --git a/src/M2F/ffnn.py b/src/M2F/ffnn.py index 32198a1..07e01b1 100644 --- a/src/M2F/ffnn.py +++ b/src/M2F/ffnn.py @@ -164,7 +164,7 @@ def fit(self, loss.backward() optimizer.step() - with torch.no_grad(): + with torch.no_grad(): # each step's loss is weighted by num_examples_in_step / total_num_examples train_loss_sum += float(loss.item()) * batch_size train_acc_sum += accuracy(logits, y, mask) * batch_size train_recall_sum += recall(logits, y, mask) * batch_size diff --git a/src/M2F/gnn.py b/src/M2F/gnn.py index 6a6dcb9..b8a0e2e 100644 --- a/src/M2F/gnn.py +++ b/src/M2F/gnn.py @@ -279,6 +279,9 @@ def fit(self, mask = torch.zeros(batch.y.size(0), dtype=torch.bool, device=device) mask[:batch_size] = True + # ^ ^ ^ + # With NeighborLoader, each returned batch is a sampled subgraph. PyG puts the seed/input nodes first in the batch, then appends sampled neighbor nodes after them. + # So, it means: compute loss only on the seed nodes for this NeighborLoader batch optimizer.zero_grad() logits = self._forward_logits(batch.x, batch.edge_index, batch.edge_attr) From d78b58b1976c01481eee51380eb80ce312be4af9 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:18:11 -0400 Subject: [PATCH 05/18] made ondisk datasets me CMs --- src/M2F/pyg_data_interfaces.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/M2F/pyg_data_interfaces.py b/src/M2F/pyg_data_interfaces.py index 2bfbd2e..f57785f 100644 --- a/src/M2F/pyg_data_interfaces.py +++ b/src/M2F/pyg_data_interfaces.py @@ -1870,6 +1870,23 @@ def close(self) -> None: self.feature_store.close() self.feature_store = None + def __enter__(self): + """ + Enter the context manager. + """ + return self + + def __exit__(self, exc_type, exc, tb): + """ + Exit the context manager. + + Args: + exc_type: Input value for `exc_type`. + exc: Input value for `exc`. + tb: Input value for `tb`. + """ + self.close() + class _ProteinDatasetView(Dataset): """ @@ -2602,6 +2619,23 @@ def close(self) -> None: self.feature_store.close() self.feature_store = None + def __enter__(self): + """ + Enter the context manager. + """ + return self + + def __exit__(self, exc_type, exc, tb): + """ + Exit the context manager. + + Args: + exc_type: Input value for `exc_type`. + exc: Input value for `exc`. + tb: Input value for `tb`. + """ + self.close() + __all__ = [ "DatasetInput", From 23ff2b0a70bdb4adf8a5a813bbd25d5048e2f939 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:18:32 -0400 Subject: [PATCH 06/18] readme and docs upd --- README.md | 14 +++++--------- docs.md | 41 +++++++++++++++-------------------------- 2 files changed, 20 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 7560b55..04d8d3b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # microbiome2function (M2F) -A toolkit for turning UniProt-linked protein annotations into machine-learning datasets. +A toolkit for mining UniProt-linked protein annotations, engineering protein features, building graph/non-graph datasets, and training neural models for function prediction. M2F supports: - UniProt mining from UniRef accessions @@ -12,6 +12,9 @@ M2F supports: - Dataset interfaces for: - PyTorch Geometric GNN training (`ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`) - Plain PyTorch FFNN training (`ProteinDataset`) +- Model implementations for: + - Graph convolution and graph attention node classifiers (`GraphConvNodeClassifier`, `GATNodeClassifier`) + - Feed-forward neural networks (`FFNN`) ## Package Status @@ -229,6 +232,7 @@ dset.close() ## Important Operational Notes - `ProteinGraphOnDiskDataset` and `ProteinDataset` process features in batches and build a global node reindex map. +- `ProteinGraphOnDiskDataset` and `ProteinDataset` support `with ... as ...:` context-manager usage to release zarr handles automatically. - Topology for on-disk graph datasets is built after feature processing so filtered-node reindexing is stable. - Feature shards with duplicate `Entry` rows are rejected. - Inconsistent per-row feature dimensions are rejected. @@ -253,14 +257,6 @@ configure_logging( - `/.github/workflows/test.yml`: multi-version test matrix (3.11, 3.12) - `/.github/workflows/build.yml`: test + build distribution artifacts (`sdist`, wheel) -## Repository Layout - -- `src/M2F`: package code -- `tests`: unit tests -- `model_notebooks`: active notebooks -- `legacy_code_examples`: old examples -- `docs.md`: detailed technical guide - ## Detailed Documentation For full API behavior, data contracts, and extended cookbook usage, see: diff --git a/docs.md b/docs.md index f6d279e..1bb5fba 100644 --- a/docs.md +++ b/docs.md @@ -1,19 +1,19 @@ # M2F Documentation -This document is the implementation-accurate user guide for `microbiome2function` (M2F). -It is written against the current code under `src/M2F`. - ## 1. What M2F Is For -M2F is a practical toolkit for turning protein identifiers and UniProt annotations into ML-ready inputs. +M2F is a practical toolkit for mining protein annotations, engineering protein features, building graph/non-graph datasets, and training neural models for function prediction. Primary use-cases: - Mine UniProt features from UniRef IDs. - Clean and normalize noisy annotation text. - Convert biology fields into numeric tensors (embeddings + encodings). - Build datasets for graph and non-graph modeling: -- Graph neural networks (PyTorch Geometric): `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`. -- Feed-forward neural networks (plain PyTorch): `ProteinDataset` (features + labels, no edges). + - Graph neural networks (PyTorch Geometric): `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`. + - Feed-forward neural networks (plain PyTorch): `ProteinDataset` (features + labels, no edges). +- Train neural models: + - Graph convolution / graph attention node classifiers: `GraphConvNodeClassifier`, `GATNodeClassifier`. + - Feed-forward neural networks: `FFNN`. Design goals: - Scalable processing for large accession sets (batched UniProt mining and batched feature shards). @@ -38,10 +38,6 @@ python -m pip install --upgrade pip python -m pip install -e . ``` -Why editable install: -- Keeps imports stable (`import M2F`) while you iterate code. -- Ensures tests and notebooks run against your local working tree. - ## 2.2 Heavy Dependencies to Plan For `requirements.txt` includes large ML packages: @@ -94,8 +90,6 @@ Current exported API (`M2F.__all__`) includes: ## 4. Data Contracts You Must Respect -M2F works well only if input schemas are strict. This is intentional. - ## 4.1 Accession Index CSV Expected columns exactly: @@ -108,10 +102,6 @@ Constraints enforced by `DatasetInput.validate(...)`: - `i` must not contain duplicates. - `uniref` values must start with `UniRef90_`. -Why strict index requirements: -- All reindexing and topology construction depend on deterministic old node IDs (`i - 1`). -- Relaxed IDs would make edge mapping ambiguous and error-prone. - ## 4.2 Edge CSV Files (Graph Datasets Only) Required only when `require_graph=True` (graph interfaces). @@ -134,6 +124,7 @@ Why one chunk per source node: `DatasetInput` uses: - `X: dict[str, str]` mapping UniProt query field -> return column name. - `Y: dict[str, str]` singleton mapping UniProt query field -> return column name. +NOTE: See UniProt for query field and return field names Important: - `Y` must contain exactly one entry. @@ -141,10 +132,6 @@ Important: - `Y` key cannot be `accession`. - `accession` is always injected into `X` internally as `"Entry"`. -Why mapping instead of plain list: -- You control the semantic output names used by downstream feature builders. -- It decouples UniProt field identifiers from model-facing column names. - ## 5. Quick Start: End-to-End Patterns ## 5.1 Mining Accessions from HUMAnN @@ -157,7 +144,7 @@ all_unirefs, all_uniclusts = extract_all_accessions_from_dir("humann_outputs/") ``` Notes: -- UniRef IDs prefixed with `UNK`/`UPI` are excluded before UniProt mining because they are not queryable reliably. +- UniRef IDs prefixed with `UNK`/`UPI` are excluded before UniProt mining because they are not queryable. ## 5.2 Fetch UniProt Fields @@ -175,8 +162,8 @@ df = fetch_uniprotkb_fields( Field-name note: - `fields` values must be valid UniProt API field identifiers. -- Returned DataFrame column names can differ from query names (for example, title-cased labels). -- Your later mapping/transforms must match the actual returned column names. +- Returned DataFrame column names will likely differ from query names (that's how UniProt works, sorry). +- Your later mapping/transforms must match the actual **returned** column names. Recommended defaults for stability: - Start with moderate `request_size` (25-100). @@ -209,6 +196,8 @@ What you get: Why tuple outputs: - Deterministic multi-label representation that plugs directly into encoders. +Values passed to `col_names` may need you to implement regexes for them. + ## 5.4 Encode and Embed ```python @@ -385,7 +374,7 @@ Under the hood: - Edge attributes are attached per mini-batch via batch `e_id` lookup. Operational note: -- Call `ondisk.close()` when done to release store handles. +- Prefer `with ProteinGraphOnDiskDataset(...) as ondisk:` to release store handles automatically; otherwise call `ondisk.close()` when done. ## 6.4 Feature and Topology Builders as Standalone Functions @@ -458,7 +447,7 @@ Why separate FFNN dataset class: - Reuses robust batch-processing, filtering, reindexing, and zarr growth path. Operational note: -- Call `dset.close()` when done. +- Prefer `with ProteinDataset(...) as dset:` to release store handles automatically; otherwise call `dset.close()` when done. ## 8. Model Training Cookbook @@ -728,7 +717,7 @@ python -m pip install dist/microbiome2function-0.1.0-py3-none-any.whl - Start small: validate your `DatasetInput` and preprocessing on a tiny accession subset first. - Keep transform contracts strict: `pre_transform` must return DataFrame; `pre_filter` must return boolean mask with matching length. - Use explicit checkpoints: preserve `meta.pt`, vocab maps, and model checkpoints per experiment. -- Close on-disk datasets: call `close()` to release zarr handles after training/inference. +- Close on-disk datasets: prefer context-manager usage (`with ... as ...:`), or call `close()` to release zarr handles after training/inference. - Avoid silent schema drift: pin requested UniProt fields and return names in code, not notebooks-only state. ## 13. Module Index From 82589612a9874b3c2baa37fcf3a1a564261c5236 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:21:30 -0400 Subject: [PATCH 07/18] readme and docs upd --- docs.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs.md b/docs.md index 1bb5fba..b777539 100644 --- a/docs.md +++ b/docs.md @@ -1,3 +1,6 @@ +![LOGO](https://raw.githubusercontent.com/Yehor-Mishchyriak/microbiome2function/main/assets/M2F_banner.png) +[![Test](https://github.com/Yehor-Mishchyriak/microbiome2function/actions/workflows/test.yml/badge.svg)](https://github.com/Yehor-Mishchyriak/microbiome2function/actions/workflows/test.yml) + # M2F Documentation ## 1. What M2F Is For From d5a59168996d65bae5437b3f62f9c33a794451ec Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:22:29 -0400 Subject: [PATCH 08/18] test upd --- tests/test_pyg_data_interfaces.py | 41 +++++++++++++++---------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/test_pyg_data_interfaces.py b/tests/test_pyg_data_interfaces.py index c277551..eb83582 100644 --- a/tests/test_pyg_data_interfaces.py +++ b/tests/test_pyg_data_interfaces.py @@ -177,42 +177,41 @@ def test_protein_graph_inmemory_dataset(self): def test_protein_graph_ondisk_dataset(self): di = self._dataset_input_graph(num_feature_batches=2) with patch("M2F.pyg_data_interfaces.fetch_uniprotkb_fields", side_effect=self._fake_fetch): - ds = ProteinGraphOnDiskDataset( + with ProteinGraphOnDiskDataset( root=self.root_ondisk, dataset_input=di, force_reload=True, val_set_size=0.2, test_set_size=0.2, - ) - self.assertTrue(os.path.exists(ds.edge_index_path)) - self.assertTrue(os.path.exists(ds.id_map_path)) - self.assertTrue(os.path.exists(ds.meta_path)) - self.assertEqual(ds.meta["num_nodes"], 3) - ds.close() + ) as ds: + self.assertTrue(os.path.exists(ds.edge_index_path)) + self.assertTrue(os.path.exists(ds.id_map_path)) + self.assertTrue(os.path.exists(ds.meta_path)) + self.assertEqual(ds.meta["num_nodes"], 3) + self.assertIsNone(ds.feature_store) def test_protein_dataset_ffnn_interface(self): di = self._dataset_input_ffnn(num_feature_batches=2) with patch("M2F.pyg_data_interfaces.fetch_uniprotkb_fields", side_effect=self._fake_fetch): - ds = ProteinDataset( + with ProteinDataset( root=self.root_ffnn, dataset_input=di, force_reload=True, split="train", val_set_size=0.2, test_set_size=0.2, - ) - - self.assertGreaterEqual(len(ds), 1) - item = ds[0] - self.assertEqual(len(item), 2) # x, y - x, y = item - self.assertEqual(tuple(x.shape), (1,)) - self.assertEqual(tuple(y.shape), (1,)) - - pred_loader = ds.predict_loader(batch_size=2) - first = next(iter(pred_loader)) - self.assertEqual(first.ndim, 2) - ds.close() + ) as ds: + self.assertGreaterEqual(len(ds), 1) + item = ds[0] + self.assertEqual(len(item), 2) # x, y + x, y = item + self.assertEqual(tuple(x.shape), (1,)) + self.assertEqual(tuple(y.shape), (1,)) + + pred_loader = ds.predict_loader(batch_size=2) + first = next(iter(pred_loader)) + self.assertEqual(first.ndim, 2) + self.assertIsNone(ds.feature_store) if __name__ == "__main__": From 45687b5c69c024d47680adda2a6171a5fb068d9c Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 05:37:55 -0400 Subject: [PATCH 09/18] cleaned proj requirements --- requirements.txt | 56 +----------------------------------------------- 1 file changed, 1 insertion(+), 55 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3786fff..9c3ccdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,66 +1,12 @@ -annotated-types==0.7.0 -anyio==4.10.0 -certifi==2025.8.3 -charset-normalizer==3.4.2 -crc32c==2.7.1 -distro==1.9.0 -docopt==0.6.2 -donfig==0.8.1.post1 -et_xmlfile==2.0.0 -filelock==3.18.0 -fsspec==2025.7.0 -ftpretty==0.4.0 goatools==1.4.12 -h11==0.16.0 -hf-xet==1.1.7 -httpcore==1.0.9 -httpx==0.28.1 -huggingface-hub==0.34.3 -idna==3.10 -Jinja2==3.1.6 -jiter==0.10.0 -joblib==1.5.1 -markdown-it-py==3.0.0 -MarkupSafe==3.0.2 -mdurl==0.1.2 -mpmath==1.3.0 -networkx==3.5 -numcodecs==0.16.1 numpy==2.3.2 openai==1.99.3 -openpyxl==3.1.5 -packaging==25.0 pandas==2.3.1 -patsy==1.0.1 pyarrow==21.0.0 -pydantic==2.11.7 -pydantic_core==2.33.2 -pydot==4.0.1 -Pygments==2.19.2 -pyparsing==3.2.3 -python-dateutil==2.9.0.post0 -pytz==2025.2 -PyYAML==6.0.2 -regex==2025.7.34 requests==2.32.4 -rich==14.1.0 -safetensors==0.6.1 scikit-learn==1.7.1 -scipy==1.16.1 -setuptools==80.9.0 -six==1.17.0 -sniffio==1.3.1 -statsmodels==0.14.5 -sympy==1.14.0 -threadpoolctl==3.6.0 -tokenizers==0.21.4 torch==2.8.0 torch-geometric==2.7.0 -tqdm==4.67.1 transformers==4.55.0 -typing-inspection==0.4.1 -typing_extensions==4.14.1 -tzdata==2025.2 -urllib3==2.5.0 -xlsxwriter==3.2.5 +wandb==0.27.2 zarr==3.1.1 From d90aaf2aac37b746197f742aec0b5436976a62d1 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 08:25:59 -0400 Subject: [PATCH 10/18] added wandb module and more metrics to training and testing loops --- src/M2F/ffnn.py | 53 ++++++++++++++++++------ src/M2F/gnn.py | 105 +++++++++++++++++++++++++++++++++++++----------- src/M2F/wb.py | 43 ++++++++++++++++++++ 3 files changed, 165 insertions(+), 36 deletions(-) create mode 100644 src/M2F/wb.py diff --git a/src/M2F/ffnn.py b/src/M2F/ffnn.py index 07e01b1..1bb2e91 100644 --- a/src/M2F/ffnn.py +++ b/src/M2F/ffnn.py @@ -11,7 +11,7 @@ import os # local -from .testing_utils import accuracy, recall +from .testing_utils import accuracy, precision, recall, f1 from .util import current_time _logger = logging.getLogger(__name__) @@ -146,8 +146,10 @@ def fit(self, # ------------------------------- train ------------------------------- self.train() train_loss_sum = 0.0 - train_acc_sum = 0.0 + train_accuracy_sum = 0.0 + train_precision_sum = 0.0 train_recall_sum = 0.0 + train_f1_sum = 0.0 train_examples = 0 for X, y in train: @@ -166,23 +168,29 @@ def fit(self, with torch.no_grad(): # each step's loss is weighted by num_examples_in_step / total_num_examples train_loss_sum += float(loss.item()) * batch_size - train_acc_sum += accuracy(logits, y, mask) * batch_size + train_accuracy_sum += accuracy(logits, y, mask) * batch_size + train_precision_sum += precision(logits, y, mask) * batch_size train_recall_sum += recall(logits, y, mask) * batch_size + train_f1_sum += f1(logits, y, mask) * batch_size train_examples += batch_size if train_examples == 0: raise RuntimeError("Train loader produced no non-empty batches.") train_loss = train_loss_sum / train_examples - train_acc = train_acc_sum / train_examples + train_accuracy = train_accuracy_sum / train_examples + train_precision = train_precision_sum / train_examples train_recall = train_recall_sum / train_examples + train_f1 = train_f1_sum / train_examples # ------------------------------------------------------------------- # -------------------------------- val ------------------------------ self.eval() val_loss_sum = 0.0 - val_acc_sum = 0.0 + val_accuracy_sum = 0.0 + val_precision_sum = 0.0 val_recall_sum = 0.0 + val_f1_sum = 0.0 val_examples = 0 with torch.no_grad(): for X, y in val: @@ -197,16 +205,20 @@ def fit(self, loss = criterion(logits, y) val_loss_sum += float(loss.item()) * batch_size - val_acc_sum += accuracy(logits, y, mask) * batch_size + val_accuracy_sum += accuracy(logits, y, mask) * batch_size + val_precision_sum += precision(logits, y, mask) * batch_size val_recall_sum += recall(logits, y, mask) * batch_size + val_f1_sum += f1(logits, y, mask) * batch_size val_examples += batch_size if val_examples == 0: raise RuntimeError("Validation loader produced no non-empty batches.") current_val_loss = val_loss_sum / val_examples - val_acc = val_acc_sum / val_examples + val_accuracy = val_accuracy_sum / val_examples + val_precision = val_precision_sum / val_examples val_recall = val_recall_sum / val_examples + val_f1 = val_f1_sum / val_examples # ------------------------------------------------------------------- # -------------------------- scheduler + early stop ------------------ @@ -233,18 +245,35 @@ def fit(self, history.append({ "epoch": epoch, "train_loss": train_loss, - "train_acc": train_acc, + "train_acc": train_accuracy, + "train_accuracy": train_accuracy, + "train_precision": train_precision, "train_recall": train_recall, + "train_f1": train_f1, "val_loss": current_val_loss, - "val_acc": val_acc, + "val_acc": val_accuracy, + "val_accuracy": val_accuracy, + "val_precision": val_precision, "val_recall": val_recall, + "val_f1": val_f1, }) if epoch == 1 or epoch % k == 0: _logger.info( - "Epoch %d | train_loss=%.6f train_acc=%.4f train_recall=%.4f | " - "val_loss=%.6f val_acc=%.4f val_recall=%.4f", - epoch, train_loss, train_acc, train_recall, current_val_loss, val_acc, val_recall + "Epoch %d | train_loss=%.6f train_accuracy=%.4f train_precision=%.4f " + "train_recall=%.4f train_f1=%.4f | val_loss=%.6f val_accuracy=%.4f " + "val_precision=%.4f val_recall=%.4f val_f1=%.4f", + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, ) if early_stopping and no_generalization_after > tolerance: diff --git a/src/M2F/gnn.py b/src/M2F/gnn.py index b8a0e2e..998308d 100644 --- a/src/M2F/gnn.py +++ b/src/M2F/gnn.py @@ -11,9 +11,8 @@ import logging import os -# local from torch_geometric.loader import NeighborLoader -from .testing_utils import accuracy, recall +from .testing_utils import accuracy, precision, recall, f1 from .util import current_time _logger = logging.getLogger(__name__) @@ -267,8 +266,10 @@ def fit(self, # ------------------------------- train ------------------------------- self.train() train_loss_sum = 0.0 - train_acc_sum = 0.0 + train_accuracy_sum = 0.0 + train_precision_sum = 0.0 train_recall_sum = 0.0 + train_f1_sum = 0.0 train_examples = 0 for batch in train: @@ -292,23 +293,29 @@ def fit(self, with torch.no_grad(): train_loss_sum += float(loss.item()) * batch_size - train_acc_sum += accuracy(logits, y, mask) * batch_size + train_accuracy_sum += accuracy(logits, y, mask) * batch_size + train_precision_sum += precision(logits, y, mask) * batch_size train_recall_sum += recall(logits, y, mask) * batch_size + train_f1_sum += f1(logits, y, mask) * batch_size train_examples += batch_size if train_examples == 0: raise RuntimeError("Train loader produced no batches with seed nodes.") train_loss = train_loss_sum / train_examples - train_acc = train_acc_sum / train_examples + train_accuracy = train_accuracy_sum / train_examples + train_precision = train_precision_sum / train_examples train_recall = train_recall_sum / train_examples + train_f1 = train_f1_sum / train_examples # ------------------------------------------------------------------- # -------------------------------- val ------------------------------ self.eval() val_loss_sum = 0.0 - val_acc_sum = 0.0 + val_accuracy_sum = 0.0 + val_precision_sum = 0.0 val_recall_sum = 0.0 + val_f1_sum = 0.0 val_examples = 0 with torch.no_grad(): for batch in val: @@ -325,16 +332,20 @@ def fit(self, loss = criterion(logits[mask], y[mask]) val_loss_sum += float(loss.item()) * batch_size - val_acc_sum += accuracy(logits, y, mask) * batch_size + val_accuracy_sum += accuracy(logits, y, mask) * batch_size + val_precision_sum += precision(logits, y, mask) * batch_size val_recall_sum += recall(logits, y, mask) * batch_size + val_f1_sum += f1(logits, y, mask) * batch_size val_examples += batch_size if val_examples == 0: raise RuntimeError("Validation loader produced no batches with seed nodes.") current_val_loss = val_loss_sum / val_examples - val_acc = val_acc_sum / val_examples + val_accuracy = val_accuracy_sum / val_examples + val_precision = val_precision_sum / val_examples val_recall = val_recall_sum / val_examples + val_f1 = val_f1_sum / val_examples # ------------------------------------------------------------------- # -------------------------- scheduler + early stop ------------------ @@ -361,18 +372,35 @@ def fit(self, history.append({ "epoch": epoch, "train_loss": train_loss, - "train_acc": train_acc, + "train_acc": train_accuracy, + "train_accuracy": train_accuracy, + "train_precision": train_precision, "train_recall": train_recall, + "train_f1": train_f1, "val_loss": current_val_loss, - "val_acc": val_acc, + "val_acc": val_accuracy, + "val_accuracy": val_accuracy, + "val_precision": val_precision, "val_recall": val_recall, + "val_f1": val_f1, }) if epoch == 1 or epoch % k == 0: _logger.info( - "Epoch %d | train_loss=%.6f train_acc=%.4f train_recall=%.4f | " - "val_loss=%.6f val_acc=%.4f val_recall=%.4f", - epoch, train_loss, train_acc, train_recall, current_val_loss, val_acc, val_recall + "Epoch %d | train_loss=%.6f train_accuracy=%.4f train_precision=%.4f " + "train_recall=%.4f train_f1=%.4f | val_loss=%.6f val_accuracy=%.4f " + "val_precision=%.4f val_recall=%.4f val_f1=%.4f", + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, ) if early_stopping and no_generalization_after > tolerance: @@ -617,8 +645,10 @@ def fit(self, # ------------------------------- train ------------------------------- self.train() train_loss_sum = 0.0 - train_acc_sum = 0.0 + train_accuracy_sum = 0.0 + train_precision_sum = 0.0 train_recall_sum = 0.0 + train_f1_sum = 0.0 train_examples = 0 for batch in train: @@ -639,23 +669,29 @@ def fit(self, with torch.no_grad(): train_loss_sum += float(loss.item()) * batch_size - train_acc_sum += accuracy(logits, y, mask) * batch_size + train_accuracy_sum += accuracy(logits, y, mask) * batch_size + train_precision_sum += precision(logits, y, mask) * batch_size train_recall_sum += recall(logits, y, mask) * batch_size + train_f1_sum += f1(logits, y, mask) * batch_size train_examples += batch_size if train_examples == 0: raise RuntimeError("Train loader produced no batches with seed nodes.") train_loss = train_loss_sum / train_examples - train_acc = train_acc_sum / train_examples + train_accuracy = train_accuracy_sum / train_examples + train_precision = train_precision_sum / train_examples train_recall = train_recall_sum / train_examples + train_f1 = train_f1_sum / train_examples # ------------------------------------------------------------------- # -------------------------------- val ------------------------------ self.eval() val_loss_sum = 0.0 - val_acc_sum = 0.0 + val_accuracy_sum = 0.0 + val_precision_sum = 0.0 val_recall_sum = 0.0 + val_f1_sum = 0.0 val_examples = 0 with torch.no_grad(): for batch in val: @@ -672,16 +708,20 @@ def fit(self, loss = criterion(logits[mask], y[mask]) val_loss_sum += float(loss.item()) * batch_size - val_acc_sum += accuracy(logits, y, mask) * batch_size + val_accuracy_sum += accuracy(logits, y, mask) * batch_size + val_precision_sum += precision(logits, y, mask) * batch_size val_recall_sum += recall(logits, y, mask) * batch_size + val_f1_sum += f1(logits, y, mask) * batch_size val_examples += batch_size if val_examples == 0: raise RuntimeError("Validation loader produced no batches with seed nodes.") current_val_loss = val_loss_sum / val_examples - val_acc = val_acc_sum / val_examples + val_accuracy = val_accuracy_sum / val_examples + val_precision = val_precision_sum / val_examples val_recall = val_recall_sum / val_examples + val_f1 = val_f1_sum / val_examples # ------------------------------------------------------------------- # -------------------------- scheduler + early stop ------------------ @@ -708,18 +748,35 @@ def fit(self, history.append({ "epoch": epoch, "train_loss": train_loss, - "train_acc": train_acc, + "train_acc": train_accuracy, + "train_accuracy": train_accuracy, + "train_precision": train_precision, "train_recall": train_recall, + "train_f1": train_f1, "val_loss": current_val_loss, - "val_acc": val_acc, + "val_acc": val_accuracy, + "val_accuracy": val_accuracy, + "val_precision": val_precision, "val_recall": val_recall, + "val_f1": val_f1, }) if epoch == 1 or epoch % k == 0: _logger.info( - "Epoch %d | train_loss=%.6f train_acc=%.4f train_recall=%.4f | " - "val_loss=%.6f val_acc=%.4f val_recall=%.4f", - epoch, train_loss, train_acc, train_recall, current_val_loss, val_acc, val_recall + "Epoch %d | train_loss=%.6f train_accuracy=%.4f train_precision=%.4f " + "train_recall=%.4f train_f1=%.4f | val_loss=%.6f val_accuracy=%.4f " + "val_precision=%.4f val_recall=%.4f val_f1=%.4f", + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, ) if early_stopping and no_generalization_after > tolerance: diff --git a/src/M2F/wb.py b/src/M2F/wb.py new file mode 100644 index 0000000..4ccd1c3 --- /dev/null +++ b/src/M2F/wb.py @@ -0,0 +1,43 @@ +import wandb +import torch + +def log_epoch(epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, + optimizer=None): + + metrics = { + "epoch": epoch, + + "train/loss": train_loss, + "train/accuracy": train_accuracy, + "train/precision": train_precision, + "train/recall": train_recall, + "train/f1": train_f1, + + "val/loss": val_loss, + "val/accuracy": val_accuracy, + "val/precision": val_precision, + "val/recall": val_recall, + "val/f1": val_f1, + } + + if optimizer is not None: + metrics["lr"] = optimizer.param_groups[0]["lr"] + + wandb.log(metrics) + + +def save_best_model(model, model_name, path): + torch.save(model.state_dict(), path) + artifact = wandb.Artifact(model_name, type="model") + artifact.add_file(str(path)) + wandb.log_artifact(artifact) From 67cb3ac4e54a59794c598aa673481884dac25ea9 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 08:30:57 -0400 Subject: [PATCH 11/18] integrated wb into the training and testing loops --- src/M2F/ffnn.py | 18 +++++++++++++++++- src/M2F/gnn.py | 36 ++++++++++++++++++++++++++++++++++-- src/M2F/wb.py | 19 +++++++++++++++---- 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/src/M2F/ffnn.py b/src/M2F/ffnn.py index 1bb2e91..a822699 100644 --- a/src/M2F/ffnn.py +++ b/src/M2F/ffnn.py @@ -103,6 +103,8 @@ def fit(self, if tolerance < 0: raise ValueError("`tolerance` must be >= 0") + from . import wb + k = report_performance_every_kth_epoch save_model_to = Path(save_model_to if save_model_to is not None else os.getcwd()) save_model_to.mkdir(parents=True, exist_ok=True) @@ -232,7 +234,7 @@ def fit(self, best_val_loss = current_val_loss no_generalization_after = 0 best_model_path = save_model_to / f"m2f_ffnn_{current_time()}.pt" - torch.save(self.state_dict(), best_model_path) + wb.save_best_model(self, "m2f-ffnn-best", best_model_path) _logger.debug( "New best validation loss %.6f at epoch %d; saved checkpoint to %s", best_val_loss, @@ -257,6 +259,20 @@ def fit(self, "val_recall": val_recall, "val_f1": val_f1, }) + wb.log_epoch( + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, + optimizer, + ) if epoch == 1 or epoch % k == 0: _logger.info( diff --git a/src/M2F/gnn.py b/src/M2F/gnn.py index 998308d..00657e6 100644 --- a/src/M2F/gnn.py +++ b/src/M2F/gnn.py @@ -223,6 +223,8 @@ def fit(self, if tolerance < 0: raise ValueError("`tolerance` must be >= 0") + from . import wb + k = report_performance_every_kth_epoch save_model_to = Path(save_model_to if save_model_to is not None else os.getcwd()) save_model_to.mkdir(parents=True, exist_ok=True) @@ -359,7 +361,7 @@ def fit(self, best_val_loss = current_val_loss no_generalization_after = 0 best_model_path = save_model_to / f"m2f_gnn_{current_time()}.pt" - torch.save(self.state_dict(), best_model_path) + wb.save_best_model(self, "m2f-graphconv-best", best_model_path) _logger.debug( "New best validation loss %.6f at epoch %d; saved checkpoint to %s", best_val_loss, @@ -384,6 +386,20 @@ def fit(self, "val_recall": val_recall, "val_f1": val_f1, }) + wb.log_epoch( + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, + optimizer, + ) if epoch == 1 or epoch % k == 0: _logger.info( @@ -602,6 +618,8 @@ def fit(self, if tolerance < 0: raise ValueError("`tolerance` must be >= 0") + from . import wb + k = report_performance_every_kth_epoch save_model_to = Path(save_model_to if save_model_to is not None else os.getcwd()) save_model_to.mkdir(parents=True, exist_ok=True) @@ -735,7 +753,7 @@ def fit(self, best_val_loss = current_val_loss no_generalization_after = 0 best_model_path = save_model_to / f"m2f_gat_{current_time()}.pt" - torch.save(self.state_dict(), best_model_path) + wb.save_best_model(self, "m2f-gat-best", best_model_path) _logger.debug( "New best validation loss %.6f at epoch %d; saved checkpoint to %s", best_val_loss, @@ -760,6 +778,20 @@ def fit(self, "val_recall": val_recall, "val_f1": val_f1, }) + wb.log_epoch( + epoch, + train_loss, + train_accuracy, + train_precision, + train_recall, + train_f1, + current_val_loss, + val_accuracy, + val_precision, + val_recall, + val_f1, + optimizer, + ) if epoch == 1 or epoch % k == 0: _logger.info( diff --git a/src/M2F/wb.py b/src/M2F/wb.py index 4ccd1c3..d1a6735 100644 --- a/src/M2F/wb.py +++ b/src/M2F/wb.py @@ -1,6 +1,11 @@ import wandb import torch + +def is_active(): + return wandb.run is not None + + def log_epoch(epoch, train_loss, train_accuracy, @@ -33,11 +38,17 @@ def log_epoch(epoch, if optimizer is not None: metrics["lr"] = optimizer.param_groups[0]["lr"] - wandb.log(metrics) + if is_active(): + wandb.log(metrics) + + return metrics def save_best_model(model, model_name, path): torch.save(model.state_dict(), path) - artifact = wandb.Artifact(model_name, type="model") - artifact.add_file(str(path)) - wandb.log_artifact(artifact) + if is_active(): + artifact = wandb.Artifact(model_name, type="model") + artifact.add_file(str(path)) + wandb.log_artifact(artifact) + + return path From 69047fcf6c2bbe191167cbfec2df2a6b7df04ea6 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 08:45:19 -0400 Subject: [PATCH 12/18] Fixed the CI dependency issue. --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9c3ccdc..50d1c64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ pandas==2.3.1 pyarrow==21.0.0 requests==2.32.4 scikit-learn==1.7.1 +setuptools==80.9.0 torch==2.8.0 torch-geometric==2.7.0 transformers==4.55.0 From c36d5e5365c6b089625ffe1545948b21cbf7a00e Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 09:06:40 -0400 Subject: [PATCH 13/18] Fix on-disk graph target encoding across feature shards Fix on-disk graph target encoding across feature shards Combine raw feature shards before applying ProteinGraphOnDiskDataset pre_transform so target encoders fit on the full feature table instead of per shard. This keeps multi-hot target dimensions and class meanings consistent when using sharded on-disk graph processing. --- src/M2F/pyg_data_interfaces.py | 86 ++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 41 deletions(-) diff --git a/src/M2F/pyg_data_interfaces.py b/src/M2F/pyg_data_interfaces.py index f57785f..b9cbae7 100644 --- a/src/M2F/pyg_data_interfaces.py +++ b/src/M2F/pyg_data_interfaces.py @@ -1644,56 +1644,58 @@ def process(self) -> None: required_cols = [self.dataset_input.Y_return_field_name, *self.dataset_input.X_return_field_names] feature_store: _ProteinFeatureStore | None = None + combined_features_path: Path | None = None total_rows = 0 try: # ---------------------- pass 1: node features ------------------- - for i in range(1, self.num_feature_batches + 1): - batch_path = self._feature_batch_path(i) - x_np, y_np = build_features_from_DatasetInput( - pre_transform=self.pre_transform, - pre_filter=self.pre_filter, - accessions_path=accessions_path, - features_path=batch_path, - required_cols=required_cols, - global_id_map=id_map, - X_return_field_names=self.dataset_input.X_return_field_names, - Y_return_field_name=self.dataset_input.Y_return_field_name, + batch_paths = [self._feature_batch_path(i) for i in range(1, self.num_feature_batches + 1)] + if len(batch_paths) == 1: + features_path = batch_paths[0] + else: + combined_features_path = self.processed_dir / "_combined_features.csv" + _logger.info( + "Combining %d feature shard(s) before pre_transform so target encodings are fitted globally", + len(batch_paths), ) + pd.concat( + (pd.read_csv(batch_path) for batch_path in batch_paths), + ignore_index=True, + ).to_csv(combined_features_path, index=False) + features_path = combined_features_path + + x_np, y_np = build_features_from_DatasetInput( + pre_transform=self.pre_transform, + pre_filter=self.pre_filter, + accessions_path=accessions_path, + features_path=features_path, + required_cols=required_cols, + global_id_map=id_map, + X_return_field_names=self.dataset_input.X_return_field_names, + Y_return_field_name=self.dataset_input.Y_return_field_name, + ) - if x_np.shape[0] == 0: - _logger.debug("Feature shard %s produced 0 kept rows; skipping", batch_path) - continue + if x_np.shape[0] == 0: + raise ValueError("All nodes were filtered out; cannot build on-disk dataset") - if feature_store is None: - feature_store = _ProteinFeatureStore( - store_on_disk_location=self.feature_store_dir, - node_feature_dim=(x_np.shape[1],), - edge_feature_dim=(0,), # rewritten after topology construction - target_feature_dim=(y_np.shape[1],), - read_only=False, - ) - else: - if x_np.shape[1] != int(feature_store.store.which_tensors["x"][0][1]): - raise ValueError("Inconsistent X feature dimensionality across feature batches") - if y_np.shape[1] != int(feature_store.store.which_tensors["y"][0][1]): - raise ValueError("Inconsistent Y dimensionality across feature batches") + feature_store = _ProteinFeatureStore( + store_on_disk_location=self.feature_store_dir, + node_feature_dim=(x_np.shape[1],), + edge_feature_dim=(0,), # rewritten after topology construction + target_feature_dim=(y_np.shape[1],), + read_only=False, + ) - feature_store.append_tensor(x_np, group_name=None, attr_name="x", index=None) - feature_store.append_tensor(y_np, group_name=None, attr_name="y", index=None) - total_rows += int(x_np.shape[0]) - _logger.info( - "Processed feature shard %d/%d -> rows=%d x_dim=%d y_dim=%d (total_rows=%d)", - i, - self.num_feature_batches, - int(x_np.shape[0]), - int(x_np.shape[1]), - int(y_np.shape[1]) if y_np.ndim > 1 else 1, - total_rows, - ) + feature_store.append_tensor(x_np, group_name=None, attr_name="x", index=None) + feature_store.append_tensor(y_np, group_name=None, attr_name="y", index=None) + total_rows = int(x_np.shape[0]) + _logger.info( + "Processed feature table -> rows=%d x_dim=%d y_dim=%d", + total_rows, + int(x_np.shape[1]), + int(y_np.shape[1]) if y_np.ndim > 1 else 1, + ) # ---------------------------------------------------------------- - if feature_store is None or total_rows == 0: - raise ValueError("All nodes were filtered out; cannot build on-disk dataset") _logger.info("Feature pass complete: total_nodes=%d", total_rows) # ---------------------- pass 2: topology ------------------------ @@ -1786,6 +1788,8 @@ def process(self) -> None: meta["edge_attr_dim"], ) finally: + if combined_features_path is not None and combined_features_path.exists(): + combined_features_path.unlink() if feature_store is not None: feature_store.close() From 1910acc3ce53da8f554884984293ae8770a2362d Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 09:47:20 -0400 Subject: [PATCH 14/18] req upd --- requirements.txt | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/requirements.txt b/requirements.txt index 50d1c64..fe28033 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,86 @@ +aiohappyeyeballs==2.6.2 +aiohttp==3.14.1 +aiosignal==1.4.0 +annotated-types==0.7.0 +anyio==4.13.0 +attrs==26.1.0 +certifi==2026.5.20 +charset-normalizer==3.4.7 +click==8.4.1 +crc32c==2.8 +distro==1.9.0 +docopt==0.6.2 +donfig==0.8.1.post1 +et_xmlfile==2.0.0 +filelock==3.29.3 +frozenlist==1.8.0 +fsspec==2026.4.0 +ftpretty==0.4.0 +gitdb==4.0.12 +GitPython==3.1.50 goatools==1.4.12 +h11==0.16.0 +hf-xet==1.5.1 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==0.36.2 +idna==3.18 +Jinja2==3.1.6 +jiter==0.15.0 +joblib==1.5.3 +markdown-it-py==4.2.0 +MarkupSafe==3.0.3 +mdurl==0.1.2 +-e git+ssh://git@github.com/Yehor-Mishchyriak/microbiome2function.git@c36d5e5365c6b089625ffe1545948b21cbf7a00e#egg=microbiome2function +mpmath==1.3.0 +multidict==6.7.1 +networkx==3.6.1 +numcodecs==0.16.5 numpy==2.3.2 openai==1.99.3 +openpyxl==3.1.5 +packaging==26.2 pandas==2.3.1 +patsy==1.0.2 +platformdirs==4.10.0 +propcache==0.5.2 +protobuf==7.35.1 +psutil==7.2.2 pyarrow==21.0.0 +pydantic==2.13.4 +pydantic_core==2.46.4 +pydot==4.0.1 +pyg_lib==0.6.0+pt28 +Pygments==2.20.0 +pyparsing==3.3.2 +python-dateutil==2.9.0.post0 +pytz==2026.2 +PyYAML==6.0.3 +regex==2026.5.9 requests==2.32.4 +rich==15.0.0 +safetensors==0.8.0 scikit-learn==1.7.1 +scipy==1.17.1 +sentry-sdk==2.62.0 setuptools==80.9.0 +six==1.17.0 +smmap==5.0.3 +sniffio==1.3.1 +statsmodels==0.14.6 +sympy==1.14.0 +threadpoolctl==3.6.0 +tokenizers==0.21.4 torch==2.8.0 torch-geometric==2.7.0 +tqdm==4.68.2 transformers==4.55.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2026.2 +urllib3==2.7.0 wandb==0.27.2 +xlsxwriter==3.2.9 +xxhash==3.7.0 +yarl==1.24.2 zarr==3.1.1 From 6d28d7b22cb96c25ecac50601bdb7aaf691b3a40 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 09:54:47 -0400 Subject: [PATCH 15/18] added training scripts --- .gitignore | 5 + training_scripts/ffnn_training.py | 0 training_scripts/gat_training.py | 0 training_scripts/gcnn_training.py | 335 +++++++++++++++++++++++ training_scripts/run_gcnn_test_subset.sh | 18 ++ 5 files changed, 358 insertions(+) create mode 100644 training_scripts/ffnn_training.py create mode 100644 training_scripts/gat_training.py create mode 100644 training_scripts/gcnn_training.py create mode 100644 training_scripts/run_gcnn_test_subset.sh diff --git a/.gitignore b/.gitignore index b284644..596c497 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,11 @@ !model_notebooks/PROT*/*.py !model_notebooks/PROT*/*.ipynb +# Training scripts +!training_scripts/ +!training_scripts/*.py +!training_scripts/*.sh + # Code examples (if present) !code_examples/ !code_examples/data_mining/ diff --git a/training_scripts/ffnn_training.py b/training_scripts/ffnn_training.py new file mode 100644 index 0000000..e69de29 diff --git a/training_scripts/gat_training.py b/training_scripts/gat_training.py new file mode 100644 index 0000000..e69de29 diff --git a/training_scripts/gcnn_training.py b/training_scripts/gcnn_training.py new file mode 100644 index 0000000..e38c0cb --- /dev/null +++ b/training_scripts/gcnn_training.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + +from M2F.cleaning_utils import clean_col +from M2F.embedding_utils import AAChainEmbedder +from M2F.feature_engineering_utils import embed_AAsequences, encode_go +from M2F.gnn import GraphConvNodeClassifier +from M2F.logging_utils import configure_logging +from M2F.pyg_data_interfaces import DatasetInput, ProteinGraphInMemoryDataset +import M2F.wb as wb + + +_logger = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Train a GraphConv node classifier on an M2F protein co-occurrence graph." + ) + + parser.add_argument("--data-dir", type=Path, required=True) + parser.add_argument("--run-dir", type=Path, required=True) + parser.add_argument("--force-reload", action="store_true") + + parser.add_argument("--go-depth", type=int, default=5) + parser.add_argument("--aa-model-key", default="esm2_t6_8M_UR50D") + parser.add_argument("--aa-batch-size", type=int, default=16) + parser.add_argument("--aa-device", default="auto") + + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--num-neighbors", default="15,10") + parser.add_argument("--tolerance", type=int, default=5) + parser.add_argument("--report-every", type=int, default=1) + parser.add_argument("--threshold", type=float, default=0.5) + + parser.add_argument("--msg-dim", type=int, default=128) + parser.add_argument("--state-dim", type=int, default=256) + parser.add_argument("--dropout-p", type=float, default=0.5) + parser.add_argument( + "--edge-features-used-as", + choices=("scaling", "catting"), + default="scaling", + ) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--weight-decay", type=float, default=1e-4) + + parser.add_argument("--val-set-size", type=float, default=0.1) + parser.add_argument("--test-set-size", type=float, default=0.1) + parser.add_argument("--request-size", type=int, default=25) + parser.add_argument("--rps", type=float, default=1.0) + parser.add_argument("--max-retry", type=int, default=20) + + parser.add_argument("--sequence-col", default="Sequence") + parser.add_argument("--go-col", default="Gene Ontology (molecular function)") + parser.add_argument("--edge-dst-column", default="j") + parser.add_argument("--edge-attr-columns", default="v") + + parser.add_argument("--device", default="auto") + parser.add_argument("--seed", type=int, default=13) + + parser.add_argument("--wandb-project", default=os.environ.get("WANDB_PROJECT")) + parser.add_argument("--wandb-name", default=os.environ.get("WANDB_NAME")) + parser.add_argument("--wandb-group", default=os.environ.get("WANDB_RUN_GROUP")) + parser.add_argument("--wandb-mode", default=os.environ.get("WANDB_MODE")) + + args = parser.parse_args() + if args.sequence_col != "Sequence": + parser.error("--sequence-col must be 'Sequence'; embed_AAsequences currently expects that column.") + return args + + +def resolve_device(device: str) -> str: + if device == "auto": + return "cuda:0" if torch.cuda.is_available() else "cpu" + return device + + +def parse_num_neighbors(value: str) -> list[int]: + return [int(part.strip()) for part in value.split(",") if part.strip()] + + +def parse_edge_attr_columns(value: str) -> tuple[str, ...] | None: + value = value.strip() + if value.lower() in {"", "none", "null"}: + return None + return tuple(part.strip() for part in value.split(",") if part.strip()) + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def json_safe(value): + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + return {key: json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [json_safe(item) for item in value] + return value + + +def as_go_multihot(idx_tuple, y_dim: int): + if not isinstance(idx_tuple, tuple) or y_dim == 0: + return np.nan + vec = np.zeros(y_dim, dtype=np.float32) + if idx_tuple: + vec[np.asarray(idx_tuple, dtype=np.int64)] = 1.0 + return vec if vec.sum() > 0 else np.nan + + +def make_pre_transform(args: argparse.Namespace, aa_device: str, go_label_map: dict[str, int]): + aa_encoder = AAChainEmbedder(model_key=args.aa_model_key, device=aa_device) + + def composed_pre_transform(node_df: pd.DataFrame) -> pd.DataFrame: + df = node_df.copy() + + df = clean_col( + df, + args.sequence_col, + apply_norm=False, + apply_strip_pubmed=False, + inplace=True, + ) + df = clean_col( + df, + args.go_col, + apply_norm=False, + apply_strip_pubmed=True, + inplace=True, + ) + + df, labels = encode_go( + df, + col_name=args.go_col, + depth=args.go_depth, + inplace=True, + ) + go_label_map.clear() + go_label_map.update(labels) + y_dim = len(go_label_map) + + df.loc[:, args.go_col] = df[args.go_col].map( + lambda idx_tuple: as_go_multihot(idx_tuple, y_dim) + ) + + return embed_AAsequences( + df, + embedder=aa_encoder, + batch_size=args.aa_batch_size, + inplace=True, + ) + + return composed_pre_transform + + +def make_pre_filter(args: argparse.Namespace): + def pre_filter_mask(df: pd.DataFrame): + x_ok = df[args.sequence_col].map( + lambda x: isinstance(x, np.ndarray) and x.size > 0 and np.isfinite(x).all() + ) + y_ok = df[args.go_col].map( + lambda y: ( + isinstance(y, np.ndarray) + and y.size > 0 + and np.isfinite(y).all() + and y.sum() > 0 + ) + ) + return x_ok & y_ok + + return pre_filter_mask + + +def start_wandb_if_requested(args: argparse.Namespace, config: dict) -> None: + if not args.wandb_project: + return + + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_name, + "group": args.wandb_group, + "job_type": "gcnn-training", + "config": config, + } + if args.wandb_mode: + init_kwargs["mode"] = args.wandb_mode + + init_kwargs["config"] = json_safe(init_kwargs["config"]) + wb.wandb.init(**init_kwargs) + + +def main() -> None: + args = parse_args() + args.run_dir.mkdir(parents=True, exist_ok=True) + configure_logging(str(args.run_dir / "logs"), console_level=logging.INFO) + set_seed(args.seed) + + model_device = resolve_device(args.device) + aa_device = resolve_device(args.aa_device) + num_neighbors = parse_num_neighbors(args.num_neighbors) + edge_attr_columns = parse_edge_attr_columns(args.edge_attr_columns) + + dataset_root = args.run_dir / "dataset" + model_dir = args.run_dir / "models" + model_dir.mkdir(parents=True, exist_ok=True) + + go_label_map: dict[str, int] = {} + dataset_input = DatasetInput( + path_to_accession_ids_csv_file=args.data_dir / "uniref_index_count.csv", + path_to_edge_csv_dir=args.data_dir, + X={"sequence": args.sequence_col}, + Y={"go_f": args.go_col}, + edge_dst_column=args.edge_dst_column, + edge_attr_columns=edge_attr_columns, + request_size=args.request_size, + rps=args.rps, + max_retry=args.max_retry, + ) + + pre_transform = make_pre_transform(args, aa_device, go_label_map) + pre_filter = make_pre_filter(args) + + ds = ProteinGraphInMemoryDataset( + root=dataset_root, + dataset_input=dataset_input, + pre_transform=pre_transform, + pre_filter=pre_filter, + force_reload=args.force_reload, + val_set_size=args.val_set_size, + test_set_size=args.test_set_size, + ) + + data = ds[0] + in_dim = int(data.x.size(-1)) + out_dim = int(data.y.size(-1)) if data.y.ndim > 1 else 1 + edge_dim = int(data.edge_attr.size(-1)) if data.edge_attr is not None else 0 + + if edge_dim == 0 and args.edge_features_used_as == "scaling": + raise ValueError("GraphConv scaling mode requires at least one edge attribute.") + + model_config = { + "num_nodes": int(data.num_nodes), + "num_edges": int(data.num_edges), + "in_dim": in_dim, + "edge_dim": edge_dim, + "out_dim": out_dim, + "num_neighbors": num_neighbors, + "model_device": model_device, + "aa_device": aa_device, + } + _logger.info("Dataset/model config: %s", model_config) + + if go_label_map: + with open(args.run_dir / "go_label_map.json", "w", encoding="utf-8") as handle: + json.dump(go_label_map, handle, indent=2, sort_keys=True) + + start_wandb_if_requested(args, {**vars(args), **model_config}) + + try: + train_loader = ds.train_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + shuffle=True, + ) + val_loader = ds.val_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + ) + test_loader = ds.test_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + ) + + model = GraphConvNodeClassifier( + in_dim=in_dim, + edge_dim=edge_dim, + msg_dim=args.msg_dim, + state_dim=args.state_dim, + out_dim=out_dim, + edge_features_used_as=args.edge_features_used_as, + dropout_p=args.dropout_p, + ).to(model_device) + + fit_result = model.fit( + train=train_loader, + val=val_loader, + epochs=args.epochs, + save_model_to=model_dir, + tolerance=args.tolerance, + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": args.lr, "weight_decay": args.weight_decay}, + report_performance_every_kth_epoch=args.report_every, + ) + + if fit_result["best_model_path"] is not None: + state = torch.load(fit_result["best_model_path"], map_location=model_device) + model.load_state_dict(state) + + test_metrics = model.test(test_loader, threshold=args.threshold) + if wb.is_active(): + wb.wandb.log({key.replace("_", "/"): value for key, value in test_metrics.items()}) + + output = { + "args": json_safe(vars(args)), + "model_config": json_safe(model_config), + "fit_result": fit_result, + "test_metrics": test_metrics, + } + with open(args.run_dir / "results.json", "w", encoding="utf-8") as handle: + json.dump(output, handle, indent=2) + + print(json.dumps({"best_model_path": fit_result["best_model_path"], **test_metrics}, indent=2)) + finally: + if wb.is_active(): + wb.wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/training_scripts/run_gcnn_test_subset.sh b/training_scripts/run_gcnn_test_subset.sh new file mode 100644 index 0000000..20bb4fb --- /dev/null +++ b/training_scripts/run_gcnn_test_subset.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -e + +set -a +source .env +set +a + +python training_scripts/gcnn_training.py \ + --data-dir untracked/test_data_subset \ + --run-dir untracked/runs/gcnn_test_subset \ + --epochs 20 \ + --batch-size 16 \ + --num-neighbors 5,5 \ + --report-every 1 \ + --aa-device cpu \ + --device cpu \ + --wandb-project m2f \ + --wandb-name gcnn-test-subset \ No newline at end of file From 00f6386fae91abec5358b1560bf713be13bba350 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 10:01:51 -0400 Subject: [PATCH 16/18] reqs upod --- requirements.txt | 73 ------------------------------------------------ 1 file changed, 73 deletions(-) diff --git a/requirements.txt b/requirements.txt index fe28033..50d1c64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,86 +1,13 @@ -aiohappyeyeballs==2.6.2 -aiohttp==3.14.1 -aiosignal==1.4.0 -annotated-types==0.7.0 -anyio==4.13.0 -attrs==26.1.0 -certifi==2026.5.20 -charset-normalizer==3.4.7 -click==8.4.1 -crc32c==2.8 -distro==1.9.0 -docopt==0.6.2 -donfig==0.8.1.post1 -et_xmlfile==2.0.0 -filelock==3.29.3 -frozenlist==1.8.0 -fsspec==2026.4.0 -ftpretty==0.4.0 -gitdb==4.0.12 -GitPython==3.1.50 goatools==1.4.12 -h11==0.16.0 -hf-xet==1.5.1 -httpcore==1.0.9 -httpx==0.28.1 -huggingface_hub==0.36.2 -idna==3.18 -Jinja2==3.1.6 -jiter==0.15.0 -joblib==1.5.3 -markdown-it-py==4.2.0 -MarkupSafe==3.0.3 -mdurl==0.1.2 --e git+ssh://git@github.com/Yehor-Mishchyriak/microbiome2function.git@c36d5e5365c6b089625ffe1545948b21cbf7a00e#egg=microbiome2function -mpmath==1.3.0 -multidict==6.7.1 -networkx==3.6.1 -numcodecs==0.16.5 numpy==2.3.2 openai==1.99.3 -openpyxl==3.1.5 -packaging==26.2 pandas==2.3.1 -patsy==1.0.2 -platformdirs==4.10.0 -propcache==0.5.2 -protobuf==7.35.1 -psutil==7.2.2 pyarrow==21.0.0 -pydantic==2.13.4 -pydantic_core==2.46.4 -pydot==4.0.1 -pyg_lib==0.6.0+pt28 -Pygments==2.20.0 -pyparsing==3.3.2 -python-dateutil==2.9.0.post0 -pytz==2026.2 -PyYAML==6.0.3 -regex==2026.5.9 requests==2.32.4 -rich==15.0.0 -safetensors==0.8.0 scikit-learn==1.7.1 -scipy==1.17.1 -sentry-sdk==2.62.0 setuptools==80.9.0 -six==1.17.0 -smmap==5.0.3 -sniffio==1.3.1 -statsmodels==0.14.6 -sympy==1.14.0 -threadpoolctl==3.6.0 -tokenizers==0.21.4 torch==2.8.0 torch-geometric==2.7.0 -tqdm==4.68.2 transformers==4.55.0 -typing-inspection==0.4.2 -typing_extensions==4.15.0 -tzdata==2026.2 -urllib3==2.7.0 wandb==0.27.2 -xlsxwriter==3.2.9 -xxhash==3.7.0 -yarl==1.24.2 zarr==3.1.1 From bb128929b6e37bc43c3c8377eea858ad29fab08d Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 10:16:29 -0400 Subject: [PATCH 17/18] added GAT and FFNN training scripts --- training_scripts/ffnn_training.py | 295 ++++++++++++++++++++++++++ training_scripts/gat_training.py | 330 ++++++++++++++++++++++++++++++ 2 files changed, 625 insertions(+) diff --git a/training_scripts/ffnn_training.py b/training_scripts/ffnn_training.py index e69de29..9623c99 100644 --- a/training_scripts/ffnn_training.py +++ b/training_scripts/ffnn_training.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + +from M2F.cleaning_utils import clean_col +from M2F.embedding_utils import AAChainEmbedder +from M2F.feature_engineering_utils import embed_AAsequences, encode_go +from M2F.ffnn import FFNN +from M2F.logging_utils import configure_logging +from M2F.pyg_data_interfaces import DatasetInput, ProteinDataset +import M2F.wb as wb + + +_logger = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Train an FFNN classifier on M2F protein sequence features." + ) + + parser.add_argument("--data-dir", type=Path, required=True) + parser.add_argument("--run-dir", type=Path, required=True) + parser.add_argument("--force-reload", action="store_true") + + parser.add_argument("--go-depth", type=int, default=5) + parser.add_argument("--aa-model-key", default="esm2_t6_8M_UR50D") + parser.add_argument("--aa-batch-size", type=int, default=16) + parser.add_argument("--aa-device", default="auto") + + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--tolerance", type=int, default=5) + parser.add_argument("--report-every", type=int, default=1) + parser.add_argument("--threshold", type=float, default=0.5) + + parser.add_argument("--hidden-dim1", type=int, default=512) + parser.add_argument("--hidden-dim2", type=int, default=256) + parser.add_argument("--dropout-p", type=float, default=0.5) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--weight-decay", type=float, default=1e-4) + + parser.add_argument("--val-set-size", type=float, default=0.1) + parser.add_argument("--test-set-size", type=float, default=0.1) + parser.add_argument("--num-feature-batches", type=int, default=1) + parser.add_argument("--request-size", type=int, default=25) + parser.add_argument("--rps", type=float, default=1.0) + parser.add_argument("--max-retry", type=int, default=20) + + parser.add_argument("--sequence-col", default="Sequence") + parser.add_argument("--go-col", default="Gene Ontology (molecular function)") + + parser.add_argument("--device", default="auto") + parser.add_argument("--seed", type=int, default=13) + + parser.add_argument("--wandb-project", default=os.environ.get("WANDB_PROJECT")) + parser.add_argument("--wandb-name", default=os.environ.get("WANDB_NAME")) + parser.add_argument("--wandb-group", default=os.environ.get("WANDB_RUN_GROUP")) + parser.add_argument("--wandb-mode", default=os.environ.get("WANDB_MODE")) + + args = parser.parse_args() + if args.sequence_col != "Sequence": + parser.error("--sequence-col must be 'Sequence'; embed_AAsequences currently expects that column.") + if args.num_feature_batches < 1: + parser.error("--num-feature-batches must be >= 1.") + return args + + +def resolve_device(device: str) -> str: + if device == "auto": + return "cuda:0" if torch.cuda.is_available() else "cpu" + return device + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def json_safe(value): + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + return {key: json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [json_safe(item) for item in value] + return value + + +def as_go_multihot(idx_tuple, y_dim: int): + if not isinstance(idx_tuple, tuple) or y_dim == 0: + return np.nan + vec = np.zeros(y_dim, dtype=np.float32) + if idx_tuple: + vec[np.asarray(idx_tuple, dtype=np.int64)] = 1.0 + return vec if vec.sum() > 0 else np.nan + + +def make_pre_transform(args: argparse.Namespace, aa_device: str, go_label_map: dict[str, int]): + aa_encoder = AAChainEmbedder(model_key=args.aa_model_key, device=aa_device) + + def composed_pre_transform(node_df: pd.DataFrame) -> pd.DataFrame: + df = node_df.copy() + + df = clean_col( + df, + args.sequence_col, + apply_norm=False, + apply_strip_pubmed=False, + inplace=True, + ) + df = clean_col( + df, + args.go_col, + apply_norm=False, + apply_strip_pubmed=True, + inplace=True, + ) + + df, labels = encode_go( + df, + col_name=args.go_col, + depth=args.go_depth, + inplace=True, + ) + go_label_map.clear() + go_label_map.update(labels) + y_dim = len(go_label_map) + + df.loc[:, args.go_col] = df[args.go_col].map( + lambda idx_tuple: as_go_multihot(idx_tuple, y_dim) + ) + + return embed_AAsequences( + df, + embedder=aa_encoder, + batch_size=args.aa_batch_size, + inplace=True, + ) + + return composed_pre_transform + + +def make_pre_filter(args: argparse.Namespace): + def pre_filter_mask(df: pd.DataFrame): + x_ok = df[args.sequence_col].map( + lambda x: isinstance(x, np.ndarray) and x.size > 0 and np.isfinite(x).all() + ) + y_ok = df[args.go_col].map( + lambda y: ( + isinstance(y, np.ndarray) + and y.size > 0 + and np.isfinite(y).all() + and y.sum() > 0 + ) + ) + return x_ok & y_ok + + return pre_filter_mask + + +def start_wandb_if_requested(args: argparse.Namespace, config: dict) -> None: + if not args.wandb_project: + return + + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_name, + "group": args.wandb_group, + "job_type": "ffnn-training", + "config": config, + } + if args.wandb_mode: + init_kwargs["mode"] = args.wandb_mode + + init_kwargs["config"] = json_safe(init_kwargs["config"]) + wb.wandb.init(**init_kwargs) + + +def main() -> None: + args = parse_args() + args.run_dir.mkdir(parents=True, exist_ok=True) + configure_logging(str(args.run_dir / "logs"), console_level=logging.INFO) + set_seed(args.seed) + + model_device = resolve_device(args.device) + aa_device = resolve_device(args.aa_device) + + dataset_root = args.run_dir / "dataset" + model_dir = args.run_dir / "models" + model_dir.mkdir(parents=True, exist_ok=True) + + go_label_map: dict[str, int] = {} + dataset_input = DatasetInput( + path_to_accession_ids_csv_file=args.data_dir / "uniref_index_count.csv", + X={"sequence": args.sequence_col}, + Y={"go_f": args.go_col}, + request_size=args.request_size, + rps=args.rps, + max_retry=args.max_retry, + num_feature_batches=args.num_feature_batches, + ) + + pre_transform = make_pre_transform(args, aa_device, go_label_map) + pre_filter = make_pre_filter(args) + + try: + with ProteinDataset( + root=dataset_root, + dataset_input=dataset_input, + pre_transform=pre_transform, + pre_filter=pre_filter, + force_reload=args.force_reload, + val_set_size=args.val_set_size, + test_set_size=args.test_set_size, + split="train", + include_targets=True, + ) as ds: + in_dim = int(ds.meta["x_dim"]) + out_dim = int(ds.meta["y_dim"]) + model_config = { + "num_nodes": int(ds.meta["num_nodes"]), + "in_dim": in_dim, + "out_dim": out_dim, + "model_device": model_device, + "aa_device": aa_device, + } + _logger.info("Dataset/model config: %s", model_config) + + start_wandb_if_requested(args, {**vars(args), **model_config}) + + if go_label_map: + with open(args.run_dir / "go_label_map.json", "w", encoding="utf-8") as handle: + json.dump(go_label_map, handle, indent=2, sort_keys=True) + + train_loader = ds.train_loader(batch_size=args.batch_size, shuffle=True) + val_loader = ds.val_loader(batch_size=args.batch_size) + test_loader = ds.test_loader(batch_size=args.batch_size) + + model = FFNN( + in_dim=in_dim, + hidden_dim1=args.hidden_dim1, + hidden_dim2=args.hidden_dim2, + out_dim=out_dim, + dropout_p=args.dropout_p, + ).to(model_device) + + fit_result = model.fit( + train=train_loader, + val=val_loader, + epochs=args.epochs, + save_model_to=model_dir, + tolerance=args.tolerance, + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": args.lr, "weight_decay": args.weight_decay}, + report_performance_every_kth_epoch=args.report_every, + ) + + if fit_result["best_model_path"] is not None: + state = torch.load(fit_result["best_model_path"], map_location=model_device) + model.load_state_dict(state) + + test_metrics = model.test(test_loader, threshold=args.threshold) + if wb.is_active(): + wb.wandb.log({key.replace("_", "/"): value for key, value in test_metrics.items()}) + + output = { + "args": json_safe(vars(args)), + "model_config": json_safe(model_config), + "fit_result": fit_result, + "test_metrics": test_metrics, + } + with open(args.run_dir / "results.json", "w", encoding="utf-8") as handle: + json.dump(output, handle, indent=2) + + print(json.dumps({"best_model_path": fit_result["best_model_path"], **test_metrics}, indent=2)) + finally: + if wb.is_active(): + wb.wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/training_scripts/gat_training.py b/training_scripts/gat_training.py index e69de29..1a5da0b 100644 --- a/training_scripts/gat_training.py +++ b/training_scripts/gat_training.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import pandas as pd +import torch + +from M2F.cleaning_utils import clean_col +from M2F.embedding_utils import AAChainEmbedder +from M2F.feature_engineering_utils import embed_AAsequences, encode_go +from M2F.gnn import GATNodeClassifier +from M2F.logging_utils import configure_logging +from M2F.pyg_data_interfaces import DatasetInput, ProteinGraphInMemoryDataset +import M2F.wb as wb + + +_logger = logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Train a GAT node classifier on an M2F protein co-occurrence graph." + ) + + parser.add_argument("--data-dir", type=Path, required=True) + parser.add_argument("--run-dir", type=Path, required=True) + parser.add_argument("--force-reload", action="store_true") + + parser.add_argument("--go-depth", type=int, default=5) + parser.add_argument("--aa-model-key", default="esm2_t6_8M_UR50D") + parser.add_argument("--aa-batch-size", type=int, default=16) + parser.add_argument("--aa-device", default="auto") + + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--num-neighbors", default="15,10") + parser.add_argument("--tolerance", type=int, default=5) + parser.add_argument("--report-every", type=int, default=1) + parser.add_argument("--threshold", type=float, default=0.5) + + parser.add_argument("--msg-dim", type=int, default=128) + parser.add_argument("--state-dim", type=int, default=256) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--attention-dropout-p", type=float, default=0.0) + parser.add_argument("--dropout-p", type=float, default=0.5) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--weight-decay", type=float, default=1e-4) + + parser.add_argument("--val-set-size", type=float, default=0.1) + parser.add_argument("--test-set-size", type=float, default=0.1) + parser.add_argument("--request-size", type=int, default=25) + parser.add_argument("--rps", type=float, default=1.0) + parser.add_argument("--max-retry", type=int, default=20) + + parser.add_argument("--sequence-col", default="Sequence") + parser.add_argument("--go-col", default="Gene Ontology (molecular function)") + parser.add_argument("--edge-dst-column", default="j") + parser.add_argument("--edge-attr-columns", default="v") + + parser.add_argument("--device", default="auto") + parser.add_argument("--seed", type=int, default=13) + + parser.add_argument("--wandb-project", default=os.environ.get("WANDB_PROJECT")) + parser.add_argument("--wandb-name", default=os.environ.get("WANDB_NAME")) + parser.add_argument("--wandb-group", default=os.environ.get("WANDB_RUN_GROUP")) + parser.add_argument("--wandb-mode", default=os.environ.get("WANDB_MODE")) + + args = parser.parse_args() + if args.sequence_col != "Sequence": + parser.error("--sequence-col must be 'Sequence'; embed_AAsequences currently expects that column.") + return args + + +def resolve_device(device: str) -> str: + if device == "auto": + return "cuda:0" if torch.cuda.is_available() else "cpu" + return device + + +def parse_num_neighbors(value: str) -> list[int]: + return [int(part.strip()) for part in value.split(",") if part.strip()] + + +def parse_edge_attr_columns(value: str) -> tuple[str, ...] | None: + value = value.strip() + if value.lower() in {"", "none", "null"}: + return None + return tuple(part.strip() for part in value.split(",") if part.strip()) + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def json_safe(value): + if isinstance(value, Path): + return str(value) + if isinstance(value, dict): + return {key: json_safe(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [json_safe(item) for item in value] + return value + + +def as_go_multihot(idx_tuple, y_dim: int): + if not isinstance(idx_tuple, tuple) or y_dim == 0: + return np.nan + vec = np.zeros(y_dim, dtype=np.float32) + if idx_tuple: + vec[np.asarray(idx_tuple, dtype=np.int64)] = 1.0 + return vec if vec.sum() > 0 else np.nan + + +def make_pre_transform(args: argparse.Namespace, aa_device: str, go_label_map: dict[str, int]): + aa_encoder = AAChainEmbedder(model_key=args.aa_model_key, device=aa_device) + + def composed_pre_transform(node_df: pd.DataFrame) -> pd.DataFrame: + df = node_df.copy() + + df = clean_col( + df, + args.sequence_col, + apply_norm=False, + apply_strip_pubmed=False, + inplace=True, + ) + df = clean_col( + df, + args.go_col, + apply_norm=False, + apply_strip_pubmed=True, + inplace=True, + ) + + df, labels = encode_go( + df, + col_name=args.go_col, + depth=args.go_depth, + inplace=True, + ) + go_label_map.clear() + go_label_map.update(labels) + y_dim = len(go_label_map) + + df.loc[:, args.go_col] = df[args.go_col].map( + lambda idx_tuple: as_go_multihot(idx_tuple, y_dim) + ) + + return embed_AAsequences( + df, + embedder=aa_encoder, + batch_size=args.aa_batch_size, + inplace=True, + ) + + return composed_pre_transform + + +def make_pre_filter(args: argparse.Namespace): + def pre_filter_mask(df: pd.DataFrame): + x_ok = df[args.sequence_col].map( + lambda x: isinstance(x, np.ndarray) and x.size > 0 and np.isfinite(x).all() + ) + y_ok = df[args.go_col].map( + lambda y: ( + isinstance(y, np.ndarray) + and y.size > 0 + and np.isfinite(y).all() + and y.sum() > 0 + ) + ) + return x_ok & y_ok + + return pre_filter_mask + + +def start_wandb_if_requested(args: argparse.Namespace, config: dict) -> None: + if not args.wandb_project: + return + + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_name, + "group": args.wandb_group, + "job_type": "gat-training", + "config": config, + } + if args.wandb_mode: + init_kwargs["mode"] = args.wandb_mode + + init_kwargs["config"] = json_safe(init_kwargs["config"]) + wb.wandb.init(**init_kwargs) + + +def main() -> None: + args = parse_args() + args.run_dir.mkdir(parents=True, exist_ok=True) + configure_logging(str(args.run_dir / "logs"), console_level=logging.INFO) + set_seed(args.seed) + + model_device = resolve_device(args.device) + aa_device = resolve_device(args.aa_device) + num_neighbors = parse_num_neighbors(args.num_neighbors) + edge_attr_columns = parse_edge_attr_columns(args.edge_attr_columns) + + dataset_root = args.run_dir / "dataset" + model_dir = args.run_dir / "models" + model_dir.mkdir(parents=True, exist_ok=True) + + go_label_map: dict[str, int] = {} + dataset_input = DatasetInput( + path_to_accession_ids_csv_file=args.data_dir / "uniref_index_count.csv", + path_to_edge_csv_dir=args.data_dir, + X={"sequence": args.sequence_col}, + Y={"go_f": args.go_col}, + edge_dst_column=args.edge_dst_column, + edge_attr_columns=edge_attr_columns, + request_size=args.request_size, + rps=args.rps, + max_retry=args.max_retry, + ) + + pre_transform = make_pre_transform(args, aa_device, go_label_map) + pre_filter = make_pre_filter(args) + + ds = ProteinGraphInMemoryDataset( + root=dataset_root, + dataset_input=dataset_input, + pre_transform=pre_transform, + pre_filter=pre_filter, + force_reload=args.force_reload, + val_set_size=args.val_set_size, + test_set_size=args.test_set_size, + ) + + data = ds[0] + in_dim = int(data.x.size(-1)) + out_dim = int(data.y.size(-1)) if data.y.ndim > 1 else 1 + edge_dim = int(data.edge_attr.size(-1)) if data.edge_attr is not None else 0 + + model_config = { + "num_nodes": int(data.num_nodes), + "num_edges": int(data.num_edges), + "in_dim": in_dim, + "edge_dim": edge_dim, + "out_dim": out_dim, + "num_neighbors": num_neighbors, + "model_device": model_device, + "aa_device": aa_device, + } + _logger.info("Dataset/model config: %s", model_config) + + if go_label_map: + with open(args.run_dir / "go_label_map.json", "w", encoding="utf-8") as handle: + json.dump(go_label_map, handle, indent=2, sort_keys=True) + + start_wandb_if_requested(args, {**vars(args), **model_config}) + + try: + train_loader = ds.train_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + shuffle=True, + ) + val_loader = ds.val_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + ) + test_loader = ds.test_loader( + num_neighbors=num_neighbors, + batch_size=args.batch_size, + ) + + model = GATNodeClassifier( + in_dim=in_dim, + edge_dim=edge_dim, + msg_dim=args.msg_dim, + state_dim=args.state_dim, + out_dim=out_dim, + heads=args.heads, + attention_dropout_p=args.attention_dropout_p, + dropout_p=args.dropout_p, + ).to(model_device) + + fit_result = model.fit( + train=train_loader, + val=val_loader, + epochs=args.epochs, + save_model_to=model_dir, + tolerance=args.tolerance, + optimizer=torch.optim.Adam, + optimizer_kwargs={"lr": args.lr, "weight_decay": args.weight_decay}, + report_performance_every_kth_epoch=args.report_every, + ) + + if fit_result["best_model_path"] is not None: + state = torch.load(fit_result["best_model_path"], map_location=model_device) + model.load_state_dict(state) + + test_metrics = model.test(test_loader, threshold=args.threshold) + if wb.is_active(): + wb.wandb.log({key.replace("_", "/"): value for key, value in test_metrics.items()}) + + output = { + "args": json_safe(vars(args)), + "model_config": json_safe(model_config), + "fit_result": fit_result, + "test_metrics": test_metrics, + } + with open(args.run_dir / "results.json", "w", encoding="utf-8") as handle: + json.dump(output, handle, indent=2) + + print(json.dumps({"best_model_path": fit_result["best_model_path"], **test_metrics}, indent=2)) + finally: + if wb.is_active(): + wb.wandb.finish() + + +if __name__ == "__main__": + main() From 6fa395a3707f5cb0c8d9dcb700ee933775c75706 Mon Sep 17 00:00:00 2001 From: Yehor Mishchyriak Date: Fri, 12 Jun 2026 11:11:07 -0400 Subject: [PATCH 18/18] wandb mention in the docs and readme --- README.md | 15 +++++++++++++++ docs.md | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 04d8d3b..7ca8d66 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ Current top-level exports include: - Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier`, `GATNodeClassifier` - Metrics: `accuracy`, `recall`, `precision`, `f1` - Dataset interfaces: `DatasetInput`, `build_topology_from_DatasetInput`, `build_features_from_DatasetInput`, `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`, `ProteinDataset` +- W&B helpers: `M2F.wb` for lightweight metric logging and best-model artifacts during `.fit(...)` - Utility namespace: `util` ## Typical Workflow @@ -252,6 +253,20 @@ configure_logging( ) ``` +### Weights & Biases + +`FFNN.fit(...)`, `GraphConvNodeClassifier.fit(...)`, and `GATNodeClassifier.fit(...)` log epoch metrics and best-model artifacts through `M2F.wb` when a W&B run is active: + +```python +import wandb + +wandb.init(project="m2f", name="example-run", config={"model": "gat"}) +history = model.fit(train_loader, val_loader, epochs=30, save_model_to="runs/checkpoints") +wandb.finish() +``` + +For training scripts, pass `--wandb-project`, `--wandb-name`, and optionally `--wandb-mode online`. + ## CI Workflows - `/.github/workflows/test.yml`: multi-version test matrix (3.11, 3.12) diff --git a/docs.md b/docs.md index b777539..2bc585b 100644 --- a/docs.md +++ b/docs.md @@ -72,6 +72,11 @@ Why: - M2F emits useful progress and validation messages during mining, processing, and training. - Debug logs are especially useful for long batched dataset builds. +W&B: +- Training loops in `FFNN`, `GraphConvNodeClassifier`, and `GATNodeClassifier` use `M2F.wb` opportunistically. +- If `wandb.init(...)` has been called, `.fit(...)` logs train/validation loss, accuracy, precision, recall, F1, learning rate, and best-model artifacts. +- If no W&B run is active, training still works normally and only local checkpoints are saved. + ## 3. Public API Overview Top-level import path: @@ -89,6 +94,7 @@ Current exported API (`M2F.__all__`) includes: - Models: `FFNN`, `GraphConv`, `GraphConvNodeClassifier`, `GATNodeClassifier`. - Metrics: `accuracy`, `recall`, `precision`, `f1`. - Dataset interfaces: `DatasetInput`, `build_topology_from_DatasetInput`, `build_features_from_DatasetInput`, `ProteinGraphInMemoryDataset`, `ProteinGraphOnDiskDataset`, `ProteinDataset`. +- W&B helpers: `M2F.wb` for lightweight metric logging and best-model artifacts when a W&B run is active. - Utility namespace: `util`. ## 4. Data Contracts You Must Respect @@ -491,6 +497,7 @@ Implementation details worth knowing: - Loss: `BCEWithLogitsLoss`. - During neighbor sampling, only seed nodes are supervised in each batch (`batch_size` mask logic). - `fit(...)` returns `best_val_loss`, `best_model_path`, and epoch-wise `history`. +- When W&B is active, `fit(...)` logs epoch metrics and the best checkpoint through `M2F.wb`. ## 8.2 GNN Attention: `GATNodeClassifier` @@ -532,6 +539,7 @@ Implementation details worth knowing: - `state_dim` must be divisible by `heads` because head outputs are concatenated. - Edge attributes are used when `edge_dim > 0`; empty edge-attribute tensors are ignored when `edge_dim=0`. - Training, evaluation, masking, loss, and returned history match `GraphConvNodeClassifier`. +- When W&B is active, `fit(...)` logs epoch metrics and the best checkpoint through `M2F.wb`. ## 8.3 FFNN: `FFNN` @@ -559,8 +567,39 @@ print(history["best_val_loss"], metrics) Implementation details: - Loss: `BCEWithLogitsLoss`. - `forward(...)` returns logits during training, sigmoid probabilities during eval. +- When W&B is active, `fit(...)` logs epoch metrics and the best checkpoint through `M2F.wb`. + +## 8.4 Weights & Biases Integration + +M2F keeps W&B optional. The model loops do not start runs themselves; they only log if a run is already active. + +```python +import wandb + +wandb.init(project="m2f", name="gcnn-example", config={"model": "gcnn"}) +history = model.fit( + train=train_loader, + val=val_loader, + epochs=50, + save_model_to="runs/checkpoints", +) +metrics = model.test(test_loader) +wandb.log({key.replace("_", "/"): value for key, value in metrics.items()}) +wandb.finish() +``` + +The training scripts expose the same behavior with CLI flags: + +```bash +python training_scripts/gat_training.py \ + --data-dir untracked/prev_30-0.005 \ + --run-dir untracked/runs/gat_prev_30_0005 \ + --wandb-project m2f \ + --wandb-name gat-prev-30-0005 \ + --wandb-mode online +``` -## 8.4 Metrics Utilities +## 8.5 Metrics Utilities Available helpers (`M2F.testing_utils`): - `accuracy(logits, y_true, mask, threshold=0.5)` @@ -720,6 +759,7 @@ python -m pip install dist/microbiome2function-0.1.0-py3-none-any.whl - Start small: validate your `DatasetInput` and preprocessing on a tiny accession subset first. - Keep transform contracts strict: `pre_transform` must return DataFrame; `pre_filter` must return boolean mask with matching length. - Use explicit checkpoints: preserve `meta.pt`, vocab maps, and model checkpoints per experiment. +- Use W&B for experiment comparison when running multiple FFNN/GCNN/GAT jobs; keep local `results.json` and checkpoint folders as the source of record. - Close on-disk datasets: prefer context-manager usage (`with ... as ...:`), or call `close()` to release zarr handles after training/inference. - Avoid silent schema drift: pin requested UniProt fields and return names in code, not notebooks-only state. @@ -733,5 +773,6 @@ python -m pip install dist/microbiome2function-0.1.0-py3-none-any.whl - `M2F.pyg_data_interfaces`: graph and FFNN dataset interfaces + standalone builders. - `M2F.gnn`: graph convolution, graph attention, and training/eval loops. - `M2F.ffnn`: feed-forward model and training/eval loops. +- `M2F.wb`: optional W&B metric logging and best-model artifact helpers. - `M2F.testing_utils`: metric helpers. - `M2F.util`: utility helpers and zarr feature-store backend.