Skip to content

Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before shift_terminal#410

Open
flyingblackshark wants to merge 1 commit into
AI-Hypercomputer:mainfrom
flyingblackshark:main
Open

Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before shift_terminal#410
flyingblackshark wants to merge 1 commit into
AI-Hypercomputer:mainfrom
flyingblackshark:main

Conversation

@flyingblackshark
Copy link
Copy Markdown

Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before shift_terminal

This commit fixes a numerical stability issue in the LTX2 scheduler when processing very long video sequences.

In ltx2_pipeline, the dynamic shift value is calculated from video_sequence_length:

mu = calculate_shift(
    video_sequence_length,
    self.scheduler.config.get("base_image_seq_len", 1024),
    self.scheduler.config.get("max_image_seq_len", 4096),
    self.scheduler.config.get("base_shift", 0.95),
    self.scheduler.config.get("max_shift", 2.05),
)

For very long sequences, mu can become very large. For example, when mu reaches around 48, the exponential dynamic shifting logic in set_timesteps_ltx2 becomes numerically saturated:

sigmas = jnp.exp(current_shift) / (
    jnp.exp(current_shift)
    + (1 / jnp.clip(sigmas, 1e-7, 1.0) - 1) ** 1.0
)

The issue is not that exp(48) overflows in float32. exp(48) is still finite in fp32. The actual problem happens after dynamic shifting.

When current_shift is large enough, jnp.exp(current_shift) dominates the denominator. As a result, the shifted sigmas can become numerically saturated to exactly 1.0.

The later shift_terminal logic assumes that 1 - sigmas[-1] is non-zero:

one_minus_z = 1 - sigmas
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
sigmas = 1 - (one_minus_z / scale_factor)

However, when dynamic shifting has already saturated all sigmas to 1.0, this becomes:

one_minus_z = 1 - 1.0 = 0
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
             = 0 / (1 - shift_terminal)
             = 0

Then the final rescaling step evaluates:

one_minus_z / scale_factor

which becomes:

0 / 0

and produces NaN.

@flyingblackshark flyingblackshark requested a review from entrpn as a code owner May 18, 2026 04:46
@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 18, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@prishajain1
Copy link
Copy Markdown
Collaborator

hardcoding the sequence length to 4096 will break the dynamic shifting logic for smaller videos, making the shift always static at 2.05 even when it should be smaller. Also, when running the ltx2 pipeline with standard resolution (768x512, 121 frames) the vid seq length is 6144.

I think we can use a min(6144, video seq length) kind of logic instead. However, could you please test the video qualities obtained at slightly higher number of frames (say 131,161) and standard resolution (768x512) with this new logic to ensure no regressions in video qualities?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants