Skip to content

loading Cosmos2 pipeline also disabled gradient tracking globally #13790

@haixpham

Description

@haixpham

Describe the bug

I'm using Diffusers version 0.35.2

After loading Cosmos2VideoToWorldPipeline pipeline from "nvidia/Cosmos-Predict2-2B-Video2World", gradient tracking is disabled globally. Have to manually re-enable with torch.set_grad_enabled(True)

Reproduction

  • Bug reproduction
import torch
from torch import nn
from diffusers import Cosmos2VideoToWorldPipeline
model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)

print(torch.is_grad_enabled())

output:

False
  • Note: loading each component (VAE, transformer) does not disable gradient tracking.

Logs

System Info

diffusers: 0.35.2
transformers: 5.3.0

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions