diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 79e6be8e..ef48601f 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -46,9 +46,10 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): - mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) - masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) + input_cpu = input.detach().cpu().to(torch.float32) + mask = torch.tril(torch.ones_like(input_cpu), diagonal=-1).flip(dims=[-2, -1]) + masked = torch.where(mask == 1, -torch.inf, input_cpu) result = torch.nn.functional.softmax(masked, dim=-1) - out.copy_(result) + out.copy_(result.to(device=out.device, dtype=out.dtype)) return out