Optimize aten::min/max.dim with TopK op#2780
Optimize aten::min/max.dim with TopK op#2780danielhumanmod wants to merge 4 commits intomicrosoft:mainfrom
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.46% 70.71% +0.24%
==========================================
Files 228 230 +2
Lines 27258 27442 +184
Branches 2761 2754 -7
==========================================
+ Hits 19208 19406 +198
+ Misses 7100 7097 -3
+ Partials 950 939 -11 ☔ View full report in Codecov by Sentry. |
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback! |
|
Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks! |
| @abstractmethod | ||
| def reduce_op_type(self) -> str: | ||
| """Return the name of the Reduce operation""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
| @abstractmethod | ||
| def arg_op_type(self) -> str: | ||
| """Return the name of the Arg operation""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
| @abstractmethod | ||
| def largest(self) -> int: | ||
| """Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements).""" | ||
| ... |
Check notice
Code scanning / CodeQL
Statement has no effect Note
| @@ -0,0 +1,297 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF/format Warning
| @@ -0,0 +1,297 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
| @property | ||
| def reduce_op_type(self) -> str: | ||
| return "ReduceMax" | ||
|
|
Check warning
Code scanning / lintrunner
RUFF/W293 Warning
| @property | ||
| def reduce_op_type(self) -> str: | ||
| return "ReduceMax" | ||
|
|
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
| @@ -0,0 +1,543 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning
| @@ -0,0 +1,543 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF/format Warning
| @@ -0,0 +1,543 @@ | |||
| # Copyright (c) Microsoft Corporation. | |||
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning
There was a problem hiding this comment.
Not sure what's wrong with this line, lintrunner -a already passed locally
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() No newline at end of file |
Check warning
Code scanning / lintrunner
RUFF/W292 Warning
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: