Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions t4_devkit/helper/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,17 @@ def __init__(self, t4: Tier4) -> None:
self._label2id: dict[str, int] = {
category.name: category.index for category in self._t4.category
}
self._sample_data_to_lidarseg_filename: dict[str, str] | None = (
{lidarseg.sample_data_token: lidarseg.filename for lidarseg in self._t4.lidarseg}
if self._t4.lidarseg
else None
)
self._sample_data_to_lidarseg_filename: dict[str, str] = {
lidarseg.sample_data_token: lidarseg.filename for lidarseg in self._t4.lidarseg
}

self._executor = concurrent.futures.ThreadPoolExecutor()

def _has_lidarseg(self) -> bool:
return self._sample_data_to_lidarseg_filename is not None
return bool(self._sample_data_to_lidarseg_filename)

def _find_lidarseg_file(self, sample_data_token: str) -> str | None:
return self._sample_data_to_lidarseg_filename.get(sample_data_token)

def _init_viewer(
self,
Expand Down Expand Up @@ -432,13 +433,10 @@ def _render_single_lidar(first_lidar_token: str) -> None:

# render segmentation pointcloud if available, otherwise render raw pointcloud
if color_mode == PointCloudColorMode.SEGMENTATION:
if not (
self._has_lidarseg()
and sample_data.token in self._sample_data_to_lidarseg_filename
):
label_filename = self._find_lidarseg_file(sample_data.token)
if label_filename is None:
continue

label_filename = self._sample_data_to_lidarseg_filename[sample_data.token]
pointcloud = SegmentationPointCloud.from_file(
point_filepath=osp.join(self._t4.data_root, sample_data.filename),
label_filepath=osp.join(self._t4.data_root, label_filename),
Expand Down
39 changes: 17 additions & 22 deletions t4_devkit/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,18 +427,19 @@ def render_pointcloud(
# TODO(ktro2828): add support of rendering pointcloud on images
rr.set_time_seconds(self.config.timeline, seconds)

entity_path = format_entity(self.config.ego_entity, channel)
if color_mode == PointCloudColorMode.SEGMENTATION:
assert isinstance(pointcloud, SegmentationPointCloud)
rr.log(
format_entity(self.config.ego_entity, channel),
rr.Points3D(pointcloud.points[:3].T, class_ids=pointcloud.labels),
)
if not isinstance(pointcloud, SegmentationPointCloud):
raise TypeError(
f"Expected SegmentationPointCloud instance, but got {type(pointcloud)}"
)

entity = rr.Points3D(pointcloud.points[:3].T, class_ids=pointcloud.labels)
else:
colors = pointcloud_color(pointcloud, color_mode=color_mode)
rr.log(
format_entity(self.config.ego_entity, channel),
rr.Points3D(pointcloud.points[:3].T, colors=colors),
)
entity = rr.Points3D(pointcloud.points[:3].T, colors=colors)

rr.log(entity_path, entity)

@_check_spatial2d
def render_image(self, seconds: float, camera: str, image: str | NDArrayU8) -> None:
Expand All @@ -451,10 +452,10 @@ def render_image(self, seconds: float, camera: str, image: str | NDArrayU8) -> N
"""
rr.set_time_seconds(self.config.timeline, seconds)

if isinstance(image, str):
rr.log(format_entity(self.config.ego_entity, camera), rr.ImageEncoded(path=image))
else:
rr.log(format_entity(self.config.ego_entity, camera), rr.Image(image))
entity_path = format_entity(self.config.ego_entity, camera)
entity = rr.ImageEncoded(path=image) if isinstance(image, str) else rr.Image(image)

rr.log(entity_path, entity)

@overload
def render_ego(self, ego_pose: EgoPose) -> None:
Expand Down Expand Up @@ -519,18 +520,12 @@ def _render_ego_without_schema(
),
)

entity_path = self.config.geocoordinate_entity
if geocoordinate is not None:
latitude, longitude, _ = geocoordinate
rr.log(
self.config.geocoordinate_entity,
rr.GeoPoints(lat_lon=(latitude, longitude)),
)
rr.log(entity_path, rr.GeoPoints(lat_lon=geocoordinate[:2]))
elif self.latlon is not None:
latitude, longitude = calculate_geodetic_point(translation, self.latlon)
rr.log(
self.config.geocoordinate_entity,
rr.GeoPoints(lat_lon=(latitude, longitude)),
)
rr.log(entity_path, rr.GeoPoints(lat_lon=(latitude, longitude)))

@overload
def render_calibration(
Expand Down
Loading