Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Transformer Block Surgery

## Introduction

Transformer-style ONNX graphs often contain shape-only or bookkeeping operators around
attention blocks. This example shows how to use ONNX GraphSurgeon for conservative graph
surgery on a small transformer-like block while keeping the model in standard ONNX
operators.

The example performs two local rewrites:

- Remove `Identity` nodes by rewiring their consumers.
- Cancel adjacent `Transpose` nodes when their permutations compose to the identity permutation.

These cleanups are intentionally small and semantics-preserving. They can make generated
transformer graphs easier to inspect in Netron and prepare for downstream tooling without
introducing custom fused operators.

## Running the example

1. Generate a transformer-like ONNX model:
```bash
python3 generate.py
```

2. Remove no-op graph structure:
```bash
python3 surgeon.py --input model.onnx --output cleaned.onnx
```

The generated model contains a residual projection with an `Identity` and a canceling
`Transpose` pair. The surgery pass exports a checked ONNX model after cleanup and
topological sorting.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import onnx
import onnx_graphsurgeon as gs


hidden_size = 8
num_heads = 2
head_dim = hidden_size // num_heads
sequence_length = 4

tokens = gs.Variable("tokens", dtype=np.float32, shape=(1, sequence_length, hidden_size))
identity_out = gs.Variable(
"identity_out", dtype=np.float32, shape=(1, sequence_length, hidden_size)
)
split_shape = gs.Constant(
"split_shape", values=np.array([1, sequence_length, num_heads, head_dim], dtype=np.int64)
)
merged_shape = gs.Constant(
"merged_shape", values=np.array([1, sequence_length, hidden_size], dtype=np.int64)
)
split_heads = gs.Variable(
"split_heads", dtype=np.float32, shape=(1, sequence_length, num_heads, head_dim)
)
heads_first = gs.Variable(
"heads_first", dtype=np.float32, shape=(1, num_heads, sequence_length, head_dim)
)
tokens_again = gs.Variable(
"tokens_again", dtype=np.float32, shape=(1, sequence_length, num_heads, head_dim)
)
merged = gs.Variable("merged", dtype=np.float32, shape=(1, sequence_length, hidden_size))

weights = gs.Constant(
"projection_weight",
values=np.linspace(-0.5, 0.5, hidden_size * hidden_size, dtype=np.float32).reshape(
hidden_size, hidden_size
),
)
bias = gs.Constant(
"projection_bias",
values=np.linspace(-0.1, 0.1, hidden_size, dtype=np.float32),
)
projected = gs.Variable("projected", dtype=np.float32, shape=(1, sequence_length, hidden_size))
output = gs.Variable("output", dtype=np.float32, shape=(1, sequence_length, hidden_size))

nodes = [
gs.Node("Identity", inputs=[tokens], outputs=[identity_out]),
gs.Node("Reshape", inputs=[identity_out, split_shape], outputs=[split_heads]),
gs.Node(
"Transpose",
attrs={"perm": [0, 2, 1, 3]},
inputs=[split_heads],
outputs=[heads_first],
),
gs.Node(
"Transpose",
attrs={"perm": [0, 2, 1, 3]},
inputs=[heads_first],
outputs=[tokens_again],
),
gs.Node("Reshape", inputs=[tokens_again, merged_shape], outputs=[merged]),
gs.Node("MatMul", inputs=[merged, weights], outputs=[projected]),
gs.Node("Add", inputs=[projected, bias], outputs=[output]),
]

graph = gs.Graph(nodes=nodes, inputs=[tokens], outputs=[output], opset=18)
model = gs.export_onnx(graph.cleanup().toposort())
onnx.checker.check_model(model)
onnx.save(model, "model.onnx")
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/usr/bin/env python3
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
from typing import Sequence

import onnx
import onnx_graphsurgeon as gs


def replace_tensor(graph: gs.Graph, old: gs.Tensor, new: gs.Tensor) -> None:
for node in graph.nodes:
node.inputs = [new if tensor is old else tensor for tensor in node.inputs]
graph.outputs = [new if tensor is old else tensor for tensor in graph.outputs]


def remove_identity_nodes(graph: gs.Graph) -> int:
removed = 0
for node in graph.nodes:
if node.op != "Identity" or len(node.inputs) != 1 or len(node.outputs) != 1:
continue

replace_tensor(graph, node.outputs[0], node.inputs[0])
node.inputs.clear()
node.outputs.clear()
removed += 1

return removed


def compose_permutations(first: Sequence[int], second: Sequence[int]) -> list[int]:
return [first[index] for index in second]


def cancel_transpose_pairs(graph: gs.Graph) -> int:
removed_pairs = 0

for node in graph.nodes:
if node.op != "Transpose" or len(node.inputs) != 1 or len(node.outputs) != 1:
continue

consumers = list(node.outputs[0].outputs)
if len(consumers) != 1:
continue

next_node = consumers[0]
if (
next_node.op != "Transpose"
or len(next_node.inputs) != 1
or len(next_node.outputs) != 1
):
continue

first_perm = node.attrs.get("perm")
second_perm = next_node.attrs.get("perm")
if not isinstance(first_perm, list) or not isinstance(second_perm, list):
continue
if len(first_perm) != len(second_perm):
continue

composed = compose_permutations(first_perm, second_perm)
if composed != list(range(len(composed))):
continue

replace_tensor(graph, next_node.outputs[0], node.inputs[0])
node.inputs.clear()
node.outputs.clear()
next_node.inputs.clear()
next_node.outputs.clear()
removed_pairs += 1

return removed_pairs


def main() -> None:
parser = argparse.ArgumentParser(
description="Apply conservative ONNX GraphSurgeon rewrites to a transformer-like block."
)
parser.add_argument("--input", required=True, help="Path to the input ONNX model")
parser.add_argument("--output", required=True, help="Path for the cleaned ONNX model")
args = parser.parse_args()

graph = gs.import_onnx(onnx.load(args.input))
removed_identities = remove_identity_nodes(graph)
removed_transpose_pairs = cancel_transpose_pairs(graph)

model = gs.export_onnx(graph.cleanup().toposort())
onnx.checker.check_model(model)
onnx.save(model, args.output)

print(f"Removed Identity nodes: {removed_identities}")
print(f"Removed Transpose pairs: {removed_transpose_pairs}")
print(f"Wrote cleaned model: {args.output}")


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions tools/onnx-graphsurgeon/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def __init__(self, name, infer=True):
"12_using_numpy_unsupported_dtypes",
[Artifact("test_conv_bf16.onnx", infer=False)],
),
(
"13_transformer_block_surgery",
[Artifact("model.onnx"), Artifact("cleaned.onnx")],
),
]


Expand Down