Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,154 changes: 1,019 additions & 135 deletions paconvert/api_mapping.json

Large diffs are not rendered by default.

884 changes: 850 additions & 34 deletions paconvert/api_matcher.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions paconvert/transformer/basic_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,33 @@ def trans_class_attribute(self, node, torch_api):
)
return node

def insert_paddle_tensor_int_helper(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个masked_fill只有一个别名差异,这个PR怎么这么大的改动,还改了转换机制,具体是有什么问题?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之所以会动到转换机制,是因为我当时是按同类 API 的统一处理方式把 masked_fill 收敛到 ChangePrefixMatcher 这条通用路径上。原先这条路径只做前缀替换,不覆盖 input -> x 这类 kwarg alias,所以 masked_fill(input=...) 这种写法切过去以后就会暴露出能力缺口。这里的机制调整,本质上是把 ChangePrefixMatcher 从“只改前缀”补到了“支持轻量 kwarg 映射”,以覆盖这类同路径场景。

helper_code = ast.parse(
"""
def paddle_tensor_int(x):
module_name = type(x).__module__
if module_name.startswith("paddle") and hasattr(x, "numel") and hasattr(x, "reshape"):
assert x.numel() == 1, "only one element variable can be converted to int."
return int(x.reshape([-1])[0].item())
return int(x)
"""
).body
self.insert_multi_node(helper_code)

def trans_builtin_int(self, node):
if not isinstance(node.func, ast.Name) or node.func.id != "int":
return None
if len(node.args) != 1 or len(node.keywords) != 0:
return None

self.insert_paddle_tensor_int_helper()
new_node = ast.Call(
func=ast.Name(id="paddle_tensor_int", ctx=ast.Load()),
args=node.args,
keywords=[],
)
return ast.copy_location(new_node, node)

def visit_Call(self, node):
"""
if one line has N torch function, it has 2^N method of
Expand Down Expand Up @@ -414,6 +441,10 @@ def visit_Call(self, node):
# Use Postorder traversal
super(BasicTransformer, self).generic_visit(node)

builtin_int_node = self.trans_builtin_int(node)
if builtin_int_node:
return builtin_int_node

full_attr = self.get_full_attr_for_apiname(node.func)
# 1) Torch Package Call, include torch third_party
# such as : torch.add(x, y) / torch.add(torch.abs(x), y)
Expand Down
104 changes: 82 additions & 22 deletions tests/code_library/code_case/paddle_code/paddlenlp_Qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,92 @@ def _post_init(self):
self._init_weights()
setattr(paddleformers.transformers.model_utils.PretrainedModel, "post_init", _post_init)

def apply_rotary_position_embeddings(x, cos, sin):
def apply_rotary_position_embeddings(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets=0,
cu_seqlens=None,
max_seqlen=None,
):
if seqlen_offsets not in (0, None):
raise NotImplementedError(
"PaConvert only supports apply_rotary_emb_func with default seqlen_offsets"
)
if cu_seqlens is not None or max_seqlen is not None:
raise NotImplementedError(
"PaConvert only supports apply_rotary_emb_func without cu_seqlens or max_seqlen"
)
if not isinstance(cos, paddle.Tensor):
cos = paddle.to_tensor(cos)
cos = paddle.to_tensor(
cos, dtype=x.dtype, place=x.place, stop_gradient=True
)
if not isinstance(sin, paddle.Tensor):
sin = paddle.to_tensor(sin)
sin = paddle.to_tensor(
sin, dtype=x.dtype, place=x.place, stop_gradient=True
)

def _rotate_half(x):
from einops import rearrange

x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(axis=-2)
if interleaved:
x1 = x[..., ::2]
x2 = x[..., 1::2]
return paddle.reshape(
paddle.stack((-x2, x1), axis=-1), shape=x.shape
)
x1, x2 = paddle.split(x, num_or_sections=2, axis=-1)
return paddle.concat((-x2, x1), axis=-1)
# [seq_len,rotary_dim/2] ==>[seq_len, rotary_dim]
cos = paddle.concat([cos,cos],axis=-1)
# [seq_len, rotary_dim] ==>[1,seq_len, 1,rotary_dim]
cos=cos.unsqueeze(axis=1).unsqueeze(axis=0)
# [seq_len,rotary_dim/2] ==>[seq_len, rotary_dim]
sin = paddle.concat([sin,sin],axis=-1)
# [seq_len, rotary_dim] ==>[1,seq_len, 1,rotary_dim]
sin=sin.unsqueeze(axis=1).unsqueeze(axis=0)
t_rot, t_pass = x[..., :cos.shape[-1]], x[..., cos.shape[-1]:]
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)

return paddle.concat(x=(t_rot, t_pass), axis=-1)

if interleaved:
cos = paddle.repeat_interleave(cos, repeats=2, axis=-1)
sin = paddle.repeat_interleave(sin, repeats=2, axis=-1)
else:
cos = paddle.concat([cos, cos], axis=-1)
sin = paddle.concat([sin, sin], axis=-1)

cos = cos.unsqueeze(axis=-2)
sin = sin.unsqueeze(axis=-2)
rotary_dim = cos.shape[-1]
assert rotary_dim <= x.shape[-1]
t_rot, t_pass = x[..., :rotary_dim], x[..., rotary_dim:]
out = paddle.concat(
x=((t_rot * cos) + (_rotate_half(t_rot) * sin), t_pass), axis=-1
)
if inplace:
paddle.assign(out, output=x)
return x
return out

def paddle_flash_attn_rms_norm(x, weight, epsilon):
if weight is not None and x.place.is_gpu_place():
try:
out = paddle.incubate.nn.functional.fused_rms_norm(
x, weight, paddle.zeros_like(weight), epsilon, len(x.shape) - 1
)
if isinstance(out, (tuple, list)):
return out[0]
return out
except Exception:
pass

original_dtype = x.dtype
if x.dtype in [paddle.float16, paddle.bfloat16]:
compute_x = paddle.cast(x, "float32")
else:
compute_x = x

out = compute_x * paddle.rsqrt(
paddle.mean(paddle.square(compute_x), axis=-1, keepdim=True) + epsilon
)
if weight is not None:
if weight.dtype != out.dtype:
weight = paddle.cast(weight, out.dtype)
out = out * weight

if out.dtype != original_dtype:
out = paddle.cast(out, original_dtype)
return out
############################## 相关utils函数,如上 ##############################


Expand Down Expand Up @@ -177,6 +239,4 @@ class QWenTokenizer(paddleformers.PreTrainedTokenizer):
print("#########################case16#########################")
apply_rotary_position_embeddings(x=x, cos=cos, sin=sin)
print("#########################case17#########################")
paddle.incubate.nn.functional.fused_rms_norm(
x, weight, paddle.zeros_like(weight), eps, len(x.shape) - 1
)[0]
paddle_flash_attn_rms_norm(x, weight, eps)
19 changes: 19 additions & 0 deletions tests/flash_attn_tests/test_flash_attn_apply_rotary_emb_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,22 @@ def test_case_1():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
from flash_attn.layers.rotary import apply_rotary_emb_func
x = torch.ones([1, 2, 2, 4]).cuda()
cos = torch.ones([2, 2]).cuda()
sin = torch.ones([2, 2]).cuda()
result = apply_rotary_emb_func(
x, cos, sin, interleaved=False, inplace=False, seqlen_offsets=0
)
"""
)
paddle_code = obj.convert(pytorch_code)
assert "from einops import rearrange" not in paddle_code
assert "paddle.assign(out, output=x)" in paddle_code
assert "interleaved=False" in paddle_code
44 changes: 44 additions & 0 deletions tests/flash_attn_tests/test_flash_attn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,47 @@ def test_case_1():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import flash_attn
q = torch.ones([1,8,8,8],dtype=torch.float16).cuda()
result = flash_attn.flash_attn_interface.flash_attn_func(q,q,q,0,None,False)
"""
)
expect_paddle_code = textwrap.dedent(
"""
import paddle

q = paddle.ones([1, 8, 8, 8], dtype=paddle.float16).cuda()
result = paddle.nn.functional.flash_attention.flash_attention(
query=q, key=q, value=q, dropout=0, causal=False
)[0]
"""
)
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import flash_attn
q = torch.ones([1,8,8,8],dtype=torch.float16).cuda()
result = flash_attn.flash_attn_interface.flash_attn_func(q, q, q, 0, None, False, None, 0.0, None, False)
"""
)
expect_paddle_code = textwrap.dedent(
"""
import paddle

q = paddle.ones([1, 8, 8, 8], dtype=paddle.float16).cuda()
result = paddle.nn.functional.flash_attention.flash_attention(
query=q, key=q, value=q, dropout=0, causal=False
)[0]
"""
)
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)
18 changes: 18 additions & 0 deletions tests/flash_attn_tests/test_flash_attn_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,21 @@ def test_case_3():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
from flash_attn.ops.rms_norm import rms_norm
x = torch.tensor([
[[0.4742, 3.5466, -4.8008, -8.9079, 0.4742, 9.5466, -8.8008, -6.9079]],
[[3.4742, 0.5466, -0.8008, -0.9079, 3.4742, 0.5466, -0.8008, -0.9079]]
]).cuda()
weight = torch.ones(8).cuda()
result = rms_norm(x, weight,1e-6)
"""
)
paddle_code = obj.convert(pytorch_code)
assert "paddle_flash_attn_rms_norm" in paddle_code
assert "result = paddle_flash_attn_rms_norm(x, weight, 1e-06)" in paddle_code
37 changes: 37 additions & 0 deletions tests/flash_attn_tests/test_flash_attn_unpadded_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,40 @@ def test_case_1():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
q = torch.ones([8,8,8],dtype=torch.float16).cuda()
cu_seqlens_q = torch.ones([8],dtype=torch.int32).cuda()
result = flash_attn_unpadded_func(q,q,q,cu_seqlens_q,cu_seqlens_q,4,4,0.25)
"""
)
expect_paddle_code = textwrap.dedent(
"""
import math

import paddle

q = paddle.ones([8, 8, 8], dtype=paddle.float16).cuda()
cu_seqlens_q = paddle.ones([8], dtype=paddle.int32).cuda()
assert (
paddle.device.cuda.get_device_capability()[0] >= 8
), "Device capabilities should be at least 8"
result = paddle.nn.functional.flash_attention.flash_attn_unpadded(
query=q,
key=q,
value=q,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=4,
max_seqlen_k=4,
dropout=0.25,
scale=1.0 / math.sqrt(q.shape[-1]),
)[0]
"""
)
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)
19 changes: 19 additions & 0 deletions tests/test_Tensor_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,22 @@ def test_case_7():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = input.std(keepdim=True, correction=0, dim=1)
"""
)
paddle_code = textwrap.dedent(
"""
import paddle

input = paddle.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = input.std(keepdim=True, correction=0, axis=1)
"""
)
obj.run(pytorch_code, expect_paddle_code=paddle_code)
19 changes: 19 additions & 0 deletions tests/test_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,22 @@ def test_case_11():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_12():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = torch.std(input=input, dim=1, correction=1, keepdim=True)
"""
)
paddle_code = textwrap.dedent(
"""
import paddle

input = paddle.tensor([[1.4907, 1.0593, 1.5696], [1.4907, 1.0593, 1.5696]])
result = paddle.std(x=input, axis=1, correction=1, keepdim=True)
"""
)
obj.run(pytorch_code, expect_paddle_code=paddle_code)
32 changes: 32 additions & 0 deletions tests/test_std_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,35 @@ def test_case_9():
"""
)
obj.run(pytorch_code, ["std", "mean"])


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor(
[[ 0.2035, 1.2959, 1.8101, -0.4644],
[ 1.5027, -0.3270, 0.5905, 0.6538],
[-1.5745, 1.3330, -0.5596, -0.6548],
[ 0.1264, -0.5080, 1.6420, 0.1992]])
std, mean = torch.std_mean(input=a, correction=0, dim=1, keepdim=True)
"""
)
paddle_code = textwrap.dedent(
"""
import paddle

a = paddle.tensor(
[
[0.2035, 1.2959, 1.8101, -0.4644],
[1.5027, -0.327, 0.5905, 0.6538],
[-1.5745, 1.333, -0.5596, -0.6548],
[0.1264, -0.508, 1.642, 0.1992],
]
)
std, mean = paddle.std(correction=0, keepdim=True, x=a, axis=1), paddle.mean(
keepdim=True, x=a, axis=1
)
"""
)
obj.run(pytorch_code, expect_paddle_code=paddle_code)
4 changes: 3 additions & 1 deletion tools/consistency/consistency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _compare_content(actual_dir, expect_dir):
result = True
if os.path.isfile(actual_dir):
assert os.path.isfile(expect_dir), f"{expect_dir} shoule be a file!"
with open(actual_dir, "r") as f1, open(expect_dir, "r") as f2:
with open(actual_dir, "r", encoding="utf-8") as f1, open(
expect_dir, "r", encoding="utf-8"
) as f2:
content1 = f1.read().strip()
content2 = f2.read().strip()
# 对随机的辅助代码路径进行处理,使用正则表达式匹配并替换
Expand Down