Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/segger/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def segment(
validator=validators.Path(exists=True, dir_okay=True),
)] = registry.get_default("output_directory"),


save_anndata: Annotated[bool, registry.get_parameter(
"save_anndata",
group=group_io,
)] = registry.get_default("save_anndata"),

# Cell Representation
node_representation_dim: Annotated[int, Parameter(
help="Number of dimensions used to represent each node type.",
Expand Down Expand Up @@ -121,7 +125,6 @@ def segment(
group=group_nodes,
)] = registry.get_default("genes_clusters_resolution"),


# Transcript-Transcript Graph
transcripts_max_k: Annotated[int, registry.get_parameter(
"transcripts_graph_max_k",
Expand All @@ -145,6 +148,12 @@ def segment(
)
] = registry.get_default("prediction_graph_mode"),

prediction_expansion_ratio: Annotated[float, registry.get_parameter(
"prediction_graph_buffer_ratio",
validator=validators.Number(gt=0),
group=group_prediction,
)] = registry.get_default("prediction_graph_buffer_ratio"),

prediction_max_k: Annotated[int | None, registry.get_parameter(
"prediction_graph_max_k",
validator=validators.Number(gt=0),
Expand Down Expand Up @@ -342,7 +351,10 @@ def segment(
from ..data import ISTSegmentationWriter
from lightning.pytorch import Trainer
logger = CSVLogger(output_directory)
writer = ISTSegmentationWriter(output_directory)
writer = ISTSegmentationWriter(
output_directory=output_directory,
save_anndata=save_anndata,
)
trainer = Trainer(
logger=logger,
max_epochs=n_epochs,
Expand Down
65 changes: 61 additions & 4 deletions src/segger/data/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..io import TrainingTranscriptFields, TrainingBoundaryFields
from . import ISTDataModule
from .utils.anndata import anndata_from_transcripts


def threshold(x):
Expand All @@ -24,9 +25,14 @@ class ISTSegmentationWriter(BasePredictionWriter):
Path to write outputs.
"""

def __init__(self, output_directory: Path):
def __init__(
self,
output_directory: Path,
save_anndata: bool = True,
):
super().__init__(write_interval="epoch")
self.output_directory = Path(output_directory)
self.save_anndata = save_anndata

def write_on_epoch_end(
self,
Expand Down Expand Up @@ -125,10 +131,61 @@ def write_on_epoch_end(
.alias("similarity_threshold")
)
)
# Join and write output to file
# Join thresholds
segmentation = segmentation.join(thresholds, on=tx_fields.feature, how='left')

# Map gene encoding to gene names
gene_index = (
pl
.from_pandas(trainer.datamodule.ad.var.reset_index())
.rename({"index": tx_fields.feature})
.select([tx_fields.feature, tx_fields.gene_encoding])
)
segmentation = (
segmentation
.rename({tx_fields.feature: tx_fields.gene_encoding})
.join(gene_index, on=tx_fields.gene_encoding, how='left')
)

# Write segmentation output (keep prior columns)
(
segmentation
.join(thresholds, on=tx_fields.feature, how='left')
.drop(tx_fields.feature)
.drop([tx_fields.feature, tx_fields.gene_encoding])
.write_parquet(self.output_directory / 'segger_segmentation.parquet')
)

# Optional: save AnnData
if self.save_anndata:
tx = trainer.datamodule.tx
transcripts = (
segmentation
.join(
tx.select([
tx_fields.row_index,
tx_fields.x,
tx_fields.y,
tx_fields.feature,
]),
on=tx_fields.row_index,
how='left',
)
.rename({tx_fields.feature: "segger_gene"})
.select([
tx_fields.row_index,
"segger_gene",
"segger_cell_id",
"segger_similarity",
"similarity_threshold",
tx_fields.x,
tx_fields.y,
])
)

adata = anndata_from_transcripts(
transcripts,
feature_column="segger_gene",
cell_id_column="segger_cell_id",
score_column="segger_similarity",
coordinate_columns=[tx_fields.x, tx_fields.y],
)
adata.write_h5ad(self.output_directory / 'segger_anndata.h5ad')