Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Dec 23, 2025

Description

pytorch-triton and triton packages install to the same location at site-packages/triton, and triton does not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creating pytorch-triton to make it work and validated it with the release of torch). However pytorch-triton should in theory (and experimented) still be compatible with how jax uses it*.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new env var to control when to use pytorch-triton in jax
  • switch pytorch back to using/checking for pytorch-triton by default
  • Add documentation (comments) on this contention of packages

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

num_ctas, # arg2: num_ctas (int)
compiled.metadata.shared, # arg3: shared_mem_bytes (int)
compiled.asm["ptx"], # arg4: ptx (str)
"", # arg5: ttir (str) - empty
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile Summary

This PR resolves the package contention between pytorch-triton and triton packages, which both install to site-packages/triton. The changes ensure PyTorch uses pytorch-triton (required for torch.compile()) while JAX defaults to standard triton but can optionally use pytorch-triton via the NVTE_USE_PYTORCH_TRITON environment variable.

Key changes:

  • build_tools/pytorch.py: Changed dependency from triton to pytorch-triton for PyTorch compatibility
  • build_tools/jax.py: Added NVTE_USE_PYTORCH_TRITON env var to conditionally select triton package
  • transformer_engine/jax/triton_extensions/utils.py: Added runtime detection of triton package type with warnings for mixed environments
  • Added comprehensive documentation throughout explaining the package options and installation requirements

Confidence Score: 4/5

  • This PR is safe to merge - changes are primarily documentation and configuration with well-thought-out detection logic and helpful error messages.
  • The changes are well-documented, handle edge cases properly (placeholder package detection), and provide clear user guidance. The only minor concern is that pytorch-triton in install_requirements cannot be automatically installed from PyPI - users must use PyTorch's package index.
  • build_tools/pytorch.py - the pytorch-triton dependency requires installation from PyTorch's index, not standard PyPI

Important Files Changed

Filename Overview
build_tools/jax.py Added NVTE_USE_PYTORCH_TRITON env var to conditionally select between 'triton' (default) and 'pytorch-triton' packages for JAX test requirements. Well-documented docstring added.
build_tools/pytorch.py Changed default triton dependency from 'triton' to 'pytorch-triton' for PyTorch compatibility with torch.compile(). Added documentation explaining that pytorch-triton must be installed from PyTorch's package index, not PyPI.
transformer_engine/jax/triton_extensions/init.py Added comprehensive documentation about Triton package options, the NVTE_USE_PYTORCH_TRITON environment variable, and usage examples for the new get_triton_info() function.
transformer_engine/jax/triton_extensions/utils.py Added Triton package detection logic with _detect_triton_package() and _check_triton_compatibility() functions. Detects standard triton vs pytorch-triton, handles placeholder package from PyPI, and emits helpful warnings/errors. Added get_triton_info() API.

Sequence Diagram

sequenceDiagram
    participant User as User/CI
    participant Setup as setup.py
    participant PyTorch as build_tools/pytorch.py
    participant JAX as build_tools/jax.py
    participant Utils as triton_extensions/utils.py
    
    User->>Setup: pip install (PyTorch)
    Setup->>PyTorch: install_requirements()
    PyTorch-->>Setup: ["pytorch-triton", ...]
    Note right of PyTorch: Always uses pytorch-triton<br/>for torch.compile()
    
    User->>Setup: pip install (JAX)
    Setup->>JAX: test_requirements()
    JAX->>JAX: Check NVTE_USE_PYTORCH_TRITON
    alt NVTE_USE_PYTORCH_TRITON=1
        JAX-->>Setup: ["pytorch-triton", ...]
    else Default (unset or 0)
        JAX-->>Setup: ["triton", ...]
    end
    
    User->>Utils: import triton_extensions
    Utils->>Utils: _detect_triton_package()
    Utils->>Utils: _check_triton_compatibility()
    alt Placeholder package (0.0.1)
        Utils-->>User: ImportError with fix instructions
    else pytorch-triton without env var
        Utils-->>User: Warning (suppressible)
    else Valid triton package
        Utils-->>User: Success
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/triton_extensions/utils.py, line 322 (link)

    syntax: Typo: compile.name should be compiled.name. The variable compile is not defined in this scope - only compiled exists from line 300. This will cause a NameError at runtime for JAX versions < 0.8.2.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/pytorch.py, line 21 (link)

    style: Placeholder text <version??> should be replaced with an actual version (e.g., cu121 or cu124) or made generic.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 37 to 41
use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in (
"1",
"true",
"yes",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in (
"1",
"true",
"yes",
)
use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to be specific with envvars and follow convention from rest of the codebase.

"pydantic",
"nvdlfw-inspect",
"triton",
"pytorch-triton",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.

triton is the default for jax, unless either in 2 scenarios happen:

  • there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
  • The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.

tdophung and others added 7 commits January 2, 2026 09:59
…for jax

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/jax.py, line 37 (link)

    logic: Inconsistent env var parsing: utils.py accepts "1", "true", "yes" (via .lower() check), but this line only accepts "0" or "1" (via int() cast). If user sets NVTE_USE_PYTORCH_TRITON=true, this will raise ValueError: invalid literal for int() with base 10: 'true'.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@tdophung tdophung force-pushed the teddy/pytorch-triton-contention branch from e2773e1 to 469ed68 Compare January 2, 2026 18:02
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. build_tools/pytorch.py, line 35 (link)

    style: Listing pytorch-triton here won't work with standard pip install from PyPI since the real package is only available from PyTorch's package index. Users running pip install . without specifying --index-url will get the placeholder package that raises "Should never be installed".

    Consider adding an installation note to the README or setup.py, or documenting that this package requires:

    pip install --index-url https://download.pytorch.org/whl/cu121 pytorch-triton
    

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants