Added LongRoPe Model Causal Mask Pattern Fusion#2473
Added LongRoPe Model Causal Mask Pattern Fusion#2473tadani3 wants to merge 29 commits intomicrosoft:mainfrom
Conversation
…xscript into longrope_causal_mask
| """ | ||
| Pattern for LongRoPe GQA Causal Mask. | ||
| This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
| It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
| """ | ||
| Pattern for LongRoPe GQA Causal Mask. | ||
| This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
| It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
| mask_key = _get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
| mask_key = _get_mask_key(attention_mask) | ||
|
|
||
| if mask_key in self._mask_cache: | ||
| total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
| # Licensed under the MIT License. See License.txt in the project root for | ||
| # license information. | ||
| # -------------------------------------------------------------------------- | ||
| import onnx |
Check notice
Code scanning / CodeQL
Unused import Note
| # -------------------------------------------------------------------------- | ||
| import onnx | ||
| from onnxscript import ir | ||
| import onnx.helper |
Check notice
Code scanning / CodeQL
Unused import Note
| cache_length = self.rotemb_attrs["cache_length"] | ||
| position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) | ||
|
|
||
| inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| with torch.autocast(device_type=device_type, enabled=False): | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) | ||
| emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) | ||
| cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] | ||
|
|
||
| inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim | ||
| inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
| if "rescale_inv_freq" in self.rotemb_attrs: | ||
| inv_freq = self.make_inv_freq_rescaled(inv_freq) | ||
|
|
||
| return inv_freq, attention_factor |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
There was a problem hiding this comment.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
…t#2465) Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Provide a way to indicate that a pattern-variable can match successfully against a None-valued input. Cleanup current handling which was inconsistent in one place. Add test cases. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This PR adds comprehensive documentation for the rewriter pattern
options that were previously undocumented. The rewriter pattern system
supports four key options for controlling pattern matching and
replacement behavior:
## New Documentation Added
### `_allow_other_inputs` option
- **File**: `docs/tutorial/rewriter/allow_other_inputs.md`
- **Purpose**: Controls whether patterns can match nodes with additional
inputs beyond those specified
- **Default**: `False` (exact input matching)
- **Example**: Matching `Conv` operations that may have optional bias
inputs
```python
def conv_pattern(op, input, weight):
# Matches Conv with 2 or 3 inputs (weight + optional bias)
return op.Conv(input, weight, _allow_other_inputs=True)
```
### `_domain` option
- **File**: `docs/tutorial/rewriter/domain_option.md`
- **Purpose**: Specifies operator domains for pattern matching and
replacement
- **Use cases**: Domain-specific rewrites, migrating between operator
domains
- **Example**: Targeting operations from specific domains like
"com.microsoft"
```python
def custom_relu_pattern(op, input):
# Only matches Relu from custom domain
return op.Relu(input, _domain="custom.domain")
```
### `_outputs` option
- **File**: `docs/tutorial/rewriter/outputs_option.md`
- **Purpose**: Specifies number and names of operation outputs
- **Formats**: Integer count (`_outputs=2`) or named list
(`_outputs=["first", "second"]`)
- **Example**: Handling multi-output operations like `Split`
```python
def split_pattern(op, input):
# Matches Split operations with exactly 2 outputs
return op.Split(input, num_outputs=2, axis=0, _outputs=2)
```
### Enhanced `_allow_other_attributes` documentation
- **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting)
- **Already documented**: Controls whether patterns match nodes with
additional attributes
- **Default**: `True` (allows extra attributes)
## Documentation Structure Improvements
- Added "Pattern Options" section to main rewriter documentation
- Integrated all option docs into the tutorial flow
- Created working code examples for each option
- Followed existing documentation patterns and style
- All examples compile and run successfully
- Documentation builds correctly with Sphinx
The documentation now provides complete coverage of all rewriter pattern
options with practical examples showing real-world usage patterns.
Fixes microsoft#2405.
> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `docs.python.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `docs.scipy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `matplotlib.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `numpy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnx.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnxruntime.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `pytorch.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to my [firewall allow
list](https://gh.io/copilot/firewall-config)
>
> </details>
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Simplify implementation for `aten_chunk` and allow it to work on all data types. Original author: @xadupre Updated: Conditionally use the new implementation when torch>=2.7 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2473 +/- ##
==========================================
- Coverage 69.81% 69.01% -0.81%
==========================================
Files 209 211 +2
Lines 25313 25978 +665
Branches 2525 2612 +87
==========================================
+ Hits 17673 17928 +255
- Misses 6762 7175 +413
+ Partials 878 875 -3 ☔ View full report in Codecov by Sentry. |
This PR introduces a specialized LongRoPe (Long Range Rotary Position Embedding) GQA (Group Query Attention) causal mask fusion rule specifically designed for Phi-4-mini-reasoning and similar models. The implementation optimizes attention mask computation for models using sliding window attention with LongRoPe position embeddings.
New LongRoPeGQACausalMask Class
Advanced Mask Computation
Note: This PR is meant to replace #2461 by introducing the requested changes.