Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before shift_terminal#410
Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before shift_terminal#410flyingblackshark wants to merge 1 commit into
shift_terminal#410Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
hardcoding the sequence length to 4096 will break the dynamic shifting logic for smaller videos, making the shift always static at 2.05 even when it should be smaller. Also, when running the ltx2 pipeline with standard resolution (768x512, 121 frames) the vid seq length is 6144. I think we can use a |
Fix NaN in LTX2 scheduler when dynamic shift saturates sigmas before
shift_terminalThis commit fixes a numerical stability issue in the LTX2 scheduler when processing very long video sequences.
In
ltx2_pipeline, the dynamic shift value is calculated fromvideo_sequence_length:For very long sequences,
mucan become very large. For example, whenmureaches around48, the exponential dynamic shifting logic inset_timesteps_ltx2becomes numerically saturated:The issue is not that
exp(48)overflows infloat32.exp(48)is still finite in fp32. The actual problem happens after dynamic shifting.When
current_shiftis large enough,jnp.exp(current_shift)dominates the denominator. As a result, the shiftedsigmascan become numerically saturated to exactly1.0.The later
shift_terminallogic assumes that1 - sigmas[-1]is non-zero:However, when dynamic shifting has already saturated all
sigmasto1.0, this becomes:Then the final rescaling step evaluates:
which becomes:
and produces
NaN.