Skip to content

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728

Open
DmCarpe93 wants to merge 3 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error
Open

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728
DmCarpe93 wants to merge 3 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error

Conversation

@DmCarpe93
Copy link

@DmCarpe93 DmCarpe93 commented Mar 3, 2026

Description

Fixed an issue where the cu_seqlen tensor was incorrectly retrieved from the cache.

  • Currently, only (batch_size, max_seqlen) were used as the cache key when retrieving cu_seqlens.
  • This coud result in error especially for Knowledge Distillation training, because teacher and student model can be run on same node.
    • When teacher model run first, cu_seqlens tensor would be created and cached.
    • After that, when student model trains on the same node, the cached cu_seqlens tensor would be used if same (batch_size, max_seqlen) is used.
    • Since cached cu_seqlens tensor from teacher model could have different inference mode and device, it could result in error.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • The cache key for retrieving cu_seqlens was updated from (batch_size, max_seqlen) to include both the device and inference mode.
  • Added testcases for cu_seqlens cache.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Greptile Summary

This PR fixes a cache-poisoning bug in get_full_cu_seqlens where the _cu_seqlens_cache was keyed only on (batch_size, max_seqlen), allowing a tensor created on one device or under torch.inference_mode() to be returned to a caller on a different device or in a different autograd context. The fix extends the cache key to (batch_size, max_seqlen, device, is_inference_mode_enabled), which is the minimal correct scope for this tensor. The change is tightly scoped to a 6-line modification in utils.py and is accompanied by a new test file covering both device-isolation and inference-vs-training isolation.

Key changes:

  • utils.py: Cache key extended from 2-tuple to 4-tuple including torch.device and the boolean result of torch.is_inference_mode_enabled().
  • test_cu_seqlens_cache.py: Two new integration tests validate that separate cache entries are maintained per-device and per-mode, and that the tensors stored actually reside on the expected device and use correct autograd context.

Confidence Score: 5/5

  • PR is safe to merge. The fix is a minimal, correct, and well-tested correction to a cache-poisoning vulnerability.
  • This PR is safe to merge: (1) the change is minimal and tightly scoped (6 lines in one function); (2) the fix is correct — extending the cache key with torch.device and torch.is_inference_mode_enabled() eliminates both failure modes described in the PR; (3) both key components are hashable, ensuring cache consistency; (4) all call sites pass tensor.device which is always a fully qualified torch.device object; (5) comprehensive tests exercise both device isolation and inference-vs-training isolation, with assertions verifying that tensors reside on expected devices and use correct autograd contexts; (6) no regressions are introduced — the change only makes cache key scoping stricter, which cannot break existing code that relied on the previous (incorrect) behavior.
  • No files require special attention.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_full_cu_seqlens called"] --> B{ONNX export mode?}
    B -- Yes --> C["Build tensor directly, skip cache"]
    B -- No --> D["Form lookup tuple: batch_size, max_seqlen, device, inference_mode"]
    D --> F{Tuple in cache?}
    F -- No --> G["Allocate cu_seqlens tensor and store in cache"]
    G --> H["Return tensor"]
    F -- Yes --> H
Loading

Last reviewed commit: 86151e8

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa March 3, 2026 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant