Skip to content

Commit cb92d27

Browse files
committed
added test
1 parent db0024b commit cb92d27

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

tests/test_sba_steps.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
import os
2+
import sys
3+
import sysconfig
4+
5+
import numpy as np
6+
import pytest
7+
8+
9+
def _import_easysba():
10+
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
11+
sys.path = [p for p in sys.path if p not in ("", repo_root)]
12+
13+
purelib = sysconfig.get_paths().get("purelib")
14+
platlib = sysconfig.get_paths().get("platlib")
15+
for path in [platlib, purelib]:
16+
if path and path not in sys.path:
17+
sys.path.insert(0, path)
18+
19+
import easysba
20+
21+
return easysba
22+
23+
24+
def _make_inputs(num_points=4, num_cams=2, cam_param_size=12, seed=0):
25+
rng = np.random.default_rng(seed)
26+
image_uv = rng.normal(scale=0.5, size=(num_points, num_cams, 2)).astype(np.float64)
27+
visibility_mask = np.ones((num_points, num_cams), dtype=np.uint8)
28+
world_xyz = rng.normal(scale=2.0, size=(num_points, 3)).astype(np.float64)
29+
30+
if cam_param_size == 7:
31+
camera_params = np.zeros((num_cams, 7), dtype=np.float64)
32+
camera_params[:, 0] = 1.0
33+
camera_params[:, 4:] = rng.normal(scale=0.1, size=(num_cams, 3))
34+
elif cam_param_size == 12:
35+
camera_params = np.zeros((num_cams, 12), dtype=np.float64)
36+
camera_params[:, 0] = 1.0
37+
camera_params[:, 3] = 1.0
38+
camera_params[:, 5] = 1.0
39+
camera_params[:, 9:] = rng.normal(scale=0.1, size=(num_cams, 3))
40+
elif cam_param_size == 17:
41+
camera_params = np.zeros((num_cams, 17), dtype=np.float64)
42+
camera_params[:, 0] = 1.0
43+
camera_params[:, 3] = 1.0
44+
camera_params[:, 5] = 1.0
45+
camera_params[:, 10:13] = rng.normal(scale=0.01, size=(num_cams, 3))
46+
camera_params[:, 14:] = rng.normal(scale=0.1, size=(num_cams, 3))
47+
else:
48+
raise ValueError("Unsupported cam_param_size")
49+
50+
return image_uv, visibility_mask, world_xyz, camera_params
51+
52+
53+
@pytest.mark.parametrize("world_shape", [(4,), (4, 4)])
54+
def test_world_shape_validation(world_shape):
55+
easysba = _import_easysba()
56+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs()
57+
world_xyz = np.zeros(world_shape, dtype=np.float64)
58+
59+
with pytest.raises(RuntimeError, match="world_xyz must have shape"):
60+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
61+
62+
63+
def test_camera_shape_validation():
64+
easysba = _import_easysba()
65+
image_uv, visibility_mask, world_xyz, _ = _make_inputs()
66+
camera_params = np.zeros((2,), dtype=np.float64)
67+
68+
with pytest.raises(RuntimeError, match="camera_params must have shape"):
69+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
70+
71+
72+
def test_visibility_shape_validation():
73+
easysba = _import_easysba()
74+
image_uv, _, world_xyz, camera_params = _make_inputs()
75+
visibility_mask = np.ones((1, 1), dtype=np.uint8)
76+
77+
with pytest.raises(RuntimeError, match="visibility_mask shape does not match"):
78+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
79+
80+
81+
def test_camera_param_size_validation():
82+
easysba = _import_easysba()
83+
image_uv, visibility_mask, world_xyz, _ = _make_inputs(cam_param_size=7)
84+
camera_params = np.zeros((visibility_mask.shape[1], 9), dtype=np.float64)
85+
86+
with pytest.raises(RuntimeError, match="num_camera_params must be 7, 12, or 17"):
87+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
88+
89+
90+
@pytest.mark.parametrize("ndim", [1, 4])
91+
def test_image_uv_ndim_validation(ndim):
92+
easysba = _import_easysba()
93+
_, visibility_mask, world_xyz, camera_params = _make_inputs()
94+
image_uv = np.zeros((2,) * ndim, dtype=np.float64)
95+
96+
with pytest.raises(RuntimeError, match="image_uv must be a 2D or 3D array"):
97+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
98+
99+
100+
def test_image_uv_shape_3d_validation():
101+
easysba = _import_easysba()
102+
_, visibility_mask, world_xyz, camera_params = _make_inputs()
103+
image_uv = np.zeros((world_xyz.shape[0], visibility_mask.shape[1], 3), dtype=np.float64)
104+
105+
with pytest.raises(RuntimeError, match="image_uv must have shape"):
106+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params)
107+
108+
109+
def test_image_uv_shape_2d_validation():
110+
easysba = _import_easysba()
111+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs()
112+
image_uv = image_uv.reshape(world_xyz.shape[0], -1)
113+
bad_image = image_uv[:, :-1]
114+
115+
with pytest.raises(RuntimeError, match="image_uv must have shape"):
116+
easysba.easy_sba(bad_image, visibility_mask, world_xyz, camera_params)
117+
118+
119+
def test_intrinsics_fixed_bounds():
120+
easysba = _import_easysba()
121+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(cam_param_size=7)
122+
123+
with pytest.raises(RuntimeError, match="intrinsics_fixed must be between"):
124+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params, intrinsics_fixed=6)
125+
126+
127+
def test_distortion_fixed_bounds():
128+
easysba = _import_easysba()
129+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(cam_param_size=7)
130+
131+
with pytest.raises(RuntimeError, match="distortion_fixed must be"):
132+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params, distortion_fixed=6)
133+
134+
135+
def test_intrinsics_fixed_mismatch_with_params():
136+
easysba = _import_easysba()
137+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(cam_param_size=12)
138+
139+
with pytest.raises(RuntimeError, match="intrinsics_fixed cannot be -1"):
140+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params, intrinsics_fixed=-1)
141+
142+
143+
def test_intrinsics_fixed_invalid_for_default_intrinsics():
144+
easysba = _import_easysba()
145+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(cam_param_size=7)
146+
147+
with pytest.raises(RuntimeError, match="intrinsics_fixed must be -1"):
148+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params, intrinsics_fixed=0)
149+
150+
151+
def test_visibility_nan_mismatch():
152+
easysba = _import_easysba()
153+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs()
154+
155+
visibility_mask[0, 0] = 0
156+
image_uv[0, 0] = 1.0
157+
with pytest.raises(RuntimeError, match="visibility_mask is false"):
158+
easysba.easy_sba(
159+
image_uv,
160+
visibility_mask,
161+
world_xyz,
162+
camera_params,
163+
intrinsics_fixed=5,
164+
)
165+
166+
visibility_mask[0, 0] = 1
167+
image_uv[0, 0] = np.nan
168+
with pytest.raises(RuntimeError, match="visibility_mask is true"):
169+
easysba.easy_sba(
170+
image_uv,
171+
visibility_mask,
172+
world_xyz,
173+
camera_params,
174+
intrinsics_fixed=5,
175+
)
176+
177+
178+
def test_visibility_empty():
179+
easysba = _import_easysba()
180+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs()
181+
visibility_mask[:] = 0
182+
image_uv[:] = np.nan
183+
184+
with pytest.raises(RuntimeError, match="visibility_mask has no valid projections"):
185+
easysba.easy_sba(
186+
image_uv,
187+
visibility_mask,
188+
world_xyz,
189+
camera_params,
190+
intrinsics_fixed=5,
191+
)
192+
193+
194+
def test_quaternion_zero_norm():
195+
easysba = _import_easysba()
196+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(cam_param_size=12)
197+
camera_params[0, 5:9] = 0.0
198+
199+
with pytest.raises(RuntimeError, match="Quaternion has zero norm"):
200+
easysba.easy_sba(image_uv, visibility_mask, world_xyz, camera_params, intrinsics_fixed=5)
201+
202+
203+
@pytest.mark.skipif(os.environ.get("EASYSBA_RUN_SOLVER") != "1", reason="Set EASYSBA_RUN_SOLVER=1 to exercise solver")
204+
@pytest.mark.parametrize("cam_param_size,intrinsics_fixed,distortion_fixed", [
205+
(7, -1, -1),
206+
(12, 5, -1),
207+
(17, 5, 0),
208+
])
209+
def test_solver_runs_seeded(cam_param_size, intrinsics_fixed, distortion_fixed):
210+
easysba = _import_easysba()
211+
image_uv, visibility_mask, world_xyz, camera_params = _make_inputs(
212+
num_points=6, num_cams=3, cam_param_size=cam_param_size, seed=42
213+
)
214+
215+
world_out, cams_out, info = easysba.easy_sba(
216+
image_uv,
217+
visibility_mask,
218+
world_xyz,
219+
camera_params,
220+
intrinsics_fixed=intrinsics_fixed,
221+
distortion_fixed=distortion_fixed,
222+
max_iter=2,
223+
verbose=False,
224+
)
225+
226+
assert world_out.shape == world_xyz.shape
227+
assert cams_out.shape == camera_params.shape
228+
assert "return_code" in info
229+
assert "final_error" in info

0 commit comments

Comments
 (0)