Skip to content

Add guard at lowest JAX version that still supports triton kernel calling#2741

Open
tdophung wants to merge 9 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat
Open

Add guard at lowest JAX version that still supports triton kernel calling#2741
tdophung wants to merge 9 commits intoNVIDIA:mainfrom
tdophung:triton_jax_bwd_compat

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Mar 6, 2026

Description

To provide backward compatibility with older jax versions, we need to have a safeguard in place for jax versions too old to work with triton kernel calling. Using Claude Code to automate bisecting through JAX toolbox nightly containers between Sep 1, 2025 and Oct 1, 2025 (*), I have found that the first passing version of the container starts on Sep 24th, 2025, corresponding to jax 0.8.0.dev20250924 hence the guard is put there.

(*) the date range is determined by having a data point that the officially released jax toolbox (nvcr.io/nvidia/jax:25.10-py3 fails while the nightly jax container on Oct 1st passed.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Handles jax < 0.8.0.dev20250924 segfault error when calling triton kernels frfom JAX side

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

tdophung commented Mar 6, 2026

/te-ci jax

@tdophung tdophung changed the title add guard at bisected jax version where lower is segfault Add guard at lowest JAX version that still supports triton kernel calling Mar 6, 2026
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM pending CI, thanks!

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 6, 2026

Greptile Summary

This PR introduces a minimum JAX version guard (>= 0.8.0) to prevent segfaults when dispatching Triton kernels on older jaxlib versions, and refactors the existing ad-hoc version check in quantize/helper.py into a new shared transformer_engine/jax/version_utils.py module. The changes are well-structured overall and the approach of using a pytest.mark.triton marker with a pytest_collection_modifyitems hook provides clean, granular test skipping without module-collection failures.

Key changes:

  • New version_utils.py: Centralizes jax_version_meet_requirement (refactored from helper.py), is_triton_extension_supported, and TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0".
  • triton_extensions/utils.py: Adds a RuntimeError guard before the gpu_triton import; correctly placed and documented.
  • Test infrastructure: Replaces top-level triton API imports with lazy-loading autouse fixtures in four test files, adds a triton pytest marker registered in conftest.py, and introduces require_triton_or_skip_test_file() for files that cannot safely execute any module-level code on old JAX.
  • Minor style concerns: import triton lines in test_triton_custom_calls.py now follow an executable statement and may need # noqa: E402; the lazy-loading fixtures don't clean up injected module-level names on teardown.

Confidence Score: 4/5

  • This PR is safe to merge; the version guard and test infrastructure changes are correct, with only minor style issues remaining.
  • The core logic — version guard at >= 0.8.0, shared version_utils module, and marker-based test skipping — is sound. Previous thread concerns (missing version_utils.py, allow_module_level=True, hardcoded version strings) have all been addressed. The two remaining issues are style-level: potential E402 lint failures on import triton lines that now follow an executable statement, and missing fixture teardown cleanup for injected module-level names. Neither impacts production runtime behavior.
  • tests/jax/test_triton_custom_calls.py (potential E402 lint failure) and the _inject_* fixture teardown in test_permutation.py, test_fused_router.py, test_distributed_permutation.py, and test_distributed_router.py.

Important Files Changed

Filename Overview
transformer_engine/jax/version_utils.py New shared utility module cleanly centralizing JAX version checks; exports public jax_version_meet_requirement, is_triton_extension_supported, and TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0". Well-structured with lru_cache and proper __all__.
transformer_engine/jax/triton_extensions/utils.py Adds a JAX version guard (RuntimeError) before gpu_triton import. Guard is correctly placed after _check_triton_compatibility() (which doesn't touch gpu_triton), and the comment accurately explains why the segfault is at dispatch-time not import-time.
transformer_engine/jax/quantize/helper.py Clean refactor: removes the local _jax_version_meet_requirement implementation and delegates to the new shared jax_version_meet_requirement from version_utils. Unused imports (lru_cache, get_pkg_version, PkgVersion) correctly removed.
tests/jax/conftest.py Adds triton marker registration in pytest_configure (using the constant, not a hardcoded string) and a pytest_collection_modifyitems hook that gracefully skips all @pytest.mark.triton tests on old JAX without failing collection.
tests/jax/test_triton_custom_calls.py Calls require_triton_or_skip_test_file() at module level (correctly placed before import triton) to skip the entire file on old JAX. The import triton lines now come after an executable statement and may require # noqa: E402 annotations to pass linting.
tests/jax/test_permutation.py Replaces top-level permutation imports with a lazy-loading autouse fixture that injects names into the module namespace only for triton-marked tests. The injected names persist after yield (no teardown cleanup).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["Import triton_extensions or run Triton test"] --> B["is_triton_extension_supported()\n(from version_utils)"]
    B --> C{JAX >= 0.8.0?}

    C -- No --> D1["triton_extensions/utils.py\nraises RuntimeError"]
    C -- No --> D2["conftest.py pytest_collection_modifyitems\nmarks @triton tests as skip"]
    C -- No --> D3["require_triton_or_skip_test_file()\npytest.skip(allow_module_level=True)"]

    C -- Yes --> E["Normal execution\n_check_triton_compatibility()"]
    E --> F{Triton installed?}
    F -- No --> G["ImportError: install triton"]
    F -- Yes --> H["Import gpu_triton, triton.compiler\nKernel dispatch works"]

    D2 --> I["_inject_* autouse fixture\nreturns early (no inject)"]
    H --> J["_inject_* autouse fixture\ninjects module-level names\n(token_dispatch, fused_topk…)"]

    subgraph version_utils.py
        B
    end
Loading

Comments Outside Diff (2)

  1. tests/jax/test_triton_custom_calls.py, line 14-15 (link)

    Mid-module imports may trigger E402 lint errors

    import triton and import triton.language as tl now appear after the executable function call require_triton_or_skip_test_file() at line 12. Flake8's E402 rule ("module level import not at top of file") flags imports that follow any non-import executable statement. If the lint configuration does not already suppress E402 for this file, the CI lint job (qa/L0_jax_lint/test.sh) will fail.

    Consider adding # noqa: E402 to suppress the warnings:

  2. tests/jax/test_permutation.py, line 17-40 (link)

    Injected module-level names are never cleaned up on fixture teardown

    The _inject_permutation fixture injects token_dispatch, token_combine, and sort_chunks_by_index into sys.modules[__name__] before the test runs but does not remove them after yield. The same pattern applies in test_distributed_permutation.py, test_fused_router.py, and test_distributed_router.py.

    While this is harmless in the current sequential test execution (a triton-marked test always injects before use, and non-triton tests don't use these names), the injected names silently leak across subsequent tests in the same session. If a future refactoring causes a non-triton test to accidentally reference one of these names, it would succeed or fail non-deterministically based on the test execution order.

    A defensive cleanup after yield would make the fixture fully idempotent:

        yield
        # Clean up injected names so they don't leak into subsequent tests
        for name in ("token_dispatch", "token_combine", "sort_chunks_by_index"):
            mod.__dict__.pop(name, None)

    The same cleanup pattern should be applied to _inject_router in test_fused_router.py and test_distributed_router.py.

Last reviewed commit: 015a804

@jberchtold-nvidia jberchtold-nvidia self-requested a review March 6, 2026 16:01
…lper.py

Signed-off-by: tdophung <tdophung@nvidia.com>
tdophung and others added 2 commits March 9, 2026 16:16
- Add version_utils.py with is_triton_extension_supported() checking JAX >= 0.8.0
  (release version, not dev snapshot) and TRITON_EXTENSION_MIN_JAX_VERSION constant
- Add pytest.mark.triton marker and conftest hook to skip marked tests on old JAX
- Add require_triton() for module-level skipping in test files
- Rewrite triton_extensions to use is_triton_extension_supported() instead of
  direct jaxlib dev-version comparison

Signed-off-by: tdophung <tdophung@nvidia.com>
…d re-export, revert test.sh

- require_triton(): add allow_module_level=True to pytest.skip() so module-level
  calls on old JAX produce a proper skip instead of a collection failure
- Remove is_triton_extension_supported from triton_extensions/utils.py __all__:
  importing triton_extensions on JAX < 0.8.0 raises immediately, so re-exporting
  the check from there defeats its purpose; callers should import directly from
  transformer_engine.jax.version_utils
- Revert qa/L0_jax_lint/test.sh TE_PATH to /opt/transformerengine (local dev
  path was accidentally committed; pass TE_PATH= at invocation time instead)

Signed-off-by: tdophung <tdophung@nvidia.com>
…l__ and hardcoded version

- Move is_triton_extension_supported() guard before the gpu_triton import block
  with a comment clarifying the segfault is at dispatch time, not import time
- Remove _jax_version_meet_requirement from version_utils __all__ (private helper,
  not a public API; callers import it explicitly as needed)
- Use TRITON_EXTENSION_MIN_JAX_VERSION constant in conftest marker description
  instead of hardcoded '0.8.0'

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

/te-ci jax

tdophung and others added 2 commits March 10, 2026 10:56
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants