Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,30 @@ def _check_torch_installed():
except Exception:
msg = (
"Missing required pre-installed packages: torch, torchvision\n"
"Install the PyTorch CUDA wheels from the appropriate index first, e.g.:\n"
" pip install --index-url https://download.pytorch.org/whl/cu12x torch torchvision\n"
"Replace the index URL and versions to match your CUDA runtime."
"On Linux/Windows: pip install --index-url https://download.pytorch.org/whl/cu12x torch torchvision\n"
"On Mac: pip install torch torchvision (MPS acceleration is used automatically)"
)
raise RuntimeError(msg)

if not torch.version.cuda:
raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.")
is_mac = sys.platform == "darwin"
has_cuda = bool(torch.version.cuda)
has_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()

if not has_cuda and not has_mps and not is_mac:
raise RuntimeError(
"Detected CPU-only PyTorch on a non-Mac platform. "
"Install CUDA-enabled torch/torchvision before installing this package."
)


def is_mac():
return sys.platform == "darwin"


def get_cuda_constraint():
if is_mac():
return None # cuda-python not used on Mac

cuda_version = os.environ.get("STREAMDIFFUSION_CUDA_VERSION") or \
os.environ.get("CUDA_VERSION")

Expand All @@ -46,8 +59,9 @@ def get_cuda_constraint():
if any(cmd in sys.argv for cmd in ("install", "develop")):
_check_torch_installed()

_cuda_constraint = get_cuda_constraint()
_deps = [
f"cuda-python{get_cuda_constraint()}",
*([] if _cuda_constraint is None else [f"cuda-python{_cuda_constraint}"]),
"xformers==0.0.30",
"diffusers @ git+https://github.com/varshith15/diffusers.git@3e3b72f557e91546894340edabc845e894f00922",
"transformers==4.56.0",
Expand Down Expand Up @@ -82,7 +96,10 @@ def deps_list(*pkgs):
extras = {}
extras["xformers"] = deps_list("xformers")
extras["torch"] = deps_list("torch", "accelerate")
extras["tensorrt"] = deps_list("protobuf", "cuda-python", "onnx", "onnxruntime", "onnxruntime-gpu", "colored", "polygraphy", "onnx-graphsurgeon")
_tensorrt_pkgs = ["protobuf", "onnx", "onnxruntime", "onnxruntime-gpu", "colored", "polygraphy", "onnx-graphsurgeon"]
if not is_mac():
_tensorrt_pkgs.insert(0, "cuda-python")
extras["tensorrt"] = deps_list(*_tensorrt_pkgs)
extras["controlnet"] = deps_list("onnx-graphsurgeon", "controlnet-aux")
extras["ipadapter"] = deps_list("diffusers-ipadapter", "mediapipe", "insightface")

Expand Down
20 changes: 14 additions & 6 deletions src/streamdiffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,9 +968,14 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor:
def __call__(
self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None
) -> torch.Tensor:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
_use_cuda_timing = torch.cuda.is_available()
if _use_cuda_timing:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
import time as _time
_t0 = _time.perf_counter()

if x is not None:
x = self.image_processor.preprocess(x, self.height, self.width).to(
Expand Down Expand Up @@ -1012,9 +1017,12 @@ def __call__(

# Clone for skip-frame cache — TRT VAE buffer is reused on next decode call
self.prev_image_result = x_output.clone()
end.record()
end.synchronize() # Wait only for this event, not all streams globally
inference_time = start.elapsed_time(end) / 1000
if _use_cuda_timing:
end.record()
end.synchronize()
inference_time = start.elapsed_time(end) / 1000
else:
inference_time = _time.perf_counter() - _t0
self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time

return x_output
Expand Down
4 changes: 2 additions & 2 deletions src/streamdiffusion/preprocessing/base_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def cleanup(self) -> None:

# Cleanup CUDA stream if it exists
if hasattr(self, '_background_stream') and self._background_stream is not None:
# Synchronize the stream before cleanup
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
self._background_stream = None

def __del__(self):
Expand Down
64 changes: 39 additions & 25 deletions src/streamdiffusion/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
mode: Literal["img2img", "txt2img"] = "img2img",
output_type: Literal["pil", "pt", "np", "latent"] = "pil",
vae_id: Optional[str] = None,
device: Literal["cpu", "cuda"] = "cuda",
device: Literal["cpu", "cuda", "mps"] = "cuda",
dtype: torch.dtype = torch.float16,
frame_buffer_size: int = 1,
width: int = 512,
Expand Down Expand Up @@ -148,8 +148,9 @@ def __init__(
The vae_id to load, by default None.
If None, the default TinyVAE
("madebyollin/taesd") will be used.
device : Literal["cpu", "cuda"], optional
The device to use for inference, by default "cuda".
device : Literal["cpu", "cuda", "mps"], optional
The device to use for inference, by default "cuda". Resolved against
availability: falls back to MPS on Apple Silicon, then CPU.
device_ids : Optional[List[int]], optional
The device ids to use for DataParallel, by default None.
dtype : torch.dtype, optional
Expand Down Expand Up @@ -275,7 +276,17 @@ def __init__(
"img2img mode must use denoising batch for now."
)

self.device = device
# Resolve the requested device against what's actually available so the
# same config runs on CUDA, Apple Silicon (MPS), or CPU. A config that
# still says "cuda" transparently uses MPS on a Mac.
if device == "cpu":
self.device = "cpu"
elif torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.dtype = dtype
self.width = width
self.height = height
Expand Down Expand Up @@ -1099,10 +1110,12 @@ def _load_model(
except Exception as e:
logger.warning(f"GPU cleanup warning: {e}")

# Reset CUDA context to prevent corruption from previous runs
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Force CUDA context reset by creating and destroying a small tensor
# Reset GPU context to prevent corruption from previous runs
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.synchronize()
# Force GPU context reset by creating and destroying a small tensor
temp_tensor = torch.zeros(1, device=self.device)
del temp_tensor
logger.info("_load_model: CUDA context reset completed")
Expand Down Expand Up @@ -1328,7 +1341,7 @@ def _load_model(

try:
if acceleration == "xformers":
stream.pipe.enable_xformers_memory_efficient_attention()
print('Skipping xformers on Mac')
if acceleration == "tensorrt":
from polygraphy import cuda
from streamdiffusion.acceleration.tensorrt import TorchVAEEncoder
Expand Down Expand Up @@ -1563,9 +1576,11 @@ def _load_model(
# Cleanup after IPAdapter installation
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.synchronize()

except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f"CUDA Out of Memory during early IPAdapter installation: {oom_error}")
logger.error("Try reducing batch size, using smaller models, or increasing GPU memory")
Expand Down Expand Up @@ -1895,7 +1910,7 @@ def _load_model(
except Exception:
import traceback
traceback.print_exc()
raise Exception("Acceleration has failed.")
print("Skipping acceleration on Mac MPS")

# Install modules via hooks instead of patching (wrapper keeps forwarding updates only)
if use_controlnet:
Expand Down Expand Up @@ -2321,17 +2336,16 @@ def cleanup_gpu_memory(self) -> None:
for i in range(3):
gc.collect()

# Clear CUDA cache and cleanup IPC handles
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Force additional memory cleanup
torch.cuda.ipc_collect()

# Get memory info
allocated = torch.cuda.memory_allocated() / (1024**3) # GB
cached = torch.cuda.memory_reserved() / (1024**3) # GB
logger.info(f" GPU Memory after cleanup: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
allocated = torch.cuda.memory_allocated() / (1024**3)
cached = torch.cuda.memory_reserved() / (1024**3)
logger.info(f" GPU Memory after cleanup: {allocated:.2f}GB allocated, {cached:.2f}GB cached")
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
logger.info(" MPS cache cleared")

logger.info(" Enhanced GPU memory cleanup complete")

Expand All @@ -2354,7 +2368,7 @@ def check_gpu_memory_for_engine(self, engine_size_gb: float) -> bool:
cached = torch.cuda.memory_reserved() / (1024**3)

# Get total GPU memory
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
total_memory = 0
free_memory = total_memory - allocated

# Add 20% overhead for safety
Expand Down