Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions extension/llm/runner/test/test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace {
using ::executorch::aten::ScalarType;
using ::executorch::extension::make_tensor_ptr;
using ::executorch::extension::llm::convert_to_bfloat16;
using ::executorch::extension::llm::stop_safe_prefix_len;
using ::executorch::extension::llm::utf8_complete_prefix_len;

class ConvertToBFloat16Test : public ::testing::Test {
protected:
Expand Down Expand Up @@ -63,4 +65,94 @@ TEST_F(ConvertToBFloat16Test, RejectsNonFloatTensor) {
EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument);
}

TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) {
EXPECT_EQ(utf8_complete_prefix_len(""), 0u);
EXPECT_EQ(utf8_complete_prefix_len("ascii"), 5u);

// Complete multi-byte characters are fully consumed.
EXPECT_EQ(utf8_complete_prefix_len("\xc3\xa9"), 2u); // é (2-byte)
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac"), 3u); // € (3-byte)
EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98\x80"), 4u); // 😀 (4-byte)

// A character split across the end is held back (not counted).
EXPECT_EQ(utf8_complete_prefix_len("\xc3"), 0u); // 1/2 of é
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82"), 0u); // 2/3 of €
EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98"), 0u); // 3/4 of 😀

// A complete prefix followed by a split character keeps the complete part.
EXPECT_EQ(utf8_complete_prefix_len("hi\xe2\x82"), 2u);
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac\xf0\x9f"), 3u);

// An invalid lead byte counts as length 1 (emitted, not stalled).
EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u);
}
Comment thread
mergennachin marked this conversation as resolved.

TEST(StopSafePrefixLenTest, NoStopsEmitsEverything) {
bool hit = true;
EXPECT_EQ(stop_safe_prefix_len("hello world", {}, hit), 11u);
EXPECT_FALSE(hit);
}
Comment thread
mergennachin marked this conversation as resolved.

TEST(StopSafePrefixLenTest, SingleByteStopMissEmitsEverything) {
bool hit = true;
const std::string text = "caf\xc3\xa9";
EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), text.size());
EXPECT_FALSE(hit);
}

TEST(StopSafePrefixLenTest, EmptyStopsDoNotHoldBack) {
bool hit = true;
EXPECT_EQ(stop_safe_prefix_len("hello", {""}, hit), 5u);
EXPECT_FALSE(hit);
}

TEST(StopSafePrefixLenTest, StopFoundReturnsEarliestOffsetAndExcludesIt) {
bool hit = false;
// "STOP" begins at offset 6; emit "Hello " (6 bytes), drop the stop and rest.
EXPECT_EQ(stop_safe_prefix_len("Hello STOP there", {"STOP"}, hit), 6u);
EXPECT_TRUE(hit);
// Earliest of several wins.
hit = false;
EXPECT_EQ(stop_safe_prefix_len("aXbY", {"Y", "X"}, hit), 1u);
EXPECT_TRUE(hit);
}

TEST(StopSafePrefixLenTest, EarliestStopWinsEvenWhenLongerStopSetsHoldBack) {
bool hit = false;
EXPECT_EQ(stop_safe_prefix_len("abcXtail", {"LONGSTOP", "X"}, hit), 3u);
EXPECT_TRUE(hit);
}

TEST(StopSafePrefixLenTest, HoldsBackPossiblePartialStopTail) {
bool hit = false;
// No full stop yet, but the trailing "ST" could become "STOP": hold back
// len("STOP")-1 == 3 bytes, so of "hi ST" (5 bytes) only "hi" (2) is safe.
EXPECT_EQ(stop_safe_prefix_len("hi ST", {"STOP"}, hit), 2u);
EXPECT_FALSE(hit);
}

TEST(StopSafePrefixLenTest, HoldBackSnapsToUtf8Boundary) {
bool hit = false;
// "ab" + "€"(3 bytes). Stop "XX" => hold back 1 byte, which would land inside
// the euro sign; snap down so the multi-byte char isn't split.
const std::string text = "ab\xe2\x82\xac";
const size_t safe = stop_safe_prefix_len(text, {"XX"}, hit);
EXPECT_FALSE(hit);
EXPECT_EQ(safe, 2u); // only "ab"; the € is held whole
}

TEST(StopSafePrefixLenTest, HoldBackWithIncompleteUtf8TailSnapsToBoundary) {
bool hit = false;
const std::string text = "ab\xe2\x82";
EXPECT_EQ(stop_safe_prefix_len(text, {"XX"}, hit), 2u);
EXPECT_FALSE(hit);
}

TEST(StopSafePrefixLenTest, HoldZeroDoesNotEmitDanglingUtf8LeadByte) {
bool hit = false;
const std::string text = "ab\xc3";
EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), 2u);
EXPECT_FALSE(hit);
}

} // namespace
112 changes: 112 additions & 0 deletions extension/llm/runner/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include <executorch/runtime/platform/compiler.h>
#include <stdio.h>
#include <time.h>
#include <algorithm>
#include <cctype>
#include <string>
#include <vector>
#if defined(__linux__) || defined(__ANDROID__) || defined(__unix__)
#include <sys/resource.h>
Expand Down Expand Up @@ -81,6 +83,116 @@ ET_EXPERIMENTAL void inline safe_printf(const char* piece) {
printf("%s", piece);
}

// Length of the longest prefix of `s` that does not end in the middle of a
// UTF-8 multi-byte sequence. A byte-level tokenizer can emit a token that is
// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji),
// so a caller streaming text must hold the incomplete tail until it completes
// rather than decode the partial bytes. An invalid lead byte counts as length 1
// (emitted, so the caller can replace it) rather than stalling output.
ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) {
size_t i = 0;
const size_t n = s.size();
while (i < n) {
const unsigned char c = static_cast<unsigned char>(s[i]);
size_t len;
if (c < 0x80) {
len = 1;
} else if ((c >> 5) == 0x6) {
len = 2;
} else if ((c >> 4) == 0xE) {
len = 3;
} else if ((c >> 3) == 0x1E) {
len = 4;
} else {
len = 1; // invalid lead byte; emit it and let the caller replace it
}
Comment thread
mergennachin marked this conversation as resolved.
if (i + len > n) {
break; // incomplete trailing sequence: hold it for more bytes
}
i += len;
}
return i;
}

ET_EXPERIMENTAL size_t inline utf8_safe_prefix_len(
const std::string& s,
size_t len) {
len = std::min(len, s.size());
if (len == 0) {
return 0;
}
const auto* data = reinterpret_cast<const unsigned char*>(s.data());
size_t i = len;
while (i > 0 && (data[i - 1] & 0xC0) == 0x80) {
--i;
}
if (i == 0) {
return 0;
}
const size_t lead_pos = i - 1;
const unsigned char lead = data[lead_pos];
size_t need = 0;
if (lead < 0x80) {
need = 1;
} else if ((lead & 0xE0) == 0xC0) {
need = 2;
} else if ((lead & 0xF0) == 0xE0) {
need = 3;
} else if ((lead & 0xF8) == 0xF0) {
need = 4;
} else {
return lead_pos;
}
return len - lead_pos == need ? len : lead_pos;
}

// How many leading bytes of `text` a streaming consumer may safely emit given a
// set of `stops` strings, and whether a stop was hit (`stop_hit`).
// * If any stop occurs, returns the byte offset of the EARLIEST occurrence
// and
// sets stop_hit=true — text before it is safe; the stop and everything
// after are dropped (the stop is excluded from output).
// * Otherwise returns the length minus the longest possible partial-stop tail
// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a
// multi-byte character is never split; stop_hit=false. Holding back that
// conservative tail lets a stop that straddles the next piece still be
// caught without suffix-prefix matching each stop.
// `text` is expected to be complete-UTF-8 (e.g. the assembled output of
// utf8_complete_prefix_len) and stops are expected to be real text, so a found
// stop offset cannot split a UTF-8 character. Empty `stops` => emit everything,
// no hold-back.
Comment thread
mergennachin marked this conversation as resolved.
ET_EXPERIMENTAL size_t inline stop_safe_prefix_len(
const std::string& text,
const std::vector<std::string>& stops,
bool& stop_hit) {
stop_hit = false;
if (stops.empty()) {
return text.size();
}
size_t earliest = std::string::npos;
size_t max_len = 0;
for (const auto& s : stops) {
if (s.empty()) {
continue;
}
max_len = std::max(max_len, s.size());
const size_t p = text.find(s);
if (p != std::string::npos &&
(earliest == std::string::npos || p < earliest)) {
earliest = p;
}
}
if (earliest != std::string::npos) {
stop_hit = true;
return earliest;
}
const size_t hold = max_len > 0 ? max_len - 1 : 0;
if (text.size() <= hold) {
return 0;
}
return utf8_safe_prefix_len(text, text.size() - hold);
}

// ----------------------------------------------------------------------------
// utilities: time

Expand Down
Loading