From 2841ccdf8b9439a0a13bf8ec318eb493f23411af Mon Sep 17 00:00:00 2001 From: Fahad Alghanim <163377666+KOKOSde@users.noreply.github.com> Date: Sat, 9 May 2026 20:15:35 -0700 Subject: [PATCH] Add transformer block graph surgery example Signed-off-by: Fahad Alghanim <163377666+KOKOSde@users.noreply.github.com> --- .../13_transformer_block_surgery/README.md | 33 ++++++ .../13_transformer_block_surgery/generate.py | 86 ++++++++++++++ .../13_transformer_block_surgery/surgeon.py | 112 ++++++++++++++++++ .../onnx-graphsurgeon/tests/test_examples.py | 4 + 4 files changed, 235 insertions(+) create mode 100644 tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/README.md create mode 100644 tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/generate.py create mode 100644 tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/surgeon.py diff --git a/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/README.md b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/README.md new file mode 100644 index 000000000..fb1f5da20 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/README.md @@ -0,0 +1,33 @@ +# Transformer Block Surgery + +## Introduction + +Transformer-style ONNX graphs often contain shape-only or bookkeeping operators around +attention blocks. This example shows how to use ONNX GraphSurgeon for conservative graph +surgery on a small transformer-like block while keeping the model in standard ONNX +operators. + +The example performs two local rewrites: + +- Remove `Identity` nodes by rewiring their consumers. +- Cancel adjacent `Transpose` nodes when their permutations compose to the identity permutation. + +These cleanups are intentionally small and semantics-preserving. They can make generated +transformer graphs easier to inspect in Netron and prepare for downstream tooling without +introducing custom fused operators. + +## Running the example + +1. Generate a transformer-like ONNX model: + ```bash + python3 generate.py + ``` + +2. Remove no-op graph structure: + ```bash + python3 surgeon.py --input model.onnx --output cleaned.onnx + ``` + +The generated model contains a residual projection with an `Identity` and a canceling +`Transpose` pair. The surgery pass exports a checked ONNX model after cleanup and +topological sorting. diff --git a/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/generate.py b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/generate.py new file mode 100644 index 000000000..ab8bff96a --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/generate.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import onnx +import onnx_graphsurgeon as gs + + +hidden_size = 8 +num_heads = 2 +head_dim = hidden_size // num_heads +sequence_length = 4 + +tokens = gs.Variable("tokens", dtype=np.float32, shape=(1, sequence_length, hidden_size)) +identity_out = gs.Variable( + "identity_out", dtype=np.float32, shape=(1, sequence_length, hidden_size) +) +split_shape = gs.Constant( + "split_shape", values=np.array([1, sequence_length, num_heads, head_dim], dtype=np.int64) +) +merged_shape = gs.Constant( + "merged_shape", values=np.array([1, sequence_length, hidden_size], dtype=np.int64) +) +split_heads = gs.Variable( + "split_heads", dtype=np.float32, shape=(1, sequence_length, num_heads, head_dim) +) +heads_first = gs.Variable( + "heads_first", dtype=np.float32, shape=(1, num_heads, sequence_length, head_dim) +) +tokens_again = gs.Variable( + "tokens_again", dtype=np.float32, shape=(1, sequence_length, num_heads, head_dim) +) +merged = gs.Variable("merged", dtype=np.float32, shape=(1, sequence_length, hidden_size)) + +weights = gs.Constant( + "projection_weight", + values=np.linspace(-0.5, 0.5, hidden_size * hidden_size, dtype=np.float32).reshape( + hidden_size, hidden_size + ), +) +bias = gs.Constant( + "projection_bias", + values=np.linspace(-0.1, 0.1, hidden_size, dtype=np.float32), +) +projected = gs.Variable("projected", dtype=np.float32, shape=(1, sequence_length, hidden_size)) +output = gs.Variable("output", dtype=np.float32, shape=(1, sequence_length, hidden_size)) + +nodes = [ + gs.Node("Identity", inputs=[tokens], outputs=[identity_out]), + gs.Node("Reshape", inputs=[identity_out, split_shape], outputs=[split_heads]), + gs.Node( + "Transpose", + attrs={"perm": [0, 2, 1, 3]}, + inputs=[split_heads], + outputs=[heads_first], + ), + gs.Node( + "Transpose", + attrs={"perm": [0, 2, 1, 3]}, + inputs=[heads_first], + outputs=[tokens_again], + ), + gs.Node("Reshape", inputs=[tokens_again, merged_shape], outputs=[merged]), + gs.Node("MatMul", inputs=[merged, weights], outputs=[projected]), + gs.Node("Add", inputs=[projected, bias], outputs=[output]), +] + +graph = gs.Graph(nodes=nodes, inputs=[tokens], outputs=[output], opset=18) +model = gs.export_onnx(graph.cleanup().toposort()) +onnx.checker.check_model(model) +onnx.save(model, "model.onnx") diff --git a/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/surgeon.py b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/surgeon.py new file mode 100644 index 000000000..fa97e2e04 --- /dev/null +++ b/tools/onnx-graphsurgeon/examples/13_transformer_block_surgery/surgeon.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +from typing import Sequence + +import onnx +import onnx_graphsurgeon as gs + + +def replace_tensor(graph: gs.Graph, old: gs.Tensor, new: gs.Tensor) -> None: + for node in graph.nodes: + node.inputs = [new if tensor is old else tensor for tensor in node.inputs] + graph.outputs = [new if tensor is old else tensor for tensor in graph.outputs] + + +def remove_identity_nodes(graph: gs.Graph) -> int: + removed = 0 + for node in graph.nodes: + if node.op != "Identity" or len(node.inputs) != 1 or len(node.outputs) != 1: + continue + + replace_tensor(graph, node.outputs[0], node.inputs[0]) + node.inputs.clear() + node.outputs.clear() + removed += 1 + + return removed + + +def compose_permutations(first: Sequence[int], second: Sequence[int]) -> list[int]: + return [first[index] for index in second] + + +def cancel_transpose_pairs(graph: gs.Graph) -> int: + removed_pairs = 0 + + for node in graph.nodes: + if node.op != "Transpose" or len(node.inputs) != 1 or len(node.outputs) != 1: + continue + + consumers = list(node.outputs[0].outputs) + if len(consumers) != 1: + continue + + next_node = consumers[0] + if ( + next_node.op != "Transpose" + or len(next_node.inputs) != 1 + or len(next_node.outputs) != 1 + ): + continue + + first_perm = node.attrs.get("perm") + second_perm = next_node.attrs.get("perm") + if not isinstance(first_perm, list) or not isinstance(second_perm, list): + continue + if len(first_perm) != len(second_perm): + continue + + composed = compose_permutations(first_perm, second_perm) + if composed != list(range(len(composed))): + continue + + replace_tensor(graph, next_node.outputs[0], node.inputs[0]) + node.inputs.clear() + node.outputs.clear() + next_node.inputs.clear() + next_node.outputs.clear() + removed_pairs += 1 + + return removed_pairs + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Apply conservative ONNX GraphSurgeon rewrites to a transformer-like block." + ) + parser.add_argument("--input", required=True, help="Path to the input ONNX model") + parser.add_argument("--output", required=True, help="Path for the cleaned ONNX model") + args = parser.parse_args() + + graph = gs.import_onnx(onnx.load(args.input)) + removed_identities = remove_identity_nodes(graph) + removed_transpose_pairs = cancel_transpose_pairs(graph) + + model = gs.export_onnx(graph.cleanup().toposort()) + onnx.checker.check_model(model) + onnx.save(model, args.output) + + print(f"Removed Identity nodes: {removed_identities}") + print(f"Removed Transpose pairs: {removed_transpose_pairs}") + print(f"Wrote cleaned model: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/tools/onnx-graphsurgeon/tests/test_examples.py b/tools/onnx-graphsurgeon/tests/test_examples.py index 345b9f388..6e09674cc 100644 --- a/tools/onnx-graphsurgeon/tests/test_examples.py +++ b/tools/onnx-graphsurgeon/tests/test_examples.py @@ -56,6 +56,10 @@ def __init__(self, name, infer=True): "12_using_numpy_unsupported_dtypes", [Artifact("test_conv_bf16.onnx", infer=False)], ), + ( + "13_transformer_block_surgery", + [Artifact("model.onnx"), Artifact("cleaned.onnx")], + ), ]