diff --git a/t4_devkit/helper/rendering.py b/t4_devkit/helper/rendering.py index fc576163..bbe094d4 100644 --- a/t4_devkit/helper/rendering.py +++ b/t4_devkit/helper/rendering.py @@ -59,16 +59,17 @@ def __init__(self, t4: Tier4) -> None: self._label2id: dict[str, int] = { category.name: category.index for category in self._t4.category } - self._sample_data_to_lidarseg_filename: dict[str, str] | None = ( - {lidarseg.sample_data_token: lidarseg.filename for lidarseg in self._t4.lidarseg} - if self._t4.lidarseg - else None - ) + self._sample_data_to_lidarseg_filename: dict[str, str] = { + lidarseg.sample_data_token: lidarseg.filename for lidarseg in self._t4.lidarseg + } self._executor = concurrent.futures.ThreadPoolExecutor() def _has_lidarseg(self) -> bool: - return self._sample_data_to_lidarseg_filename is not None + return bool(self._sample_data_to_lidarseg_filename) + + def _find_lidarseg_file(self, sample_data_token: str) -> str | None: + return self._sample_data_to_lidarseg_filename.get(sample_data_token) def _init_viewer( self, @@ -432,13 +433,10 @@ def _render_single_lidar(first_lidar_token: str) -> None: # render segmentation pointcloud if available, otherwise render raw pointcloud if color_mode == PointCloudColorMode.SEGMENTATION: - if not ( - self._has_lidarseg() - and sample_data.token in self._sample_data_to_lidarseg_filename - ): + label_filename = self._find_lidarseg_file(sample_data.token) + if label_filename is None: continue - label_filename = self._sample_data_to_lidarseg_filename[sample_data.token] pointcloud = SegmentationPointCloud.from_file( point_filepath=osp.join(self._t4.data_root, sample_data.filename), label_filepath=osp.join(self._t4.data_root, label_filename), diff --git a/t4_devkit/viewer/viewer.py b/t4_devkit/viewer/viewer.py index 8cf2bf95..c450eeb7 100644 --- a/t4_devkit/viewer/viewer.py +++ b/t4_devkit/viewer/viewer.py @@ -427,18 +427,19 @@ def render_pointcloud( # TODO(ktro2828): add support of rendering pointcloud on images rr.set_time_seconds(self.config.timeline, seconds) + entity_path = format_entity(self.config.ego_entity, channel) if color_mode == PointCloudColorMode.SEGMENTATION: - assert isinstance(pointcloud, SegmentationPointCloud) - rr.log( - format_entity(self.config.ego_entity, channel), - rr.Points3D(pointcloud.points[:3].T, class_ids=pointcloud.labels), - ) + if not isinstance(pointcloud, SegmentationPointCloud): + raise TypeError( + f"Expected SegmentationPointCloud instance, but got {type(pointcloud)}" + ) + + entity = rr.Points3D(pointcloud.points[:3].T, class_ids=pointcloud.labels) else: colors = pointcloud_color(pointcloud, color_mode=color_mode) - rr.log( - format_entity(self.config.ego_entity, channel), - rr.Points3D(pointcloud.points[:3].T, colors=colors), - ) + entity = rr.Points3D(pointcloud.points[:3].T, colors=colors) + + rr.log(entity_path, entity) @_check_spatial2d def render_image(self, seconds: float, camera: str, image: str | NDArrayU8) -> None: @@ -451,10 +452,10 @@ def render_image(self, seconds: float, camera: str, image: str | NDArrayU8) -> N """ rr.set_time_seconds(self.config.timeline, seconds) - if isinstance(image, str): - rr.log(format_entity(self.config.ego_entity, camera), rr.ImageEncoded(path=image)) - else: - rr.log(format_entity(self.config.ego_entity, camera), rr.Image(image)) + entity_path = format_entity(self.config.ego_entity, camera) + entity = rr.ImageEncoded(path=image) if isinstance(image, str) else rr.Image(image) + + rr.log(entity_path, entity) @overload def render_ego(self, ego_pose: EgoPose) -> None: @@ -519,18 +520,12 @@ def _render_ego_without_schema( ), ) + entity_path = self.config.geocoordinate_entity if geocoordinate is not None: - latitude, longitude, _ = geocoordinate - rr.log( - self.config.geocoordinate_entity, - rr.GeoPoints(lat_lon=(latitude, longitude)), - ) + rr.log(entity_path, rr.GeoPoints(lat_lon=geocoordinate[:2])) elif self.latlon is not None: latitude, longitude = calculate_geodetic_point(translation, self.latlon) - rr.log( - self.config.geocoordinate_entity, - rr.GeoPoints(lat_lon=(latitude, longitude)), - ) + rr.log(entity_path, rr.GeoPoints(lat_lon=(latitude, longitude))) @overload def render_calibration(