diff --git a/extension/llm/runner/test/test_util.cpp b/extension/llm/runner/test/test_util.cpp index 3d66c212375..d99fac3099a 100644 --- a/extension/llm/runner/test/test_util.cpp +++ b/extension/llm/runner/test/test_util.cpp @@ -18,6 +18,8 @@ namespace { using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::llm::convert_to_bfloat16; +using ::executorch::extension::llm::stop_safe_prefix_len; +using ::executorch::extension::llm::utf8_complete_prefix_len; class ConvertToBFloat16Test : public ::testing::Test { protected: @@ -63,4 +65,94 @@ TEST_F(ConvertToBFloat16Test, RejectsNonFloatTensor) { EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument); } +TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) { + EXPECT_EQ(utf8_complete_prefix_len(""), 0u); + EXPECT_EQ(utf8_complete_prefix_len("ascii"), 5u); + + // Complete multi-byte characters are fully consumed. + EXPECT_EQ(utf8_complete_prefix_len("\xc3\xa9"), 2u); // é (2-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac"), 3u); // € (3-byte) + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98\x80"), 4u); // 😀 (4-byte) + + // A character split across the end is held back (not counted). + EXPECT_EQ(utf8_complete_prefix_len("\xc3"), 0u); // 1/2 of é + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82"), 0u); // 2/3 of € + EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98"), 0u); // 3/4 of 😀 + + // A complete prefix followed by a split character keeps the complete part. + EXPECT_EQ(utf8_complete_prefix_len("hi\xe2\x82"), 2u); + EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac\xf0\x9f"), 3u); + + // An invalid lead byte counts as length 1 (emitted, not stalled). + EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u); +} + +TEST(StopSafePrefixLenTest, NoStopsEmitsEverything) { + bool hit = true; + EXPECT_EQ(stop_safe_prefix_len("hello world", {}, hit), 11u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, SingleByteStopMissEmitsEverything) { + bool hit = true; + const std::string text = "caf\xc3\xa9"; + EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), text.size()); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, EmptyStopsDoNotHoldBack) { + bool hit = true; + EXPECT_EQ(stop_safe_prefix_len("hello", {""}, hit), 5u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, StopFoundReturnsEarliestOffsetAndExcludesIt) { + bool hit = false; + // "STOP" begins at offset 6; emit "Hello " (6 bytes), drop the stop and rest. + EXPECT_EQ(stop_safe_prefix_len("Hello STOP there", {"STOP"}, hit), 6u); + EXPECT_TRUE(hit); + // Earliest of several wins. + hit = false; + EXPECT_EQ(stop_safe_prefix_len("aXbY", {"Y", "X"}, hit), 1u); + EXPECT_TRUE(hit); +} + +TEST(StopSafePrefixLenTest, EarliestStopWinsEvenWhenLongerStopSetsHoldBack) { + bool hit = false; + EXPECT_EQ(stop_safe_prefix_len("abcXtail", {"LONGSTOP", "X"}, hit), 3u); + EXPECT_TRUE(hit); +} + +TEST(StopSafePrefixLenTest, HoldsBackPossiblePartialStopTail) { + bool hit = false; + // No full stop yet, but the trailing "ST" could become "STOP": hold back + // len("STOP")-1 == 3 bytes, so of "hi ST" (5 bytes) only "hi" (2) is safe. + EXPECT_EQ(stop_safe_prefix_len("hi ST", {"STOP"}, hit), 2u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, HoldBackSnapsToUtf8Boundary) { + bool hit = false; + // "ab" + "€"(3 bytes). Stop "XX" => hold back 1 byte, which would land inside + // the euro sign; snap down so the multi-byte char isn't split. + const std::string text = "ab\xe2\x82\xac"; + const size_t safe = stop_safe_prefix_len(text, {"XX"}, hit); + EXPECT_FALSE(hit); + EXPECT_EQ(safe, 2u); // only "ab"; the € is held whole +} + +TEST(StopSafePrefixLenTest, HoldBackWithIncompleteUtf8TailSnapsToBoundary) { + bool hit = false; + const std::string text = "ab\xe2\x82"; + EXPECT_EQ(stop_safe_prefix_len(text, {"XX"}, hit), 2u); + EXPECT_FALSE(hit); +} + +TEST(StopSafePrefixLenTest, HoldZeroDoesNotEmitDanglingUtf8LeadByte) { + bool hit = false; + const std::string text = "ab\xc3"; + EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), 2u); + EXPECT_FALSE(hit); +} + } // namespace diff --git a/extension/llm/runner/util.h b/extension/llm/runner/util.h index da15b60890b..f5bcb945dbf 100644 --- a/extension/llm/runner/util.h +++ b/extension/llm/runner/util.h @@ -13,7 +13,9 @@ #include #include #include +#include #include +#include #include #if defined(__linux__) || defined(__ANDROID__) || defined(__unix__) #include @@ -81,6 +83,116 @@ ET_EXPERIMENTAL void inline safe_printf(const char* piece) { printf("%s", piece); } +// Length of the longest prefix of `s` that does not end in the middle of a +// UTF-8 multi-byte sequence. A byte-level tokenizer can emit a token that is +// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji), +// so a caller streaming text must hold the incomplete tail until it completes +// rather than decode the partial bytes. An invalid lead byte counts as length 1 +// (emitted, so the caller can replace it) rather than stalling output. +ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) { + size_t i = 0; + const size_t n = s.size(); + while (i < n) { + const unsigned char c = static_cast(s[i]); + size_t len; + if (c < 0x80) { + len = 1; + } else if ((c >> 5) == 0x6) { + len = 2; + } else if ((c >> 4) == 0xE) { + len = 3; + } else if ((c >> 3) == 0x1E) { + len = 4; + } else { + len = 1; // invalid lead byte; emit it and let the caller replace it + } + if (i + len > n) { + break; // incomplete trailing sequence: hold it for more bytes + } + i += len; + } + return i; +} + +ET_EXPERIMENTAL size_t inline utf8_safe_prefix_len( + const std::string& s, + size_t len) { + len = std::min(len, s.size()); + if (len == 0) { + return 0; + } + const auto* data = reinterpret_cast(s.data()); + size_t i = len; + while (i > 0 && (data[i - 1] & 0xC0) == 0x80) { + --i; + } + if (i == 0) { + return 0; + } + const size_t lead_pos = i - 1; + const unsigned char lead = data[lead_pos]; + size_t need = 0; + if (lead < 0x80) { + need = 1; + } else if ((lead & 0xE0) == 0xC0) { + need = 2; + } else if ((lead & 0xF0) == 0xE0) { + need = 3; + } else if ((lead & 0xF8) == 0xF0) { + need = 4; + } else { + return lead_pos; + } + return len - lead_pos == need ? len : lead_pos; +} + +// How many leading bytes of `text` a streaming consumer may safely emit given a +// set of `stops` strings, and whether a stop was hit (`stop_hit`). +// * If any stop occurs, returns the byte offset of the EARLIEST occurrence +// and +// sets stop_hit=true — text before it is safe; the stop and everything +// after are dropped (the stop is excluded from output). +// * Otherwise returns the length minus the longest possible partial-stop tail +// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a +// multi-byte character is never split; stop_hit=false. Holding back that +// conservative tail lets a stop that straddles the next piece still be +// caught without suffix-prefix matching each stop. +// `text` is expected to be complete-UTF-8 (e.g. the assembled output of +// utf8_complete_prefix_len) and stops are expected to be real text, so a found +// stop offset cannot split a UTF-8 character. Empty `stops` => emit everything, +// no hold-back. +ET_EXPERIMENTAL size_t inline stop_safe_prefix_len( + const std::string& text, + const std::vector& stops, + bool& stop_hit) { + stop_hit = false; + if (stops.empty()) { + return text.size(); + } + size_t earliest = std::string::npos; + size_t max_len = 0; + for (const auto& s : stops) { + if (s.empty()) { + continue; + } + max_len = std::max(max_len, s.size()); + const size_t p = text.find(s); + if (p != std::string::npos && + (earliest == std::string::npos || p < earliest)) { + earliest = p; + } + } + if (earliest != std::string::npos) { + stop_hit = true; + return earliest; + } + const size_t hold = max_len > 0 ? max_len - 1 : 0; + if (text.size() <= hold) { + return 0; + } + return utf8_safe_prefix_len(text, text.size() - hold); +} + // ---------------------------------------------------------------------------- // utilities: time