Skip to content

Optimize aten::min/max.dim with TopK op#2780

Open
danielhumanmod wants to merge 4 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim
Open

Optimize aten::min/max.dim with TopK op#2780
danielhumanmod wants to merge 4 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim

Conversation

@danielhumanmod
Copy link

Fix pytorch/pytorch#76344

Context

As mentioned in the issue, torch.max(dim=...) can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).

In additional, given the torch.min(dim=...) has the similar pattern with max, I also apply this optimization to it.

Verification

Successfully passed existing OpInfo consistency tests:

  • pytest tests/function_libs/torch_lib/ops_test.py
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py

@danielhumanmod
Copy link
Author

@danielhumanmod please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

❌ Patch coverage is 95.47170% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.71%. Comparing base (e06dd92) to head (312ef9e).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
.../rewriter/rules/common/_fuse_reduce_arg_to_topk.py 90.99% 7 Missing and 3 partials ⚠️
...iter/rules/common/_fuse_reduce_arg_to_topk_test.py 98.70% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2780      +/-   ##
==========================================
+ Coverage   70.46%   70.71%   +0.24%     
==========================================
  Files         228      230       +2     
  Lines       27258    27442     +184     
  Branches     2761     2754       -7     
==========================================
+ Hits        19208    19406     +198     
+ Misses       7100     7097       -3     
+ Partials      950      939      -11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

@github-project-automation github-project-automation bot moved this from Todo to In Progress in ONNX Script Review Board Jan 25, 2026
@danielhumanmod
Copy link
Author

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.

  1. From ONNX runtime perspective,

    1. CPU EP provide a fastline when k = 1, which performs a simple linear scan. So on CPU, it seems to behave identically to a fused max+argmax.
    2. CUDA EP will walk through the whole Bitonic/Radix sort process, which can involve more complex instructions. But the upside is that these operations happen primarily in shared memory.
  2. PyTorch Inductor (as an reference): it adopts a similar approach—splitting into reduce_max/arg_max in IR—but leaves it to the runtime (Scheduler) to fuse them. However, when I checked ONNX Runtime, it didn't seem to have an optimization rule to automatically fuse ReduceMax and ArgMax, which implies the split approach effectively incurs one more IO pass compared to TopK

So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation.

@justinchuby
Copy link
Collaborator

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

@danielhumanmod
Copy link
Author

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback!

@danielhumanmod
Copy link
Author

Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks!

@abstractmethod
def reduce_op_type(self) -> str:
"""Return the name of the Reduce operation"""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@abstractmethod
def arg_op_type(self) -> str:
"""Return the name of the Arg operation"""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@abstractmethod
def largest(self) -> int:
"""Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements)."""
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@@ -0,0 +1,297 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,297 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
@property
def reduce_op_type(self) -> str:
return "ReduceMax"

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

@property
def reduce_op_type(self) -> str:
return "ReduceMax"

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
@@ -0,0 +1,543 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Final newline expected
@@ -0,0 +1,543 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
@@ -0,0 +1,543 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what's wrong with this line, lintrunner -a already passed locally



if __name__ == "__main__":
unittest.main() No newline at end of file

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[ONNX] Use topk to export max(dim,keepdim) to onnx

2 participants