Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ module = [
# ipywidgets is an optional dependency
"ipywidgets.*",
"requests_toolbelt.*",
"rfdetr.*",
"torch.*",
"ultralytics.*",
]
Expand Down
91 changes: 70 additions & 21 deletions roboflow/util/model_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
)
from roboflow.util.versions import print_warn_for_wrong_dependencies_versions

# Minimum rf-detr release shipping `RFDETR.export_for_roboflow`.
RFDETR_MIN_VERSION = "1.8.0" # TODO: pin once rf-detr release is cut

# rf-detr model_type -> RFDETR subclass name. Single source of truth for both the
# supported-type check and the `from_checkpoint` fallback (used when a raw checkpoint
# lacks the metadata rf-detr needs to infer its own class).
_RFDETR_MODEL_TYPE_TO_CLASS = {
# Detection
"rfdetr-base": "RFDETRBase",
"rfdetr-nano": "RFDETRNano",
"rfdetr-small": "RFDETRSmall",
"rfdetr-medium": "RFDETRMedium",
"rfdetr-large": "RFDETRLarge",
"rfdetr-xlarge": "RFDETRXLarge",
"rfdetr-2xlarge": "RFDETR2XLarge",
# Segmentation
"rfdetr-seg-nano": "RFDETRSegNano",
"rfdetr-seg-small": "RFDETRSegSmall",
"rfdetr-seg-medium": "RFDETRSegMedium",
"rfdetr-seg-large": "RFDETRSegLarge",
"rfdetr-seg-xlarge": "RFDETRSegXLarge",
"rfdetr-seg-2xlarge": "RFDETRSeg2XLarge",
}


def task_of_model_type(model_type: str) -> str:
"""Canonical task for a deploy model_type string.
Expand Down Expand Up @@ -339,24 +363,35 @@ def _detect_rfdetr_task(checkpoint) -> Optional[str]:
return None


def _is_ptl_checkpoint(checkpoint) -> bool:
"""True if `checkpoint` is a raw PyTorch-Lightning rf-detr checkpoint dict."""
return isinstance(checkpoint, dict) and "pytorch-lightning_version" in checkpoint


def _require_rfdetr():
"""Lazily import `rfdetr` and verify it ships the upload-bundle helpers.

Raises a RuntimeError with an actionable hint if rfdetr is missing or too old.
"""
try:
import rfdetr
except ImportError:
raise RuntimeError(
"rfdetr is required to upload PyTorch-Lightning rf-detr checkpoints. "
f"Please install it with `pip install 'rfdetr>={RFDETR_MIN_VERSION}'`."
)

if not hasattr(rfdetr.RFDETR, "export_for_roboflow"):
raise RuntimeError(
"The installed rfdetr is too old to upload PyTorch-Lightning rf-detr checkpoints. "
f"Please upgrade it with `pip install --upgrade 'rfdetr>={RFDETR_MIN_VERSION}'`."
)

return rfdetr


def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[str, str]:
_supported_types = [
# Detection models
"rfdetr-base",
"rfdetr-nano",
"rfdetr-small",
"rfdetr-medium",
"rfdetr-large",
"rfdetr-xlarge",
"rfdetr-2xlarge",
# Segmentation models
"rfdetr-seg-nano",
"rfdetr-seg-small",
"rfdetr-seg-medium",
"rfdetr-seg-large",
"rfdetr-seg-xlarge",
"rfdetr-seg-2xlarge",
]
_supported_types = list(_RFDETR_MODEL_TYPE_TO_CLASS)
if model_type not in _supported_types:
raise ValueError(f"Model type {model_type} not supported. Supported types are {_supported_types}")

Expand All @@ -382,11 +417,25 @@ def _process_rfdetr(model_type: str, model_path: str, filename: str) -> tuple[st
f".pt is a '{detected_task}' rfdetr checkpoint. Use a matching model_type."
)

get_classnames_txt_for_rfdetr(model_path, pt_file, checkpoint=checkpoint)
if _is_ptl_checkpoint(checkpoint):
# Raw PyTorch-Lightning checkpoint: let rf-detr rebuild a proper upload
# bundle (weights.pt with `args.resolution` + class_names.txt).
rfdetr = _require_rfdetr()
pth = os.path.join(model_path, pt_file)
try:
model = rfdetr.RFDETR.from_checkpoint(pth)
except ValueError:
# Checkpoint lacks model_name/pretrain_weights signals; fall back to
# the already-validated user-provided model_type to pick the subclass.
model_cls = getattr(rfdetr, _RFDETR_MODEL_TYPE_TO_CLASS[model_type])
model = model_cls(pretrain_weights=pth)
model.export_for_roboflow(model_path) # writes weights.pt + class_names.txt
else:
get_classnames_txt_for_rfdetr(model_path, pt_file, checkpoint=checkpoint)

# Copy the .pt file to weights.pt if not already named weights.pt
if pt_file != "weights.pt":
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))
# Copy the .pt file to weights.pt if not already named weights.pt
if pt_file != "weights.pt":
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))

required_files = ["weights.pt"]

Expand Down
172 changes: 172 additions & 0 deletions tests/util/test_model_processor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import os
import sys
import tempfile
import unittest
import zipfile
from types import SimpleNamespace
from unittest import mock

try:
# torch is an optional, lazily-imported SDK dependency; absent in CI. Tests that
# round-trip a real checkpoint through `_process_rfdetr` are skipped without it.
import torch

_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False

from roboflow.config import TASK_CLS, TASK_DET, TASK_OBB, TASK_POSE, TASK_SEG
from roboflow.util.model_processor import (
_RFDETR_MODEL_TYPE_TO_CLASS,
_detect_rfdetr_task,
_detect_yolo_task,
_is_ptl_checkpoint,
_process_rfdetr,
_require_rfdetr,
get_classnames_txt_for_rfdetr,
task_of_model_type,
)
Expand Down Expand Up @@ -104,5 +120,161 @@ def test_namespace_args(self):
)


class RfdetrModelTypeToClassTest(unittest.TestCase):
def test_representative_mappings(self):
self.assertEqual(_RFDETR_MODEL_TYPE_TO_CLASS["rfdetr-seg-medium"], "RFDETRSegMedium")
self.assertEqual(_RFDETR_MODEL_TYPE_TO_CLASS["rfdetr-base"], "RFDETRBase")

def test_keys_are_rfdetr_types_and_values_are_class_names(self):
for model_type, class_name in _RFDETR_MODEL_TYPE_TO_CLASS.items():
self.assertTrue(model_type.startswith("rfdetr-"), model_type)
self.assertTrue(class_name.startswith("RFDETR"), class_name)
# Segmentation types must map to Seg classes (and detection types must not).
for model_type, class_name in _RFDETR_MODEL_TYPE_TO_CLASS.items():
self.assertEqual("seg" in model_type, "Seg" in class_name, model_type)


class IsPtlCheckpointTest(unittest.TestCase):
def test_true_when_lightning_version_present(self):
self.assertTrue(_is_ptl_checkpoint({"pytorch-lightning_version": "2.1.0", "args": {}}))

def test_false_for_plain_checkpoint(self):
self.assertFalse(_is_ptl_checkpoint({"args": {}, "model": {}}))

def test_false_for_non_dict(self):
self.assertFalse(_is_ptl_checkpoint(None))
self.assertFalse(_is_ptl_checkpoint(SimpleNamespace(**{"pytorch-lightning_version": "2.1.0"})))


class _StubBundleModel:
"""Stub rf-detr model whose export_for_roboflow writes a dummy bundle on disk."""

def __init__(self, class_names=("cat", "dog")):
self.class_names = list(class_names)

def export_for_roboflow(self, output_dir):
torch.save({"dummy": True}, os.path.join(output_dir, "weights.pt"))
with open(os.path.join(output_dir, "class_names.txt"), "w") as f:
for name in self.class_names:
f.write(name + "\n")


def _make_fake_rfdetr(*, from_checkpoint_raises=False, capabilities=True):
"""Build a fake `rfdetr` module for injection via sys.modules."""

stub_model = _StubBundleModel()
calls = {"from_checkpoint": 0, "fallback_constructed": 0, "constructor_kwargs": None}

class _RFDETR:
@staticmethod
def from_checkpoint(path):
calls["from_checkpoint"] += 1
if from_checkpoint_raises:
raise ValueError("cannot infer model class")
return stub_model

class _SizedModel(_StubBundleModel):
def __init__(self, *, pretrain_weights):
super().__init__()
calls["fallback_constructed"] += 1
calls["constructor_kwargs"] = {"pretrain_weights": pretrain_weights}

module = SimpleNamespace()
module.RFDETR = _RFDETR
# The SDK fallback resolves the subclass by name via _RFDETR_MODEL_TYPE_TO_CLASS,
# e.g. "rfdetr-seg-medium" -> getattr(rfdetr, "RFDETRSegMedium").
module.RFDETRSegMedium = _SizedModel

if capabilities:
_RFDETR.export_for_roboflow = _StubBundleModel.export_for_roboflow # capability marker on class

module._calls = calls
return module


class RequireRfdetrTest(unittest.TestCase):
def test_raises_when_not_installed(self):
with mock.patch.dict(sys.modules, {"rfdetr": None}):
with self.assertRaises(RuntimeError) as ctx:
_require_rfdetr()
self.assertIn("pip install", str(ctx.exception).lower())
self.assertIn("rfdetr", str(ctx.exception).lower())

def test_raises_when_capability_missing(self):
# rfdetr present but RFDETR lacks export_for_roboflow (too old)
fake = SimpleNamespace(RFDETR=type("RFDETR", (), {}))
with mock.patch.dict(sys.modules, {"rfdetr": fake}):
with self.assertRaises(RuntimeError) as ctx:
_require_rfdetr()
self.assertIn("upgrade", str(ctx.exception).lower())

def test_returns_module_when_capable(self):
fake = _make_fake_rfdetr()
with mock.patch.dict(sys.modules, {"rfdetr": fake}):
self.assertIs(_require_rfdetr(), fake)


@unittest.skipUnless(_HAS_TORCH, "requires torch")
class ProcessRfdetrPtlTest(unittest.TestCase):
def _write_ptl_checkpoint(self, model_path, *, segmentation_head=False, class_names=("cat", "dog")):
checkpoint = {
"pytorch-lightning_version": "2.1.0",
"args": {"segmentation_head": segmentation_head, "class_names": list(class_names)},
}
torch.save(checkpoint, os.path.join(model_path, "checkpoint_best_ema.pth"))

def test_from_checkpoint_success_produces_bundle(self):
fake = _make_fake_rfdetr()
with tempfile.TemporaryDirectory() as model_path:
self._write_ptl_checkpoint(model_path)
with mock.patch.dict(sys.modules, {"rfdetr": fake}):
zip_name, model_type = _process_rfdetr("rfdetr-base", model_path, "checkpoint_best_ema.pth")
self.assertEqual(model_type, "rfdetr-base")
self.assertEqual(fake._calls["from_checkpoint"], 1)
self.assertEqual(fake._calls["fallback_constructed"], 0)
with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z:
self.assertIn("weights.pt", z.namelist())

def test_from_checkpoint_valueerror_falls_back_to_model_type(self):
fake = _make_fake_rfdetr(from_checkpoint_raises=True)
with tempfile.TemporaryDirectory() as model_path:
self._write_ptl_checkpoint(model_path, segmentation_head=True)
pth = os.path.join(model_path, "checkpoint_best_ema.pth")
with mock.patch.dict(sys.modules, {"rfdetr": fake}):
zip_name, model_type = _process_rfdetr("rfdetr-seg-medium", model_path, "checkpoint_best_ema.pth")
self.assertEqual(model_type, "rfdetr-seg-medium")
self.assertEqual(fake._calls["from_checkpoint"], 1)
self.assertEqual(fake._calls["fallback_constructed"], 1)
self.assertEqual(fake._calls["constructor_kwargs"], {"pretrain_weights": pth})
with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z:
self.assertIn("weights.pt", z.namelist())

def test_ptl_path_raises_when_rfdetr_absent(self):
with tempfile.TemporaryDirectory() as model_path:
self._write_ptl_checkpoint(model_path)
with mock.patch.dict(sys.modules, {"rfdetr": None}):
with self.assertRaises(RuntimeError):
_process_rfdetr("rfdetr-base", model_path, "checkpoint_best_ema.pth")


@unittest.skipUnless(_HAS_TORCH, "requires torch")
class ProcessRfdetrLegacyTest(unittest.TestCase):
def _write_legacy_checkpoint(self, model_path):
checkpoint = {"args": {"segmentation_head": False, "class_names": ["cat", "dog"]}}
torch.save(checkpoint, os.path.join(model_path, "weights.pt"))

def test_legacy_path_produces_bundle_without_importing_rfdetr(self):
with tempfile.TemporaryDirectory() as model_path:
self._write_legacy_checkpoint(model_path)
# Make any attempt to import rfdetr fail loudly.
with mock.patch.dict(sys.modules, {"rfdetr": None}):
zip_name, model_type = _process_rfdetr("rfdetr-base", model_path, "weights.pt")
self.assertEqual(model_type, "rfdetr-base")
with zipfile.ZipFile(os.path.join(model_path, zip_name)) as z:
names = z.namelist()
self.assertIn("weights.pt", names)
self.assertIn("class_names.txt", names)


if __name__ == "__main__":
unittest.main()
Loading