From e01a04df5159cca5c51f83a9df171e4e29a38d30 Mon Sep 17 00:00:00 2001 From: Cavan Riley Date: Wed, 4 Mar 2026 09:46:24 -0600 Subject: [PATCH 1/4] ADD: add multifile and multichannel flag to local datastore Adds flags to the local datastore init signaling the data is either multichannel or multi-file (each directory has a sample with multiple volumes Signed-off-by: Cavan Riley --- monailabel/datastore/cvat.py | 27 ++++++ monailabel/datastore/dicom.py | 26 ++++++ monailabel/datastore/dsa.py | 26 ++++++ monailabel/datastore/local.py | 102 ++++++++++++++++++---- monailabel/datastore/xnat.py | 26 ++++++ monailabel/endpoints/datastore.py | 16 +++- monailabel/interfaces/app.py | 15 +++- monailabel/interfaces/datastore.py | 26 ++++++ monailabel/tasks/activelearning/first.py | 8 ++ monailabel/tasks/activelearning/random.py | 28 +++++- monailabel/tasks/train/basic_train.py | 5 ++ 11 files changed, 278 insertions(+), 27 deletions(-) diff --git a/monailabel/datastore/cvat.py b/monailabel/datastore/cvat.py index 461883ea0..4ec6583e8 100644 --- a/monailabel/datastore/cvat.py +++ b/monailabel/datastore/cvat.py @@ -16,6 +16,7 @@ import tempfile import time import urllib.parse +from typing import Any, Dict import numpy as np import requests @@ -318,6 +319,32 @@ def download_from_cvat(self, max_retry_count=5, retry_wait_time=10): retry_count += 1 return None + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Not implemented for this datastore + + Abstract method for adding a directory to cvat + """ + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at multichannel (4D) data + """ + logger.info("The function get_is_multichannel is not implemented for this datastore") + return False + + def get_is_multi_file(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + logger.info("The function get_is_multi_file is not implemented for this datastore") + return False + """ def main(): diff --git a/monailabel/datastore/dicom.py b/monailabel/datastore/dicom.py index ed4733ca6..5447b826d 100644 --- a/monailabel/datastore/dicom.py +++ b/monailabel/datastore/dicom.py @@ -264,3 +264,29 @@ def _download_labeled_data(self): def datalist(self, full_path=True) -> List[Dict[str, Any]]: self._download_labeled_data() return super().datalist(full_path) + + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Not implemented + + Abstract method for adding a directory to DICOMWeb + """ + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at multichannel (4D) data + """ + logger.info("The function get_is_multichannel is not implemented for this datastore") + return False + + def get_is_multi_file(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + logger.info("The function get_is_multi_file is not implemented for this datastore") + return False diff --git a/monailabel/datastore/dsa.py b/monailabel/datastore/dsa.py index 365cef24e..f7e586251 100644 --- a/monailabel/datastore/dsa.py +++ b/monailabel/datastore/dsa.py @@ -270,6 +270,32 @@ def status(self) -> Dict[str, Any]: def json(self): return self.datalist() + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Not implemented for this datastore + + Abstract method for adding a directory to dsa + """ + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at multichannel (4D) data + """ + logger.info("The function get_is_multichannel is not implemented for this datastore") + return False + + def get_is_multi_file(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + logger.info("The function get_is_multi_file is not implemented for this datastore") + return False + """ def main(): diff --git a/monailabel/datastore/local.py b/monailabel/datastore/local.py index d8b0538aa..c31a93f5f 100644 --- a/monailabel/datastore/local.py +++ b/monailabel/datastore/local.py @@ -102,9 +102,11 @@ def __init__( images_dir: str = ".", labels_dir: str = "labels", datastore_config: str = "datastore_v2.json", - extensions=("*.nii.gz", "*.nii"), + extensions=("*.nii.gz", "*.nii", "*.nrrd"), auto_reload=False, read_only=False, + multichannel: bool = False, + multi_file: bool = False, ): """ Creates a `LocalDataset` object @@ -124,6 +126,14 @@ def __init__( self._ignore_event_config = False self._config_ts = 0 self._auto_reload = auto_reload + if multichannel and multi_file: + raise ValueError( + "multichannel and multi_file are mutually exclusive: " + "multichannel expects a single 4D NIfTI volume per sample, " + "while multi_file expects a directory of separate modality files." + ) + self._multichannel: bool = multichannel + self._multi_file: bool = multi_file logging.getLogger("filelock").setLevel(logging.ERROR) @@ -256,6 +266,18 @@ def datalist(self, full_path=True) -> List[Dict[str, Any]]: ds = json.loads(json.dumps(ds).replace(f"{self._datastore_path.rstrip(os.pathsep)}{os.pathsep}", "")) return ds + def get_is_multichannel(self) -> bool: + """ + Returns whether the dataset is multichannel or not + """ + return self._multichannel + + def get_is_multi_file(self) -> bool: + """ + Returns whether the dataset is multi-file or not + """ + return self._multi_file + def get_image(self, image_id: str, params=None) -> Any: """ Retrieve image object based on image id @@ -431,6 +453,43 @@ def refresh(self): """ self._reconcile_datastore() + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Add a directory to the datastore + + :param directory_id: the directory id + :param filename: the filename + :param info: additional info + + :return: directory id + """ + id = os.path.basename(os.path.normpath(filename)) + if not directory_id: + directory_id = id + + logger.info(f"Adding Image: {directory_id} => {filename}") + name = directory_id + dest = os.path.realpath(os.path.join(self._datastore.image_path(), name)) + + with FileLock(self._lock_file): + logger.debug("Acquired the lock!") + if os.path.isdir(filename): + if os.path.exists(dest): + shutil.rmtree(dest) + shutil.copytree(filename, dest) + else: + shutil.copy2(filename, dest) + + info = info if info else {} + info["ts"] = int(time.time()) + info["name"] = name + + # images = get_directory_contents(filename) + self._datastore.objects[directory_id] = ImageLabelModel(image=DataModel(info=info, ext="")) + self._update_datastore_file(lock=False) + logger.debug("Released the lock!") + return directory_id + def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: id, image_ext = self._to_id(os.path.basename(image_filename)) if not image_id: @@ -552,10 +611,17 @@ def _list_files(self, path, patterns): files = os.listdir(path) filtered = dict() - for pattern in patterns: - matching = fnmatch.filter(files, pattern) - for file in matching: - filtered[os.path.basename(file)] = file + if not self._multi_file: + for pattern in patterns: + matching = fnmatch.filter(files, pattern) + for file in matching: + filtered[os.path.basename(file)] = file + else: + ignored = {"labels", ".lock", os.path.basename(self._datastore_config_path).lower()} + for file in files: + abs_file = os.path.join(path, file) + if os.path.isdir(abs_file) and file.lower() not in ignored: + filtered[os.path.basename(file)] = file return filtered def _reconcile_datastore(self): @@ -585,24 +651,26 @@ def _add_non_existing_images(self) -> int: invalidate = 0 self._init_from_datastore_file() - local_images = self._list_files(self._datastore.image_path(), self._extensions) + local_files = self._list_files(self._datastore.image_path(), self._extensions) - image_ids = list(self._datastore.objects.keys()) - for image_file in local_images: - image_id, image_ext = self._to_id(image_file) - if image_id not in image_ids: - logger.info(f"Adding New Image: {image_id} => {image_file}") + ids = list(self._datastore.objects.keys()) + for file in local_files: + if self._multi_file: + # Directories have no extension — use the name as-is + file_id = file + file_ext_str = "" + else: + file_id, file_ext_str = self._to_id(file) - name = self._filename(image_id, image_ext) - image_info = { + if file_id not in ids: + logger.info(f"Adding New Image: {file_id} => {file}") + name = self._filename(file_id, file_ext_str) + file_info = { "ts": int(time.time()), - # "checksum": file_checksum(os.path.join(self._datastore.image_path(), name)), "name": name, } - invalidate += 1 - self._datastore.objects[image_id] = ImageLabelModel(image=DataModel(info=image_info, ext=image_ext)) - + self._datastore.objects[file_id] = ImageLabelModel(image=DataModel(info=file_info, ext=file_ext_str)) return invalidate def _add_non_existing_labels(self, tag) -> int: diff --git a/monailabel/datastore/xnat.py b/monailabel/datastore/xnat.py index c0904fd8d..acaa2af1c 100644 --- a/monailabel/datastore/xnat.py +++ b/monailabel/datastore/xnat.py @@ -386,6 +386,32 @@ def __upload_assessment(self, aiaa_model_name, image_id, file_path, type): self._request_put(url, data, type=type) + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Not implemented for this datastore + + Abstract method for adding a directory to xnat + """ + raise NotImplementedError("This datastore does not support adding directories") + + def get_is_multichannel(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at multichannel (4D) data + """ + logger.info("The function get_is_multichannel is not implemented for this datastore") + return False + + def get_is_multi_file(self) -> bool: + """ + Not implemented for this datastore + + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + logger.info("The function get_is_multi_file is not implemented for this datastore") + return False + """ def main(): diff --git a/monailabel/endpoints/datastore.py b/monailabel/endpoints/datastore.py index fdd63bb6e..a6252c2b2 100644 --- a/monailabel/endpoints/datastore.py +++ b/monailabel/endpoints/datastore.py @@ -66,20 +66,28 @@ def add_image( user: Optional[str] = None, ): logger.info(f"Image: {image}; File: {file}; params: {params}") - file_ext = "".join(pathlib.Path(file.filename).suffixes) if file.filename else ".nii.gz" - image_id = image if image else os.path.basename(file.filename).replace(file_ext, "") + instance: MONAILabelApp = app_instance() + if instance.datastore().get_is_multi_file(): + raise HTTPException( + status_code=400, + detail="Multi-file datastore does not support single-file uploads. " + "Data must be pre-staged as sample subdirectories on the server filesystem.", + ) + + file_ext = "".join(pathlib.Path(file.filename).suffixes) if file.filename else ".nii.gz" + id = image if image else os.path.basename(file.filename).replace(file_ext, "") image_file = tempfile.NamedTemporaryFile(suffix=file_ext).name with open(image_file, "wb") as buffer: shutil.copyfileobj(file.file, buffer) background_tasks.add_task(remove_file, image_file) - instance: MONAILabelApp = app_instance() save_params: Dict[str, Any] = json.loads(params) if params else {} if user: save_params["user"] = user - image_id = instance.datastore().add_image(image_id, image_file, save_params) + + image_id = instance.datastore().add_image(id, image_file, save_params) return {"image": image_id} diff --git a/monailabel/interfaces/app.py b/monailabel/interfaces/app.py index f9b405bde..938987e38 100644 --- a/monailabel/interfaces/app.py +++ b/monailabel/interfaces/app.py @@ -90,7 +90,9 @@ def __init__( self.app_dir = app_dir self.studies = studies self.conf = conf if conf else {} - + self.multichannel: bool = strtobool(conf.get("multichannel", False)) + self.multi_file: bool = strtobool(conf.get("multi_file", False)) + self.input_channels = conf.get("input_channels", False) self.name = name self.description = description self.version = version @@ -146,6 +148,8 @@ def init_datastore(self) -> Datastore: extensions=settings.MONAI_LABEL_DATASTORE_FILE_EXT, auto_reload=settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD, read_only=settings.MONAI_LABEL_DATASTORE_READ_ONLY, + multichannel=self.multichannel, + multi_file=self.multi_file, ) def init_remote_datastore(self) -> Datastore: @@ -282,6 +286,9 @@ def infer(self, request, datastore=None): ) request = copy.deepcopy(request) + request["multi_file"] = self.multi_file + request["multichannel"] = self.multichannel + request["input_channels"] = self.input_channels request["description"] = task.description image_id = request["image"] @@ -292,7 +299,7 @@ def infer(self, request, datastore=None): else: request["image"] = datastore.get_image_uri(request["image"]) - if os.path.isdir(request["image"]): + if os.path.isdir(request["image"]) and not self.multi_file: logger.info("Input is a Directory; Consider it as DICOM") logger.debug(f"Image => {request['image']}") @@ -431,6 +438,10 @@ def train(self, request): ) request = copy.deepcopy(request) + # 4D image support, send train task information regarding data + request["multi_file"] = self.multi_file + request["multichannel"] = self.multichannel + request["input_channels"] = self.input_channels result = task(request, self.datastore()) # Run all scoring methods diff --git a/monailabel/interfaces/datastore.py b/monailabel/interfaces/datastore.py index 78fa0aecc..702bff1bb 100644 --- a/monailabel/interfaces/datastore.py +++ b/monailabel/interfaces/datastore.py @@ -201,6 +201,18 @@ def refresh(self) -> None: """ pass + @abstractmethod + def add_directory(self, directory_id: str, filename: str, info: Dict[str, Any]) -> str: + """ + Save a directory for the given directory id and return the newly saved directory's id + + :param directory_id: the directory id for the image; If None then base filename will be used + :param filename: the path to the directory + :param info: additional info for the directory + :return: the directory id for the saved image filename + """ + pass + @abstractmethod def add_image(self, image_id: str, image_filename: str, image_info: Dict[str, Any]) -> str: """ @@ -279,3 +291,17 @@ def json(self): Return json representation of datastore """ pass + + @abstractmethod + def get_is_multichannel(self) -> bool: + """ + Returns whether the application's studies is directed at multichannel (4D) data + """ + pass + + @abstractmethod + def get_is_multi_file(self) -> bool: + """ + Returns whether the application's studies is directed at directories containing multiple images per sample + """ + pass diff --git a/monailabel/tasks/activelearning/first.py b/monailabel/tasks/activelearning/first.py index 2a0ffa675..e80ec7704 100644 --- a/monailabel/tasks/activelearning/first.py +++ b/monailabel/tasks/activelearning/first.py @@ -35,5 +35,13 @@ def __call__(self, request, datastore: Datastore): images.sort() image = images[0] + # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences + if datastore.get_is_multichannel(): + return {"id": image, "multichannel": True} + + # If the datastore is multi_file, each sample has a directory with multiple images + if datastore.get_is_multi_file(): + return {"id": image, "multi_file": True} + logger.info(f"First: Selected Image: {image}") return {"id": image} diff --git a/monailabel/tasks/activelearning/random.py b/monailabel/tasks/activelearning/random.py index b196f7a6b..1dabb8f53 100644 --- a/monailabel/tasks/activelearning/random.py +++ b/monailabel/tasks/activelearning/random.py @@ -42,7 +42,27 @@ def __call__(self, request, datastore: Datastore): current_ts = int(time.time()) weights = [current_ts - info.get("ts", 0) for info in images_info] - image = random.choices(images, weights=weights)[0] - logger.debug(f"Random: Images: {images}; Weight: {weights}") - logger.info(f"Random: Selected Image: {image}; Weight: {weights[0]}") - return {"id": image, "weight": weights[0]} + if not any(weights): + # All weights are zero (every image was seen at the current second); + # fall back to a uniform random pick to avoid a ValueError from random.choices. + selected_idx = random.randrange(len(images)) + else: + selected_idx = random.choices(range(len(images)), weights=weights, k=1)[0] + image = images[selected_idx] + selected_weight = weights[selected_idx] + + logger.info(f"Random: Selected Image: {image}; Weight: {selected_weight}") + + # If the datastore contains 4d images send the multichannel flag to ensure images are loaded as sequences + if datastore.get_is_multichannel(): + return {"id": image, "weight": selected_weight, "multichannel": True} + + # If the datastore is multi_file, each sample has a directory with multiple images + if datastore.get_is_multi_file(): + return { + "id": image, + "weight": selected_weight, + "multi_file": True, + } # this will send the directory and we will walk it later on + + return {"id": image, "weight": selected_weight} diff --git a/monailabel/tasks/train/basic_train.py b/monailabel/tasks/train/basic_train.py index 9e5d0b1a9..65211f47a 100644 --- a/monailabel/tasks/train/basic_train.py +++ b/monailabel/tasks/train/basic_train.py @@ -83,6 +83,8 @@ def __init__(self): self.multi_gpu = False # multi gpu enabled self.local_rank = 0 # local rank in case of multi gpu self.world_size = 0 # world size in case of multi gpu + self.input_channels = 1 + self.multi_file = False self.request = None self.trainer = None @@ -490,6 +492,9 @@ def train(self, rank, world_size, request, datalist): context.run_id = request["run_id"] context.multi_gpu = request["multi_gpu"] + context.multi_file = request.get("multi_file", False) + context.input_channels = request.get("input_channels", 1) + if context.multi_gpu: os.environ["LOCAL_RANK"] = str(context.local_rank) From 0070ca2d63d193a9e61a0280520f659b8b1efdd2 Mon Sep 17 00:00:00 2001 From: Cavan Riley Date: Wed, 4 Mar 2026 10:17:38 -0600 Subject: [PATCH 2/4] ENH: modified slicer3d plugin to load multi-volume datasets Adds ability to load multiple volumes per sample or load sequence data depending on request Signed-off-by: Cavan Riley --- plugins/slicer/MONAILabel/MONAILabel.py | 59 +++++++++++++++++++++---- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/plugins/slicer/MONAILabel/MONAILabel.py b/plugins/slicer/MONAILabel/MONAILabel.py index 5051dacf3..5e0188280 100644 --- a/plugins/slicer/MONAILabel/MONAILabel.py +++ b/plugins/slicer/MONAILabel/MONAILabel.py @@ -1303,19 +1303,62 @@ def onNextSampleButton(self): return logging.info(sample) - image_id = sample["id"] + sample_id = sample["id"] image_file = sample.get("path") - image_name = sample.get("name", image_id) - node_name = sample.get("PatientID", sample.get("name", image_id)) + image_name = sample.get("name", sample_id) + node_name = sample.get("PatientID", sample.get("name", sample_id)) checksum = sample.get("checksum") local_exists = image_file and os.path.exists(image_file) + multichannel: bool = bool(sample.get("multichannel", False)) + multi_file: bool = bool(sample.get("multi_file", False)) logging.info(f"Check if file exists/shared locally: {image_file} => {local_exists}") if local_exists: - self._volumeNode = slicer.util.loadVolume(image_file) - self._volumeNode.SetName(node_name) + if multichannel: + # For 4D multichannel images, NOTE: slicer does not like 4D nifti images + # from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497788857 + volumeSequenceNode = slicer.util.loadSequence(image_file) + volumeSequenceNode.SetName(node_name) + # Get a volume node + browserNode = slicer.modules.sequences.logic().GetFirstBrowserNodeForSequenceNode( + volumeSequenceNode + ) + browserNode.SetOverwriteProxyName( + None, True + ) # set the proxy node name based on the sequence node name + self._volumeNode = browserNode.GetProxyNode(volumeSequenceNode) + else: + if not multi_file: + self._volumeNode = slicer.util.loadVolume(image_file) + self._volumeNode.SetName(node_name) + else: # in the case the underlying dataset is multi_file, we load all the images in the directory + dir_path = image_file + if not os.path.isdir(dir_path): + raise ValueError(f"multi_file=True but path is not a directory: {dir_path}") + + # get valid image paths + entries = sorted(os.listdir(dir_path)) + image_paths = [] + for name in entries: + full_path = os.path.join(dir_path, name) + if os.path.isfile(full_path) and name.lower().endswith((".nii", ".nii.gz", ".nrrd")): + image_paths.append(full_path) + + if not image_paths: + raise ValueError(f"No loadable modality files found in: {dir_path}") + + nodes = [] + for image in image_paths: + image_base_name = os.path.basename(image) + node = slicer.util.loadVolume(image) + if node is None: + raise RuntimeError(f"Failed to load modality volume: {image}") + node.SetName(image_base_name) + nodes.append(node) + + self._volumeNode = nodes[0] else: - download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(image_id)}" + download_uri = f"{self.serverUrl()}/datastore/image?image={quote_plus(sample_id)}" logging.info(download_uri) sampleDataLogic = SampleData.SampleDataLogic() @@ -1326,7 +1369,7 @@ def onNextSampleButton(self): if slicer.util.settingsValue("MONAILabel/originalLabel", True, converter=slicer.util.toBool): try: datastore = self.logic.datastore() - label_info = datastore["objects"][image_id]["labels"]["original"]["info"] + label_info = datastore["objects"][sample_id]["labels"]["original"]["info"] labels = label_info.get("params", {}).get("label_names", {}) if labels: @@ -1338,7 +1381,7 @@ def onNextSampleButton(self): labels = self.logic.info().get("labels") # ext = datastore['objects'][image_id]['labels']['original']['ext'] - maskFile = self.logic.download_label(image_id, "original") + maskFile = self.logic.download_label(sample_id, "original") self.updateSegmentationMask(maskFile, list(labels)) print("Original label uploaded! ") From e244ec266a3546df01834e61cca87032d5108604 Mon Sep 17 00:00:00 2001 From: Cavan Riley Date: Wed, 4 Mar 2026 10:19:52 -0600 Subject: [PATCH 3/4] ADD: add BraTS segmentation sample application BraTS sample application addition using BraTS 2020 multi-volume data Signed-off-by: Cavan Riley --- .../lib/configs/segmentation_brats.py | 110 ++++++++ .../lib/infers/segmentation_brats.py | 175 +++++++++++++ .../lib/trainers/segmentation_brats.py | 242 ++++++++++++++++++ .../radiology/lib/transforms/transforms.py | 181 ++++++++++++- sample-apps/radiology/main.py | 23 +- 5 files changed, 729 insertions(+), 2 deletions(-) create mode 100644 sample-apps/radiology/lib/configs/segmentation_brats.py create mode 100644 sample-apps/radiology/lib/infers/segmentation_brats.py create mode 100644 sample-apps/radiology/lib/trainers/segmentation_brats.py diff --git a/sample-apps/radiology/lib/configs/segmentation_brats.py b/sample-apps/radiology/lib/configs/segmentation_brats.py new file mode 100644 index 000000000..4fcae6325 --- /dev/null +++ b/sample-apps/radiology/lib/configs/segmentation_brats.py @@ -0,0 +1,110 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, Optional, Union + +import lib.infers +import lib.trainers +from monai.networks.nets import SegResNet +from monai.utils import optional_import + +from monailabel.interfaces.config import TaskConfig +from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.train import TrainTask +from monailabel.utils.others.generic import download_file, strtobool + +_, has_cp = optional_import("cupy") +_, has_cucim = optional_import("cucim") + +logger = logging.getLogger(__name__) + + +class Segmentation(TaskConfig): + def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **kwargs): + """Initializes the SegmentationBrats task.""" + super().init(name, model_dir, conf, planner, **kwargs) + + # BraTS labels: 3 multi-label channels produced by ConvertToMultiChannelBasedOnBratsClassesd + # Channel 0: TC - Tumor Core (label 2 OR label 3) + # Channel 1: WT - Whole Tumor (label 1 OR label 2 OR label 3) + # Channel 2: ET - Enhancing Tumor (label 2) + self.labels = { + "tumor core": 1, # Tumor Core + "whole tumor": 2, # Whole Tumor + "enhancing tumor": 3, # Enhancing Tumor + } + + # Model Files + self.path = [ + os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained + os.path.join(self.model_dir, f"{name}.pt"), # published + ] + + # Download PreTrained Model (optional) + if strtobool(self.conf.get("use_pretrained_model", "false")): + url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}" + url = f"{url}/radiology_segmentation_segresnet_brats.pt" + download_file(url, self.path[0]) + + # Spacing and ROI for BraTS (isotropic 1mm, large crop matching tutorial) + self.target_spacing = (1.0, 1.0, 1.0) + self.roi_size = (224, 224, 144) + + # Number of input channels: 4 MRI modalities (FLAIR, T1, T1Gd, T2) + # when multi_file=True the LoadDirectoryImagesd loader stacks them; + # when multi_file=False the image file must already be a 4-channel volume. + try: + input_channels = int(self.conf.get("input_channels", 4)) + except (ValueError, TypeError): + logger.warning("Could not parse input_channels, defaulting to 4") + input_channels = 4 + + # Network + self.network = SegResNet( + blocks_down=(1, 2, 2, 4), + blocks_up=(1, 1, 1), + init_filters=16, + in_channels=input_channels, + out_channels=len(self.labels), # TC, WT, ET — sigmoid multilabel, no background channel + dropout_prob=0.2, + ) + + def infer(self) -> Union[InferTask, Dict[str, InferTask]]: + """Creates the SegmentationBrats InferTask task.""" + task: InferTask = lib.infers.SegmentationBrats( + path=self.path, + network=self.network, + roi_size=self.roi_size, + target_spacing=self.target_spacing, + labels=self.labels, + preload=strtobool(self.conf.get("preload", "false")), + config={"largest_cc": True if has_cp and has_cucim else False}, + ) + return task + + def trainer(self) -> Optional[TrainTask]: + """Creates the SegmentationBrats Trainer task.""" + output_dir = os.path.join(self.model_dir, self.name) + load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1] + + task: TrainTask = lib.trainers.SegmentationBrats( + model_dir=output_dir, + network=self.network, + roi_size=self.roi_size, + target_spacing=self.target_spacing, + load_path=load_path, + publish_path=self.path[1], + description="Train BraTS Segmentation Model (TC/WT/ET multilabel)", + labels=self.labels, + ) + return task diff --git a/sample-apps/radiology/lib/infers/segmentation_brats.py b/sample-apps/radiology/lib/infers/segmentation_brats.py new file mode 100644 index 000000000..01efe7d3e --- /dev/null +++ b/sample-apps/radiology/lib/infers/segmentation_brats.py @@ -0,0 +1,175 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Sequence + +from lib.transforms.transforms import ConvertFromMultiChannelBasedOnBratsClassesd, GetCentroidsd, LoadDirectoryImagesd +from monai.inferers import Inferer, SlidingWindowInferer +from monai.transforms import ( + Activationsd, + AsDiscreted, + EnsureChannelFirstd, + EnsureTyped, + KeepLargestConnectedComponentd, + LoadImaged, + NormalizeIntensityd, + Orientationd, + Spacingd, +) + +from monailabel.interfaces.tasks.infer_v2 import InferType +from monailabel.tasks.infer.basic_infer import BasicInferTask +from monailabel.transform.post import Restored + + +class SegmentationBrats(BasicInferTask): + """ + Inference Engine for BraTS brain tumour segmentation using a SegResNet. + + The model outputs 3 channels (TC, WT, ET) with sigmoid activations — it is + a multilabel task, NOT a softmax classification. Each channel is thresholded + independently at 0.5 to produce binary maps. + + Two image loading modes are supported (set via ``data["multi_file"]``): + - False (default): the input image is a single 4-channel NIfTI volume. + - True: ``data["image"]`` is a directory containing 4 single- + modality NIfTI files; LoadDirectoryImagesd stacks them. + """ + + def __init__( + self, + path, + network=None, + target_spacing=(1.0, 1.0, 1.0), + type=InferType.SEGMENTATION, + labels=None, + dimension=3, + description="Pre-trained BraTS SegResNet — TC/WT/ET multilabel segmentation", + **kwargs, + ): + """ + Args: + path: path(s) to the model checkpoint(s). + network: optional pre-instantiated network; if None the checkpoint + is loaded directly. + target_spacing: voxel spacing to resample images to before inference. + type: inference type tag (default SEGMENTATION). + labels: label name → integer index mapping. + dimension: spatial dimension of the model (3 for volumetric). + description: human-readable description surfaced in the REST API. + **kwargs: forwarded to ``BasicInferTask``. + """ + super().__init__( + path=path, + network=network, + type=type, + labels=labels, + dimension=dimension, + description=description, + load_strict=False, + **kwargs, + ) + self.target_spacing = target_spacing + + def pre_transforms(self, data=None) -> Sequence[Callable]: + """ + Pre-processing pipeline matching the official MONAI BraTS tutorial. + + NOTE: ScaleIntensityRangePercentilesd and CenterSpatialCropd from the + original file have been removed — they are not part of the BraTS pipeline + and would distort MRI intensity normalisation. NormalizeIntensityd with + nonzero=True, channel_wise=True is the correct approach for multi-modal MRI. + """ + data = data or {} + channels = data.get("input_channels", 4) + t = [ + ( + LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + if data.get("multi_file", False) is False + else LoadDirectoryImagesd( + keys="image", + target_spacing=self.target_spacing, + channels=channels, + ) + ), + EnsureTyped(keys="image", device=data.get("device") if data else None), + # EnsureChannelFirstd is safe to keep as a guard; if the channel dim is + # already present (ITKReader + ensure_channel_first) it is a no-op. + EnsureChannelFirstd(keys="image", channel_dim=0), + Orientationd(keys="image", axcodes="RAS"), + Spacingd( + keys="image", + pixdim=self.target_spacing, + allow_missing_keys=True, + ), + # Channel-wise intensity normalisation on non-zero voxels only. + # This matches both the tutorial and the training pipeline exactly. + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + ] + return t + + def inferer(self, data=None) -> Inferer: + """Return a SlidingWindowInferer configured for BraTS volumetric inference.""" + return SlidingWindowInferer( + roi_size=self.roi_size, + sw_batch_size=2, + overlap=0.4, + padding_mode="replicate", + mode="gaussian", + ) + + def inverse_transforms(self, data=None): + """No inverse transforms needed; Restored handles spatial restoration directly.""" + return [] + + def post_transforms(self, data=None) -> Sequence[Callable]: + """ + Post-processing for multilabel sigmoid output. + + IMPORTANT differences from a softmax segmentation: + - Activationsd uses sigmoid=True (not softmax=True). + - AsDiscreted thresholds each channel at 0.5 independently + (not argmax, because channels are not mutually exclusive). + - KeepLargestConnectedComponentd is applied per-channel if available. + """ + data = data or {} + t = [ + EnsureTyped(keys="pred", device=data.get("device") if data else None), + # Sigmoid: each of the 3 channels (TC, WT, ET) is activated independently. + Activationsd(keys="pred", sigmoid=True), + # Threshold each channel at 0.5 to produce binary masks. + AsDiscreted(keys="pred", threshold=0.5), + ] + + if data and data.get("largest_cc", False): + # Apply per-channel so TC, WT and ET are each cleaned independently. + t.append( + KeepLargestConnectedComponentd( + keys="pred", + independent=True, # treat each channel separately + ) + ) + + t.extend( + [ + # Merge 3 binary channels → single-channel integer label map + # Must happen before Restored so spatial metadata is applied + # to the final (1, H, W, D) output, not the intermediate (3, H, W, D). + ConvertFromMultiChannelBasedOnBratsClassesd(keys="pred"), + Restored( + keys="pred", + ref_image="image", + config_labels=self.labels if data.get("restore_label_idx", False) else None, + ), + GetCentroidsd(keys="pred", centroids_key="centroids"), + ] + ) + return t diff --git a/sample-apps/radiology/lib/trainers/segmentation_brats.py b/sample-apps/radiology/lib/trainers/segmentation_brats.py new file mode 100644 index 000000000..6dff359e8 --- /dev/null +++ b/sample-apps/radiology/lib/trainers/segmentation_brats.py @@ -0,0 +1,242 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from lib.transforms.transforms import LoadDirectoryImagesd +from monai.handlers import TensorBoardImageHandler, from_engine +from monai.inferers import SlidingWindowInferer +from monai.losses import DiceLoss +from monai.transforms import ( + Activationsd, + AsDiscreted, + ConvertToMultiChannelBasedOnBratsClassesd, + EnsureTyped, + LoadImaged, + NormalizeIntensityd, + Orientationd, + RandFlipd, + RandScaleIntensityd, + RandShiftIntensityd, + RandSpatialCropd, + Spacingd, +) + +from monailabel.tasks.train.basic_train import BasicTrainTask, Context +from monailabel.tasks.train.utils import region_wise_metrics + +logger = logging.getLogger(__name__) + + +class SegmentationBrats(BasicTrainTask): + """ + Training task for BraTS brain tumor segmentation using a SegResNet. + + Implements a sigmoid multilabel pipeline predicting three overlapping + regions: Tumour Core (TC), Whole Tumor (WT), and Enhancing Tumor (ET). + Supports two data loading modes via ``context.multi_file``: + - False (default): image is a single 4-channel NIfTI volume. + - True: image is a directory of per-modality NIfTI files + stacked by ``LoadDirectoryImagesd``. + """ + + def __init__( + self, + model_dir, + network, + roi_size=(224, 224, 144), + target_spacing=(1.0, 1.0, 1.0), + num_samples=4, + description="Train BraTS Segmentation model (TC/WT/ET multilabel)", + **kwargs, + ): + """ + Args: + model_dir: directory where checkpoints are saved. + network: instantiated segmentation network (e.g. SegResNet). + roi_size: spatial crop size used during training and sliding-window + validation. + target_spacing: voxel spacing images are resampled to. + num_samples: random crops drawn per volume per epoch. + description: human-readable label surfaced in the REST API. + **kwargs: forwarded to ``BasicTrainTask``. + """ + self._network = network + self.roi_size = roi_size + self.target_spacing = target_spacing + self.num_samples = num_samples + super().__init__(model_dir, description, **kwargs) + + def network(self, context: Context): + """Return the SegResNet instance used for training.""" + return self._network + + def optimizer(self, context: Context): + """Adam optimizer with lr=1e-4 and weight_decay=1e-5, matching the BraTS tutorial.""" + return torch.optim.Adam(context.network.parameters(), lr=1e-4, weight_decay=1e-5) + + def loss_function(self, context: Context): + """Loss function used during training.""" + # BraTS is a sigmoid multilabel task (TC, WT, ET). + # to_onehot_y=False because the label is already 3-channel after + # ConvertToMultiChannelBasedOnBratsClassesd. + # sigmoid=True because each channel is independent (not mutually exclusive). + return DiceLoss( + smooth_nr=0, + smooth_dr=1e-5, + squared_pred=True, + to_onehot_y=False, + sigmoid=True, + ) + + def lr_scheduler_handler(self, context: Context): + """No LR scheduler — constant learning rate throughout training.""" + return None + + def train_data_loader(self, context, num_workers=0, shuffle=False): + """Training data loader with shuffling always enabled.""" + return super().train_data_loader(context, num_workers, True) + + def train_pre_transforms(self, context: Context): + """ + Transforms follow the official MONAI BraTS tutorial exactly. + + Two loading paths: + - multi_file=False : image is already a single 4-channel .nii.gz volume + (LoadImaged handles it, then EnsureChannelFirstd is a no-op + because ITKReader + ensure_channel_first already adds the channel dim) + - multi_file=True : a directory of 4 single-modality files is stacked by + LoadDirectoryImagesd into a (4, H, W, D) tensor + """ + channels = context.input_channels + return [ + ( + LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + if context.multi_file is False + else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels) + ), + LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True), + # ConvertToMultiChannelBasedOnBratsClassesd converts the integer label map + # to a 3-channel binary tensor: [TC, WT, ET]. + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + EnsureTyped(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=self.target_spacing, + mode=("bilinear", "nearest"), + ), + # Random crop matching the official tutorial roi + RandSpatialCropd( + keys=["image", "label"], + roi_size=self.roi_size, + random_size=False, + ), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), + RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), + # Channel-wise zero-mean / unit-std normalisation on non-zero voxels only + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), + RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), + ] + + def train_post_transforms(self, context: Context): + """ + Post-transforms for TRAINING metrics. + + Because this is a sigmoid multilabel task: + - Apply sigmoid activation per channel. + - Threshold at 0.5 to get binary predictions. + - The label is already binary 3-channel — no argmax / to_onehot needed. + """ + return [ + EnsureTyped(keys="pred", device=context.device), + Activationsd(keys="pred", sigmoid=True), + AsDiscreted(keys="pred", threshold=0.5), + # label is already binary 3-channel, nothing to do + ] + + def val_pre_transforms(self, context: Context): + """ + Validation pre-processing: same loading and normalization as training + but without any random augmentation or spatial cropping. + """ + channels = context.input_channels + return [ + ( + LoadImaged(keys="image", reader="ITKReader", ensure_channel_first=True) + if context.multi_file is False + else LoadDirectoryImagesd(keys="image", target_spacing=self.target_spacing, channels=channels) + ), + LoadImaged(keys="label", reader="ITKReader", ensure_channel_first=True), + ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), + EnsureTyped(keys=["image", "label"]), + Orientationd(keys=["image", "label"], axcodes="RAS"), + Spacingd( + keys=["image", "label"], + pixdim=self.target_spacing, + mode=("bilinear", "nearest"), + ), + # No crop during validation — sliding window covers the full volume + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + ] + + def val_inferer(self, context: Context): + """SlidingWindowInferer for full-volume validation with Gaussian blending.""" + return SlidingWindowInferer( + roi_size=self.roi_size, + sw_batch_size=2, + overlap=0.4, + padding_mode="replicate", + mode="gaussian", + ) + + def norm_labels(self): + """ + Return a label-name → channel-index mapping with contiguous indices + starting at 0, skipping ``"background"``. + + Used by ``region_wise_metrics`` to align predicted channels (TC=0, + WT=1, ET=2) with the correct label names for Dice computation. + """ + new_label_nums = {} + idx = 0 + for key_label in self._labels.keys(): + if key_label == "background": + continue + new_label_nums[key_label] = idx + idx += 1 + return new_label_nums + + def train_key_metric(self, context: Context): + """Per-region Dice metrics logged during training (TC, WT, ET).""" + return region_wise_metrics(self.norm_labels(), "train_mean_dice", "train") + + def val_key_metric(self, context: Context): + """Per-region Dice metrics logged during validation (TC, WT, ET).""" + return region_wise_metrics(self.norm_labels(), "val_mean_dice", "val") + + def train_handlers(self, context: Context): + """Extend default handlers with TensorBoard image logging every 20 epochs.""" + handlers = super().train_handlers(context) + if context.local_rank == 0: + handlers.append( + TensorBoardImageHandler( + log_dir=context.events_dir, + batch_transform=from_engine(["image", "label"]), + output_transform=from_engine(["pred"]), + interval=20, + epoch_level=True, + ) + ) + return handlers diff --git a/sample-apps/radiology/lib/transforms/transforms.py b/sample-apps/radiology/lib/transforms/transforms.py index c24202328..d03d1e1a2 100644 --- a/sample-apps/radiology/lib/transforms/transforms.py +++ b/sample-apps/radiology/lib/transforms/transforms.py @@ -10,6 +10,7 @@ # limitations under the License. import copy import logging +import os from typing import Any, Dict, Hashable, Mapping import numpy as np @@ -18,7 +19,17 @@ from monai.config import KeysCollection, NdarrayOrTensor from monai.data import MetaTensor from monai.networks.layers import GaussianFilter -from monai.transforms import CropForeground, GaussianSmooth, Randomizable, Resize, ScaleIntensity, SpatialCrop +from monai.transforms import ( + ConcatItemsd, + CropForeground, + EnsureChannelFirst, + GaussianSmooth, + LoadImage, + Randomizable, + Resize, + ScaleIntensity, + SpatialCrop, +) from monai.transforms.transform import MapTransform, Transform from monai.utils.enums import CommonKeys @@ -27,6 +38,174 @@ logger = logging.getLogger(__name__) +class ConvertFromMultiChannelBasedOnBratsClassesd(MapTransform): + """ + Dictionary-based transform that reverses + ``ConvertToMultiChannelBasedOnBratsClassesd``. + + Converts a 3-channel binary prediction (TC, WT, ET) back to a + single-channel integer label map: + + Output shape: (1, H, W, D), dtype ``torch.long`` by default. + + Args: + keys: keys of the items to be transformed. + dtype: output dtype, default ``torch.long``. + allow_missing_keys: don't raise an error if a key is missing. + """ + + def __init__( + self, + keys: KeysCollection, + dtype: torch.dtype = torch.long, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.dtype = dtype + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Merge 3 binary channels into a single-channel integer label map. + + Channel → label assignment: + - WT only (wt & ~tc) -> 1 (oedema / peritumoral) + - ET (et) -> 2 (enhancing tumor) + - TC only (tc & ~et) -> 3 (necrotic core) + - background -> 0 + """ + d = dict(data) + for key in self.key_iterator(d): + img = d[key] + + if img.shape[0] != 3: + raise ValueError( + f"Expected 3-channel input (TC, WT, ET) for key '{key}', " f"got {img.shape[0]} channels." + ) + + tc = img[0].bool() + wt = img[1].bool() + et = img[2].bool() + + label_map = torch.zeros_like(img[0], dtype=self.dtype) + label_map[wt & ~tc] = 1 # Oedema + label_map[et] = 2 # Enhancing tumour + label_map[tc & ~et] = 3 # Necrotic core + + result = label_map.unsqueeze(0) + + d[key] = MetaTensor(result, meta=img.meta) if isinstance(img, MetaTensor) else result + + return d + + +# Adapted from https://github.com/Project-MONAI/MONAILabel/issues/241#issuecomment-1497561538 +class LoadDirectoryImagesd(MapTransform): + """ + Load all 3D images from a directory, stack them along a new axis, + and preserve MONAI-style metadata similar to LoadImaged. + + Each key should point to a directory of NIfTI/NRRD files (.nii, .nii.gz, + .nrrd). Files are loaded in the BraTS modality order (T1, T1ce, T2, FLAIR) + by matching the ``_t1``, ``_t1ce``, ``_t2``, ``_flair`` filename suffixes, + resized to a common spatial shape, and concatenated into a single + (C, H, W, D) tensor. Metadata from the first file is preserved and stored + in ``d[f"{key}_meta_dict"]``. A ``ValueError`` is raised if any expected + modality file is missing or ambiguous. + + Args: + keys: keys of the directory paths to transform. + target_spacing: if provided, voxel spacing to which each image is + resampled before stacking (passed to ``Spacingd`` in the pipeline). + When ``None`` no resampling is applied here. + allow_missing_keys: don't raise an error if a key is absent. + channels: expected number of modality files in the directory. A + ``ValueError`` is raised if the count doesn't match. Set to 0 to + skip the check. + """ + + def __init__(self, keys: KeysCollection, target_spacing=None, allow_missing_keys: bool = False, channels: int = 2): + super().__init__(keys, allow_missing_keys) + self.target_spacing = target_spacing + self.loader = LoadImage(reader="ITKReader", image_only=False) + self.ensure_channel_first = EnsureChannelFirst() + self.channels = int(channels) + + def __call__(self, data: Dict): + d = dict(data) + + for key in self.key_iterator(d): + dir_path = d[key] + if not os.path.isdir(dir_path): + raise ValueError(f"Expected a directory path for key '{key}', got: {dir_path}") + + # Gather files in the required BraTS modality order: T1, T1ce, T2, FLAIR. + # Alphabetical sort is intentionally avoided, it would place FLAIR first, + # assigning the wrong channel index to each modality. + _MODALITY_SUFFIXES = ["_t1", "_t1ce", "_t2", "_flair"] + all_files = [f for f in os.listdir(dir_path) if f.lower().endswith((".nii", ".nii.gz", ".nrrd"))] + image_files = [] + for suffix in _MODALITY_SUFFIXES: + matches = [ + os.path.join(dir_path, f) + for f in all_files + if os.path.splitext(os.path.splitext(f)[0])[0].lower().endswith(suffix) + or f.lower().endswith(suffix + ".nrrd") + ] + if len(matches) != 1: + raise ValueError( + f"Expected exactly one file matching '*{suffix}' in {dir_path}, " + f"found {len(matches)}: {matches}" + ) + image_files.append(matches[0]) + + if 0 < self.channels != len(image_files): + raise ValueError(f"Expected {self.channels} modality files in {dir_path}, found {len(image_files)}") + + channel_keys = [] + meta_dicts = [] + resizer = None + + logger.info(f"Loading {len(image_files)} images from {dir_path}") + + for idx, img_path in enumerate(image_files): + img, meta = self.loader(img_path) + img = self.ensure_channel_first(img) + + if resizer is None: + resizer = Resize(spatial_size=img.shape[1:], mode="bilinear") + img = resizer(img) + + ch_key = f"{key}_ch{idx + 1}" + d[ch_key] = img + d[f"{ch_key}_meta_dict"] = meta + + channel_keys.append(ch_key) + meta_dicts.append(meta) + + logger.debug(f"Loaded {ch_key}: {img.shape}") + + # MONAI-native concatenation + concat = ConcatItemsd(keys=channel_keys, name=key, dim=0) + d = concat(d) + + # Clean up temporary channel keys + for ch_key in channel_keys: + d.pop(ch_key, None) + d.pop(f"{ch_key}_meta_dict", None) + + # Construct merged metadata + merged_meta = copy.deepcopy(meta_dicts[0]) + merged_meta["filename_or_obj"] = image_files + merged_meta["num_channels"] = len(channel_keys) + merged_meta["original_channel_dim"] = 0 + + d[f"{key}_meta_dict"] = merged_meta + + logger.info(f"Concatenated {len(channel_keys)} images → {d[key].shape}") + + return d + + class BinaryMaskd(MapTransform): def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): """ diff --git a/sample-apps/radiology/main.py b/sample-apps/radiology/main.py index bb7f8ae18..a9e3672cb 100644 --- a/sample-apps/radiology/main.py +++ b/sample-apps/radiology/main.py @@ -307,12 +307,18 @@ def main(): parser.add_argument("-s", "--studies", default=studies) parser.add_argument("-m", "--model", default="segmentation") parser.add_argument("-t", "--test", default="batch_infer", choices=("train", "infer", "batch_infer")) + parser.add_argument("-multi", "--multichannel", action="store_true", help="Enable multichannel (4D) data loading") + parser.add_argument("-c", "--input_channels", type=int, default=1, help="Number of input channels") + parser.add_argument("-multif", "--multi_file", action="store_true", help="Enable multi-file data loading") args = parser.parse_args() app_dir = os.path.dirname(__file__) studies = args.studies conf = { "models": args.model, + "multichannel": args.multichannel, + "input_channels": args.input_channels, + "multi_file": args.multi_file, "preload": "false", } @@ -326,7 +332,16 @@ def main(): # Run on all devices for device in device_list(): - res = app.infer(request={"model": args.model, "image": image_id, "device": device}) + res = app.infer( + request={ + "model": args.model, + "image": image_id, + "device": device, + "multichannel": conf["multichannel"], + "input_channels": conf["input_channels"], + "multi_file": conf["multi_file"], + } + ) # res = app.infer( # request={"model": "vertebra_pipeline", "image": image_id, "device": device, "slicer": False} # ) @@ -354,6 +369,9 @@ def main(): "label_tag": "original", "max_workers": 1, "max_batch_size": 0, + "multichannel": conf["multichannel"], + "input_channels": conf["input_channels"], + "multi_file": conf["multi_file"], } ) @@ -380,6 +398,9 @@ def main(): "val_batch_size": 1, "multi_gpu": False, "val_split": 0.1, + "multichannel": conf["multichannel"], + "input_channels": conf["input_channels"], + "multi_file": conf["multi_file"], }, ) From 9dc1c0a834f9519485d06c5e1f4804b665765773 Mon Sep 17 00:00:00 2001 From: Cavan Riley Date: Wed, 4 Mar 2026 10:20:35 -0600 Subject: [PATCH 4/4] ADD: add documentation readme for training BraTS model Signed-off-by: Cavan Riley --- sample-apps/radiology/README.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sample-apps/radiology/README.md b/sample-apps/radiology/README.md index 5c7356940..6ea21f95e 100644 --- a/sample-apps/radiology/README.md +++ b/sample-apps/radiology/README.md @@ -215,6 +215,34 @@ the model to learn on new organ. - Output: N channels representing the segmented organs/tumors/tissues +
+ + Segmentation BraTS is a model based on UNet for automated multilabel brain tumor segmentation. This model is designed for multi-label segmentation tasks using pre-aligned, multi-modal MRI volumes. + + +> monailabel start_server --app workspace/radiology --studies workspace/images --conf models segmentation_brats --conf input_channels 4 --conf multi_file true + +- Additional Configs *(pass them as **--conf name value** while starting MONAILabel Server)* + +| Name | Values | Description | +|----------------------|------------------|--------------------------------------------------------------------------| +| use_pretrained_model | **true**, false | Set to `false` to skip loading pretrained weights | +| preload | true, **false** | Preload model into GPU at startup | +| scribbles | **true**, false | Set to `false` to disable scribble-based interactive segmentation models | + +- Network: This model uses the [UNet](https://docs.monai.io/en/latest/networks.html#unet) as the default network. Researchers can define their own network or use one of the listed [MONAI network architectures](https://docs.monai.io/en/latest/networks.html) +- Labels + ```json + { + "tumor core": 1, + "whole tumor": 2, + "enhancing tumor": 3 + } + ``` +- Dataset: The model is trained over the dataset: https://www.med.upenn.edu/cbica/brats2020/ +- Inputs: 4 channels for the 4 BRATS image modalities +- Output: N channels representing the segmented tumors/tissues +