-
Notifications
You must be signed in to change notification settings - Fork 596
Solve pytorch-triton and triton package contention #2540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| num_ctas, # arg2: num_ctas (int) | ||
| compiled.metadata.shared, # arg3: shared_mem_bytes (int) | ||
| compiled.asm["ptx"], # arg4: ptx (str) | ||
| "", # arg5: ttir (str) - empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+
Greptile SummaryThis PR resolves the package contention between Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/CI
participant Setup as setup.py
participant PyTorch as build_tools/pytorch.py
participant JAX as build_tools/jax.py
participant Utils as triton_extensions/utils.py
User->>Setup: pip install (PyTorch)
Setup->>PyTorch: install_requirements()
PyTorch-->>Setup: ["pytorch-triton", ...]
Note right of PyTorch: Always uses pytorch-triton<br/>for torch.compile()
User->>Setup: pip install (JAX)
Setup->>JAX: test_requirements()
JAX->>JAX: Check NVTE_USE_PYTORCH_TRITON
alt NVTE_USE_PYTORCH_TRITON=1
JAX-->>Setup: ["pytorch-triton", ...]
else Default (unset or 0)
JAX-->>Setup: ["triton", ...]
end
User->>Utils: import triton_extensions
Utils->>Utils: _detect_triton_package()
Utils->>Utils: _check_triton_compatibility()
alt Placeholder package (0.0.1)
Utils-->>User: ImportError with fix instructions
else pytorch-triton without env var
Utils-->>User: Warning (suppressible)
else Valid triton package
Utils-->>User: Success
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/triton_extensions/utils.py, line 322 (link)syntax: Typo:
compile.nameshould becompiled.name. The variablecompileis not defined in this scope - onlycompiledexists from line 300. This will cause aNameErrorat runtime for JAX versions < 0.8.2.
4 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/pytorch.py, line 21 (link)style: Placeholder text
<version??>should be replaced with an actual version (e.g.,cu121orcu124) or made generic.
4 files reviewed, 1 comment
build_tools/jax.py
Outdated
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | ||
| "1", | ||
| "true", | ||
| "yes", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| use_pytorch_triton = os.environ.get("NVTE_USE_PYTORCH_TRITON", "0").lower() in ( | |
| "1", | |
| "true", | |
| "yes", | |
| ) | |
| use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0"))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be specific with envvars and follow convention from rest of the codebase.
| "pydantic", | ||
| "nvdlfw-inspect", | ||
| "triton", | ||
| "pytorch-triton", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If pytorch-triton from PyPI is actually a placeholder, then we shouldn't list it here as a dependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch-triton should be the default for transformerengine-pytorch. Not just a placeholder. It should be used, all the time when pytorch framework is used.
triton is the default for jax, unless either in 2 scenarios happen:
- there is both jax and pytorch installed, and they are using TE pytorch to call the triton kernels
- The user specify NVTE_USE_PYTORCH_TRITON=1 while using TE jax, to make sure there is no performance diff between using different versions of triton, between pytorch and jax.
…for jax Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/jax.py, line 37 (link)logic: Inconsistent env var parsing:
utils.pyaccepts"1","true","yes"(via.lower()check), but this line only accepts"0"or"1"(viaint()cast). If user setsNVTE_USE_PYTORCH_TRITON=true, this will raiseValueError: invalid literal for int() with base 10: 'true'.
4 files reviewed, 1 comment
e2773e1 to
469ed68
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/pytorch.py, line 35 (link)style: Listing
pytorch-tritonhere won't work with standardpip installfrom PyPI since the real package is only available from PyTorch's package index. Users runningpip install .without specifying--index-urlwill get the placeholder package that raises "Should never be installed".Consider adding an installation note to the README or setup.py, or documenting that this package requires:
pip install --index-url https://download.pytorch.org/whl/cu121 pytorch-triton
4 files reviewed, 1 comment
Description
pytorch-tritonandtritonpackages install to the same location at site-packages/triton, andtritondoes not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creatingpytorch-tritonto make it work and validated it with the release of torch). Howeverpytorch-tritonshould in theory (and experimented) still be compatible with how jax uses it*.Fixes # (issue)
Type of change
Changes
Checklist: