diff --git a/pyproject.toml b/pyproject.toml index 4de8c9ef..d0d0c35b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ module = [ # ipywidgets is an optional dependency "ipywidgets.*", "requests_toolbelt.*", + "rfdetr.*", "torch.*", "ultralytics.*", ] diff --git a/roboflow/util/model_processor.py b/roboflow/util/model_processor.py index 4ff0ce1a..cca98012 100644 --- a/roboflow/util/model_processor.py +++ b/roboflow/util/model_processor.py @@ -19,6 +19,30 @@ ) from roboflow.util.versions import print_warn_for_wrong_dependencies_versions +# Minimum rf-detr release shipping `RFDETR.export_for_roboflow`. +RFDETR_MIN_VERSION = "1.8.0" # TODO: pin once rf-detr release is cut + +# rf-detr model_type -> RFDETR subclass name. Single source of truth for both the +# supported-type check and the `from_checkpoint` fallback (used when a raw checkpoint +# lacks the metadata rf-detr needs to infer its own class). +_RFDETR_MODEL_TYPE_TO_CLASS = { + # Detection + "rfdetr-base": "RFDETRBase", + "rfdetr-nano": "RFDETRNano", + "rfdetr-small": "RFDETRSmall", + "rfdetr-medium": "RFDETRMedium", + "rfdetr-large": "RFDETRLarge", + "rfdetr-xlarge": "RFDETRXLarge", + "rfdetr-2xlarge": "RFDETR2XLarge", + # Segmentation + "rfdetr-seg-nano": "RFDETRSegNano", + "rfdetr-seg-small": "RFDETRSegSmall", + "rfdetr-seg-medium": "RFDETRSegMedium", + "rfdetr-seg-large": "RFDETRSegLarge", + "rfdetr-seg-xlarge": "RFDETRSegXLarge", + "rfdetr-seg-2xlarge": "RFDETRSeg2XLarge", +} + def task_of_model_type(model_type: str) -> str: """Canonical task for a deploy model_type string. @@ -339,24 +363,35 @@ def _detect_rfdetr_task(checkpoint) -> Optional[str]: return None +def _is_ptl_checkpoint(checkpoint) -> bool: + """True if `checkpoint` is a raw PyTorch-Lightning rf-detr checkpoint dict.""" + return isinstance(checkpoint, dict) and "pytorch-lightning_version" in checkpoint + + +def _require_rfdetr(): + """Lazily import `rfdetr` and verify it ships the upload-bundle helpers. + + Raises a RuntimeError with an actionable hint if rfdetr is missing or too old. + """ + try: + import rfdetr + except ImportError: + raise RuntimeError( + "rfdetr is required to upload PyTorch-Lightning rf-detr checkpoints. " + f"Please install it with `pip install 'rfdetr>={RFDETR_MIN_VERSION}'`." + ) + + if not hasattr(rfdetr.RFDETR, "export_for_roboflow"): + raise RuntimeError( + "The installed rfdetr is too old to upload PyTorch-Lightning rf-detr checkpoints. " + f"Please upgrade it with `pip install --upgrade 'rfdetr>={RFDETR_MIN_VERSION}'`." + ) + + return rfdetr + + def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[str, str]: - _supported_types = [ - # Detection models - "rfdetr-base", - "rfdetr-nano", - "rfdetr-small", - "rfdetr-medium", - "rfdetr-large", - "rfdetr-xlarge", - "rfdetr-2xlarge", - # Segmentation models - "rfdetr-seg-nano", - "rfdetr-seg-small", - "rfdetr-seg-medium", - "rfdetr-seg-large", - "rfdetr-seg-xlarge", - "rfdetr-seg-2xlarge", - ] + _supported_types = list(_RFDETR_MODEL_TYPE_TO_CLASS) if model_type not in _supported_types: raise ValueError(f"Model type {model_type} not supported. Supported types are {_supported_types}") @@ -382,11 +417,25 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[st f".pt is a '{detected_task}' rfdetr checkpoint. Use a matching model_type." ) - get_classnames_txt_for_rfdetr(model_path, pt_file, checkpoint=checkpoint) + if _is_ptl_checkpoint(checkpoint): + # Raw PyTorch-Lightning checkpoint: let rf-detr rebuild a proper upload + # bundle (weights.pt with `args.resolution` + class_names.txt). + rfdetr = _require_rfdetr() + pth = os.path.join(model_path, pt_file) + try: + model = rfdetr.RFDETR.from_checkpoint(pth) + except ValueError: + # Checkpoint lacks model_name/pretrain_weights signals; fall back to + # the already-validated user-provided model_type to pick the subclass. + model_cls = getattr(rfdetr, _RFDETR_MODEL_TYPE_TO_CLASS[model_type]) + model = model_cls(pretrain_weights=pth) + model.export_for_roboflow(model_path) # writes weights.pt + class_names.txt + else: + get_classnames_txt_for_rfdetr(model_path, pt_file, checkpoint=checkpoint) - # Copy the .pt file to weights.pt if not already named weights.pt - if pt_file != "weights.pt": - shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt")) + # Copy the .pt file to weights.pt if not already named weights.pt + if pt_file != "weights.pt": + shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt")) required_files = ["weights.pt"] diff --git a/tests/util/test_model_processor.py b/tests/util/test_model_processor.py index 37ecb186..604e25ca 100644 --- a/tests/util/test_model_processor.py +++ b/tests/util/test_model_processor.py @@ -1,12 +1,28 @@ import os +import sys import tempfile import unittest +import zipfile from types import SimpleNamespace +from unittest import mock + +try: + # torch is an optional, lazily-imported SDK dependency; absent in CI. Tests that + # round-trip a real checkpoint through `_process_rfdetr` are skipped without it. + import torch + + _HAS_TORCH = True +except ImportError: + _HAS_TORCH = False from roboflow.config import TASK_CLS, TASK_DET, TASK_OBB, TASK_POSE, TASK_SEG from roboflow.util.model_processor import ( + _RFDETR_MODEL_TYPE_TO_CLASS, _detect_rfdetr_task, _detect_yolo_task, + _is_ptl_checkpoint, + _process_rfdetr, + _require_rfdetr, get_classnames_txt_for_rfdetr, task_of_model_type, ) @@ -104,5 +120,161 @@ def test_namespace_args(self): ) +class RfdetrModelTypeToClassTest(unittest.TestCase): + def test_representative_mappings(self): + self.assertEqual(_RFDETR_MODEL_TYPE_TO_CLASS["rfdetr-seg-medium"], "RFDETRSegMedium") + self.assertEqual(_RFDETR_MODEL_TYPE_TO_CLASS["rfdetr-base"], "RFDETRBase") + + def test_keys_are_rfdetr_types_and_values_are_class_names(self): + for model_type, class_name in _RFDETR_MODEL_TYPE_TO_CLASS.items(): + self.assertTrue(model_type.startswith("rfdetr-"), model_type) + self.assertTrue(class_name.startswith("RFDETR"), class_name) + # Segmentation types must map to Seg classes (and detection types must not). + for model_type, class_name in _RFDETR_MODEL_TYPE_TO_CLASS.items(): + self.assertEqual("seg" in model_type, "Seg" in class_name, model_type) + + +class IsPtlCheckpointTest(unittest.TestCase): + def test_true_when_lightning_version_present(self): + self.assertTrue(_is_ptl_checkpoint({"pytorch-lightning_version": "2.1.0", "args": {}})) + + def test_false_for_plain_checkpoint(self): + self.assertFalse(_is_ptl_checkpoint({"args": {}, "model": {}})) + + def test_false_for_non_dict(self): + self.assertFalse(_is_ptl_checkpoint(None)) + self.assertFalse(_is_ptl_checkpoint(SimpleNamespace(**{"pytorch-lightning_version": "2.1.0"}))) + + +class _StubBundleModel: + """Stub rf-detr model whose export_for_roboflow writes a dummy bundle on disk.""" + + def __init__(self, class_names=("cat", "dog")): + self.class_names = list(class_names) + + def export_for_roboflow(self, output_dir): + torch.save({"dummy": True}, os.path.join(output_dir, "weights.pt")) + with open(os.path.join(output_dir, "class_names.txt"), "w") as f: + for name in self.class_names: + f.write(name + "\n") + + +def _make_fake_rfdetr(*, from_checkpoint_raises=False, capabilities=True): + """Build a fake `rfdetr` module for injection via sys.modules.""" + + stub_model = _StubBundleModel() + calls = {"from_checkpoint": 0, "fallback_constructed": 0, "constructor_kwargs": None} + + class _RFDETR: + @staticmethod + def from_checkpoint(path): + calls["from_checkpoint"] += 1 + if from_checkpoint_raises: + raise ValueError("cannot infer model class") + return stub_model + + class _SizedModel(_StubBundleModel): + def __init__(self, *, pretrain_weights): + super().__init__() + calls["fallback_constructed"] += 1 + calls["constructor_kwargs"] = {"pretrain_weights": pretrain_weights} + + module = SimpleNamespace() + module.RFDETR = _RFDETR + # The SDK fallback resolves the subclass by name via _RFDETR_MODEL_TYPE_TO_CLASS, + # e.g. "rfdetr-seg-medium" -> getattr(rfdetr, "RFDETRSegMedium"). + module.RFDETRSegMedium = _SizedModel + + if capabilities: + _RFDETR.export_for_roboflow = _StubBundleModel.export_for_roboflow # capability marker on class + + module._calls = calls + return module + + +class RequireRfdetrTest(unittest.TestCase): + def test_raises_when_not_installed(self): + with mock.patch.dict(sys.modules, {"rfdetr": None}): + with self.assertRaises(RuntimeError) as ctx: + _require_rfdetr() + self.assertIn("pip install", str(ctx.exception).lower()) + self.assertIn("rfdetr", str(ctx.exception).lower()) + + def test_raises_when_capability_missing(self): + # rfdetr present but RFDETR lacks export_for_roboflow (too old) + fake = SimpleNamespace(RFDETR=type("RFDETR", (), {})) + with mock.patch.dict(sys.modules, {"rfdetr": fake}): + with self.assertRaises(RuntimeError) as ctx: + _require_rfdetr() + self.assertIn("upgrade", str(ctx.exception).lower()) + + def test_returns_module_when_capable(self): + fake = _make_fake_rfdetr() + with mock.patch.dict(sys.modules, {"rfdetr": fake}): + self.assertIs(_require_rfdetr(), fake) + + +@unittest.skipUnless(_HAS_TORCH, "requires torch") +class ProcessRfdetrPtlTest(unittest.TestCase): + def _write_ptl_checkpoint(self, model_path, *, segmentation_head=False, class_names=("cat", "dog")): + checkpoint = { + "pytorch-lightning_version": "2.1.0", + "args": {"segmentation_head": segmentation_head, "class_names": list(class_names)}, + } + torch.save(checkpoint, os.path.join(model_path, "checkpoint_best_ema.pth")) + + def test_from_checkpoint_success_produces_bundle(self): + fake = _make_fake_rfdetr() + with tempfile.TemporaryDirectory() as model_path: + self._write_ptl_checkpoint(model_path) + with mock.patch.dict(sys.modules, {"rfdetr": fake}): + zip_name, model_type = _process_rfdetr("rfdetr-base", model_path, "checkpoint_best_ema.pth") + self.assertEqual(model_type, "rfdetr-base") + self.assertEqual(fake._calls["from_checkpoint"], 1) + self.assertEqual(fake._calls["fallback_constructed"], 0) + with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z: + self.assertIn("weights.pt", z.namelist()) + + def test_from_checkpoint_valueerror_falls_back_to_model_type(self): + fake = _make_fake_rfdetr(from_checkpoint_raises=True) + with tempfile.TemporaryDirectory() as model_path: + self._write_ptl_checkpoint(model_path, segmentation_head=True) + pth = os.path.join(model_path, "checkpoint_best_ema.pth") + with mock.patch.dict(sys.modules, {"rfdetr": fake}): + zip_name, model_type = _process_rfdetr("rfdetr-seg-medium", model_path, "checkpoint_best_ema.pth") + self.assertEqual(model_type, "rfdetr-seg-medium") + self.assertEqual(fake._calls["from_checkpoint"], 1) + self.assertEqual(fake._calls["fallback_constructed"], 1) + self.assertEqual(fake._calls["constructor_kwargs"], {"pretrain_weights": pth}) + with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z: + self.assertIn("weights.pt", z.namelist()) + + def test_ptl_path_raises_when_rfdetr_absent(self): + with tempfile.TemporaryDirectory() as model_path: + self._write_ptl_checkpoint(model_path) + with mock.patch.dict(sys.modules, {"rfdetr": None}): + with self.assertRaises(RuntimeError): + _process_rfdetr("rfdetr-base", model_path, "checkpoint_best_ema.pth") + + +@unittest.skipUnless(_HAS_TORCH, "requires torch") +class ProcessRfdetrLegacyTest(unittest.TestCase): + def _write_legacy_checkpoint(self, model_path): + checkpoint = {"args": {"segmentation_head": False, "class_names": ["cat", "dog"]}} + torch.save(checkpoint, os.path.join(model_path, "weights.pt")) + + def test_legacy_path_produces_bundle_without_importing_rfdetr(self): + with tempfile.TemporaryDirectory() as model_path: + self._write_legacy_checkpoint(model_path) + # Make any attempt to import rfdetr fail loudly. + with mock.patch.dict(sys.modules, {"rfdetr": None}): + zip_name, model_type = _process_rfdetr("rfdetr-base", model_path, "weights.pt") + self.assertEqual(model_type, "rfdetr-base") + with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z: + names = z.namelist() + self.assertIn("weights.pt", names) + self.assertIn("class_names.txt", names) + + if __name__ == "__main__": unittest.main()