Skip to content

Commit aa7fb9f

Browse files
committed
Add annotation as Any for args and kwargs
1 parent cff1a97 commit aa7fb9f

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

tests/utils/architectures.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import Generic, TypeVar
2+
from typing import Any, Generic, TypeVar
33

44
import torch
55
import torchvision
@@ -14,7 +14,7 @@
1414

1515

1616
class ModuleFactory(Generic[_T]):
17-
def __init__(self, architecture: type[_T], *args, **kwargs) -> None:
17+
def __init__(self, architecture: type[_T], *args: Any, **kwargs: Any) -> None:
1818
self.architecture: type[_T] = architecture
1919
self.args = args
2020
self.kwargs = kwargs

tests/utils/asserts.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import torch
24
from torch import Tensor
35
from torch.testing import assert_close
@@ -16,7 +18,7 @@ def assert_has_no_jac(t: Tensor) -> None:
1618
assert not is_tensor_with_jac(t)
1719

1820

19-
def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs) -> None:
21+
def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs: Any) -> None:
2022
assert is_tensor_with_jac(t)
2123
assert_close(t.jac, expected_jac, **kwargs)
2224

@@ -29,12 +31,12 @@ def assert_has_no_grad(t: Tensor) -> None:
2931
assert t.grad is None
3032

3133

32-
def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs) -> None:
34+
def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs: Any) -> None:
3335
assert t.grad is not None
3436
assert_close(t.grad, expected_grad, **kwargs)
3537

3638

37-
def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None:
39+
def assert_is_psd_matrix(matrix: Tensor, **kwargs: Any) -> None:
3840
assert is_psd_matrix(matrix)
3941
assert_close(matrix, matrix.mH, **kwargs)
4042

@@ -44,7 +46,7 @@ def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None:
4446
assert_close(eig_vals, expected_eig_vals, **kwargs)
4547

4648

47-
def assert_is_psd_tensor(t: Tensor, **kwargs) -> None:
49+
def assert_is_psd_tensor(t: Tensor, **kwargs: Any) -> None:
4850
assert is_psd_tensor(t)
4951
matrix = flatten(t)
5052
assert_is_psd_matrix(matrix, **kwargs)

0 commit comments

Comments
 (0)