Skip to content

Commit aeef2f4

Browse files
authored
fix(autojac): Replace unsafe view with reshape (#582)
* Add test for jac_to_grad with non-contiguous jac (that would have failed in v0.8.1) * Replace another potentially unsafe view with reshape
1 parent 8862d13 commit aeef2f4

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _disunite_gradient(
9292
tensors: list[TensorWithJac],
9393
) -> list[Tensor]:
9494
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
95-
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)]
95+
gradients = [g.reshape(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)]
9696
return gradients
9797

9898

tests/unit/autojac/test_jac_to_grad.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,17 @@ def test_jacs_are_freed(retain_jac: bool):
101101
check = assert_has_jac if retain_jac else assert_has_no_jac
102102
check(t1)
103103
check(t2)
104+
105+
106+
def test_noncontiguous_jac():
107+
"""Tests that jac_to_grad works when the .jac field is non-contiguous."""
108+
109+
aggregator = UPGrad()
110+
t = tensor_([2.0, 3.0, 4.0], requires_grad=True)
111+
jac_T = tensor_([[-4.0, 1.0], [1.0, 6.0], [1.0, 1.0]])
112+
jac = jac_T.T
113+
t.__setattr__("jac", jac)
114+
g = aggregator(jac)
115+
116+
jac_to_grad([t], aggregator)
117+
assert_grad_close(t, g)

0 commit comments

Comments
 (0)