Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,7 @@ def __init__(self, buffer):

def indirect_jump(self, offset, byte_width):
"""Helper function to read the offset value and jump"""
unpack_str = ""
if byte_width == 1:
unpack_str = "<B"
elif byte_width == 4:
unpack_str = "<i"
assert unpack_str != ""
unpack_str = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[byte_width]
back_jump = struct.unpack(unpack_str, self.buffer[offset : offset + byte_width])[0]
return offset - back_jump

Expand All @@ -107,19 +102,26 @@ def decode_vector(self, end, size, byte_width):
# Each entry in the vector can have different datatype. Each entry is of fixed length. The
# format is a sequence of all values followed by a sequence of datatype of all values. For
# example - (4)(3.56)(int)(float) The end here points to the start of the values.
# Each type byte contains: (type << 2) | bit_width, where bit_width determines actual size.
values = list()
for i in range(0, size):
value_type_pos = end + size * byte_width + i
value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
value_bytes = self.buffer[end + i * byte_width : end + (i + 1) * byte_width]
value_type_packed = self.buffer[value_type_pos]
value_type = FlexBufferType(value_type_packed >> 2)
value_bit_width = BitWidth(value_type_packed & 3)
value_byte_width = 1 << value_bit_width
value_bytes = self.buffer[end + i * byte_width : end + i * byte_width + value_byte_width]
if value_type == FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
elif value_type == FlexBufferType.FBT_INT:
value = struct.unpack("<i", value_bytes)[0]
fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[value_byte_width]
value = struct.unpack(fmt, value_bytes)[0]
elif value_type == FlexBufferType.FBT_UINT:
value = struct.unpack("<I", value_bytes)[0]
fmt = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[value_byte_width]
value = struct.unpack(fmt, value_bytes)[0]
elif value_type == FlexBufferType.FBT_FLOAT:
value = struct.unpack("<f", value_bytes)[0]
fmt = {4: "<f", 8: "<d"}[value_byte_width]
value = struct.unpack(fmt, value_bytes)[0]
else:
raise Exception
values.append(value)
Expand All @@ -128,7 +130,8 @@ def decode_vector(self, end, size, byte_width):
def decode_map(self, end, byte_width, parent_byte_width):
"""Decodes the flexbuffer map and returns a dict"""
mid_loc = self.indirect_jump(end, parent_byte_width)
map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width : mid_loc])[0]
size_fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[byte_width]
map_size = struct.unpack(size_fmt, self.buffer[mid_loc - byte_width : mid_loc])[0]

# Find keys
keys_offset = mid_loc - byte_width * 3
Expand Down
223 changes: 165 additions & 58 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,7 +2832,9 @@ def convert_batch_matmul(self, op):
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in shape_b]
max_rank = max(rank_a, rank_b)

batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)]
batch_shape = [
max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 2)
]

a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
Expand Down Expand Up @@ -3225,16 +3227,49 @@ def convert_dequantize(self, op):

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
raise NotImplementedError(
"DETECTION_POSTPROCESS is not wired in this frontend yet: it still needs "
"Relax NMS / get_valid_counts / related vision helpers (see dead code below). "
"relax.vision.multibox_transform_loc exists; tracking: "
"https://github.com/apache/tvm/issues/18928"
)
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()

use_regular_nms = "use_regular_nms" in custom_options and custom_options["use_regular_nms"]
use_regular_nms = bool(custom_options.get("use_regular_nms", False))

required_attrs = [
"num_classes",
"max_detections",
"detections_per_class",
"nms_iou_threshold",
"nms_score_threshold",
"x_scale",
"y_scale",
"w_scale",
"h_scale",
]
missing_attrs = [key for key in required_attrs if key not in custom_options]
if missing_attrs:
raise ValueError(
"DETECTION_POSTPROCESS custom options miss required attributes: "
+ ", ".join(missing_attrs)
)

num_classes = int(custom_options["num_classes"])
max_detections = int(custom_options["max_detections"])
detections_per_class = int(custom_options["detections_per_class"])
iou_threshold = float(custom_options["nms_iou_threshold"])
score_threshold = float(custom_options["nms_score_threshold"])
x_scale = float(custom_options["x_scale"])
y_scale = float(custom_options["y_scale"])
w_scale = float(custom_options["w_scale"])
h_scale = float(custom_options["h_scale"])

if num_classes <= 0:
raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.")
if max_detections <= 0:
raise ValueError("DETECTION_POSTPROCESS requires max_detections > 0.")
if detections_per_class <= 0:
raise ValueError("DETECTION_POSTPROCESS requires detections_per_class > 0.")
if not 0.0 <= iou_threshold <= 1.0:
raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold in [0, 1].")
if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <= 0.0:
raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to be > 0.")

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
Expand Down Expand Up @@ -3296,67 +3331,139 @@ def convert_detection_postprocess(self, op):
# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = (
0.0 if use_regular_nms else custom_options["nms_score_threshold"]
)
multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms else score_threshold
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
1 / x_scale,
1 / y_scale,
1 / w_scale,
1 / h_scale,
)
multibox_transform_loc_attrs["keep_background"] = use_regular_nms

ret = relax.op.vision.multibox_transform_loc(
# reshape cls_pred so it can be consumed by
# multibox_transform_loc
relax.op.permute_dims(cls_pred, [0, 2, 1]),
loc_prob,
anchor_expr,
**multibox_transform_loc_attrs,
multibox_res = self.bb.emit(
relax.op.vision.multibox_transform_loc(
# reshape cls_pred so it can be consumed by
# multibox_transform_loc
relax.op.permute_dims(cls_pred, [0, 2, 1]),
loc_prob,
anchor_expr,
**multibox_transform_loc_attrs,
)
)
transformed_boxes = self.bb.emit(relax.TupleGetItem(multibox_res, 0))
transformed_scores = self.bb.emit(relax.TupleGetItem(multibox_res, 1))

if use_regular_nms:
nms_out = self.bb.emit(
relax.op.vision.all_class_non_max_suppression(
transformed_boxes,
transformed_scores,
relax.const(detections_per_class, "int64"),
relax.const(iou_threshold, "float32"),
relax.const(score_threshold, "float32"),
output_format="tensorflow",
)
)
selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
class_id_from_score = None
else:
topk_res = self.bb.emit(
relax.op.topk(transformed_scores, k=1, axis=1, ret_type="both", largest=True)
)
max_scores = self.bb.emit(relax.TupleGetItem(topk_res, 0))
class_id_from_score = self.bb.emit(relax.TupleGetItem(topk_res, 1))
nms_out = self.bb.emit(
relax.op.vision.all_class_non_max_suppression(
transformed_boxes,
max_scores,
relax.const(max_detections, "int64"),
relax.const(iou_threshold, "float32"),
relax.const(score_threshold, "float32"),
output_format="tensorflow",
)
)
selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
class_id_from_score = relax.op.squeeze(class_id_from_score, axis=[1])

selected_score_slots = selected_scores.struct_info.shape.values[1]
selected_detection_positions = relax.op.expand_dims(
relax.op.arange(selected_score_slots, dtype="int64"), axis=0
)
selected_valid_detection_mask = relax.op.less(
selected_detection_positions, relax.op.expand_dims(num_detections, axis=1)
)
masked_selected_scores = relax.op.where(
selected_valid_detection_mask,
selected_scores,
relax.const(-1.0, "float32"),
)
topk_scores_res = self.bb.emit(
relax.op.topk(
masked_selected_scores, k=max_detections, axis=1, ret_type="both", largest=True
)
)
detection_scores = self.bb.emit(relax.TupleGetItem(topk_scores_res, 0))
top_positions = self.bb.emit(relax.TupleGetItem(topk_scores_res, 1))
num_detections = relax.op.minimum(
num_detections, relax.const([max_detections], dtype="int64")
)
detection_positions = relax.op.expand_dims(
relax.op.arange(max_detections, dtype="int64"), axis=0
Comment thread
Aharrypotter marked this conversation as resolved.
)
valid_detection_mask = relax.op.less(
detection_positions, relax.op.expand_dims(num_detections, axis=1)
)
top_positions_expanded = relax.op.expand_dims(top_positions, axis=2)
top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2, axis=2)
top_index_pairs = relax.op.gather_elements(
selected_indices, top_positions_for_pairs, axis=1
)
top_box_ids = relax.op.squeeze(
relax.op.strided_slice(top_index_pairs, axes=[2], begin=[1], end=[2]),
axis=[2],
)
top_box_ids_for_gather = relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2)
detection_boxes = relax.op.gather_nd(
transformed_boxes, top_box_ids_for_gather, batch_dims=1
)

if use_regular_nms:
# box coordinates need to be converted from ltrb to (ymin, xmin, ymax, xmax)
_, transformed_boxes = relax.op.split(ret[0], (2,), axis=2)
box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4, axis=2)
transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r], axis=2)

return relax.op.vision.regular_non_max_suppression(
boxes=transformed_boxes,
scores=cls_pred,
max_detections_per_class=custom_options["detections_per_class"],
max_detections=custom_options["max_detections"],
num_classes=custom_options["num_classes"],
iou_threshold=custom_options["nms_iou_threshold"],
score_threshold=custom_options["nms_score_threshold"],
detection_classes = relax.op.squeeze(
relax.op.strided_slice(top_index_pairs, axes=[2], begin=[0], end=[1]),
axis=[2],
)
detection_classes = relax.op.astype(detection_classes, "int32")
else:
top_box_ids_for_class = relax.op.expand_dims(
relax.op.astype(top_box_ids, "int64"), axis=2
)
detection_classes = relax.op.gather_nd(
class_id_from_score, top_box_ids_for_class, batch_dims=1
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
non_max_suppression_attrs["return_indices"] = False
non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"]
non_max_suppression_attrs["force_suppress"] = True
non_max_suppression_attrs["top_k"] = anchor_boxes
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = relax.op.vision.non_max_suppression(
ret[0], ret[1], ret[1], **non_max_suppression_attrs
detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2)
detection_boxes = relax.op.where(
detection_mask,
detection_boxes,
relax.op.zeros((batch_size, max_detections, 4), dtype="float32"),
)
detection_classes = relax.op.where(
valid_detection_mask,
detection_classes,
relax.op.zeros((batch_size, max_detections), dtype="int32"),
)
ret = relax.op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# keep only the top 'max_detections' rows
ret = relax.op.strided_slice(
ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"], 6]
detection_scores = relax.op.where(
valid_detection_mask,
detection_scores,
relax.op.zeros((batch_size, max_detections), dtype="float32"),
)
# the output needs some reshaping to match tflite
ret = relax.op.split(ret, 6, axis=2)
cls_ids = relax.op.reshape(ret[0], [batch_size, -1])
scores = relax.op.reshape(ret[1], [batch_size, -1])
boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2)
ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret
detection_classes = relax.op.astype(detection_classes, "float32")
num_detections = relax.op.astype(num_detections, "float32")
return relax.Tuple([detection_boxes, detection_classes, detection_scores, num_detections])

def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
Expand Down
19 changes: 12 additions & 7 deletions python/tvm/relax/transform/legalize_ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E

Returns
-------
result : Tuple[Tensor, Tensor]
A tuple of (trimmed_indices, num_total_detections) where:
- trimmed_indices: Tensor of shape (num_total_detections, 3) containing only
valid detection indices (batch_id, class_id, box_id)
- num_total_detections: Tensor of shape (1,) with the count of valid detections
result : Expr
The legalized NMS result.

- For ONNX output format, returns a tuple of
`(trimmed_indices, num_total_detections)`, where `trimmed_indices`
contains only valid detection indices.
- For TensorFlow output format, returns the TOPI result directly to
preserve the `(selected_indices, selected_scores, num_detections)`
layout expected by the Relax op.
"""
boxes = call.args[0]
scores = call.args[1]
Expand Down Expand Up @@ -69,8 +73,9 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E
output_format,
)

# Dynamic output trimming using dynamic_strided_slice
# Extract selected_indices and num_total_detections from the NMS result
if output_format == "tensorflow":
return nms_result

selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))

Expand Down
Loading
Loading