From 961c2d96bb473fa09d3b4d22245d5e7d42f6280f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Sat, 16 May 2026 10:38:39 +0800 Subject: [PATCH] fix(tests): run causal softmax reference on CPU --- tests/test_causal_softmax.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 79e6be8e0..ef48601f7 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -46,9 +46,10 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): - mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) - masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) + input_cpu = input.detach().cpu().to(torch.float32) + mask = torch.tril(torch.ones_like(input_cpu), diagonal=-1).flip(dims=[-2, -1]) + masked = torch.where(mask == 1, -torch.inf, input_cpu) result = torch.nn.functional.softmax(masked, dim=-1) - out.copy_(result) + out.copy_(result.to(device=out.device, dtype=out.dtype)) return out