1+ from typing import Any
2+
13import torch
24from torch import Tensor
35from torch .testing import assert_close
@@ -16,7 +18,7 @@ def assert_has_no_jac(t: Tensor) -> None:
1618 assert not is_tensor_with_jac (t )
1719
1820
19- def assert_jac_close (t : Tensor , expected_jac : Tensor , ** kwargs ) -> None :
21+ def assert_jac_close (t : Tensor , expected_jac : Tensor , ** kwargs : Any ) -> None :
2022 assert is_tensor_with_jac (t )
2123 assert_close (t .jac , expected_jac , ** kwargs )
2224
@@ -29,12 +31,12 @@ def assert_has_no_grad(t: Tensor) -> None:
2931 assert t .grad is None
3032
3133
32- def assert_grad_close (t : Tensor , expected_grad : Tensor , ** kwargs ) -> None :
34+ def assert_grad_close (t : Tensor , expected_grad : Tensor , ** kwargs : Any ) -> None :
3335 assert t .grad is not None
3436 assert_close (t .grad , expected_grad , ** kwargs )
3537
3638
37- def assert_is_psd_matrix (matrix : Tensor , ** kwargs ) -> None :
39+ def assert_is_psd_matrix (matrix : Tensor , ** kwargs : Any ) -> None :
3840 assert is_psd_matrix (matrix )
3941 assert_close (matrix , matrix .mH , ** kwargs )
4042
@@ -44,7 +46,7 @@ def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None:
4446 assert_close (eig_vals , expected_eig_vals , ** kwargs )
4547
4648
47- def assert_is_psd_tensor (t : Tensor , ** kwargs ) -> None :
49+ def assert_is_psd_tensor (t : Tensor , ** kwargs : Any ) -> None :
4850 assert is_psd_tensor (t )
4951 matrix = flatten (t )
5052 assert_is_psd_matrix (matrix , ** kwargs )
0 commit comments