diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 01524ccf0ed0a9..4a43972fdd970f 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -1399,7 +1399,8 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); - factory.register_function(); + factory.register_function>(); + factory.register_function>(); factory.register_function>(); factory.register_function< FunctionCountSubString>(); diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 943bf740c8e6f6..ffed3641797d3a 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -2032,39 +2032,17 @@ class FunctionSubstringIndex : public IFunction { } }; -class FunctionSplitByString : public IFunction { +class SplitByStringExecutor { public: - static constexpr auto name = "split_by_string"; - - static FunctionPtr create() { return std::make_shared(); } using NullMapType = PaddedPODArray; - String get_name() const override { return name; } - - bool is_variadic() const override { return false; } - - size_t get_number_of_arguments() const override { return 2; } - - DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - DCHECK(is_string_type(arguments[0]->get_primitive_type())) - << "first argument for function: " << name << " should be string" - << " and arguments[0] is " << arguments[0]->get_name(); - DCHECK(is_string_type(arguments[1]->get_primitive_type())) - << "second argument for function: " << name << " should be string" - << " and arguments[1] is " << arguments[1]->get_name(); - return std::make_shared(make_nullable(arguments[0])); - } - - Status execute_impl(FunctionContext* /*context*/, Block& block, const ColumnNumbers& arguments, - uint32_t result, size_t input_rows_count) const override { - DCHECK_EQ(arguments.size(), 2); - + static Status execute_core(Block& block, const ColumnNumbers& arguments, uint32_t result, + size_t input_rows_count, Int32 limit_value) { const auto& [src_column, left_const] = unpack_if_const(block.get_by_position(arguments[0]).column); const auto& [right_column, right_const] = unpack_if_const(block.get_by_position(arguments[1]).column); - DataTypePtr right_column_type = block.get_by_position(arguments[1]).type; DataTypePtr src_column_type = block.get_by_position(arguments[0]).type; auto dest_column_ptr = ColumnArray::create(make_nullable(src_column_type)->create_column(), ColumnArray::ColumnOffsets::create()); @@ -2076,14 +2054,13 @@ class FunctionSplitByString : public IFunction { auto* dest_nested_column = dest_nullable_col.get_nested_column_ptr().get(); const auto* col_str = assert_cast(src_column.get()); - const auto* col_delimiter = assert_cast(right_column.get()); std::visit( [&](auto src_const, auto delimiter_const) { _execute(*col_str, *col_delimiter, *dest_nested_column, dest_offsets, - input_rows_count); + input_rows_count, limit_value); }, vectorized::make_bool_variant(left_const), vectorized::make_bool_variant(right_const)); @@ -2098,9 +2075,9 @@ class FunctionSplitByString : public IFunction { private: template - void _execute(const ColumnString& src_column_string, const ColumnString& delimiter_column, - IColumn& dest_nested_column, ColumnArray::Offsets64& dest_offsets, - size_t size) const { + static void _execute(const ColumnString& src_column_string, + const ColumnString& delimiter_column, IColumn& dest_nested_column, + ColumnArray::Offsets64& dest_offsets, size_t size, Int32 limit_value) { auto& dest_column_string = assert_cast(dest_nested_column); ColumnString::Chars& column_string_chars = dest_column_string.get_chars(); ColumnString::Offsets& column_string_offsets = dest_column_string.get_offsets(); @@ -2129,12 +2106,29 @@ class FunctionSplitByString : public IFunction { } if (delimiter_ref.size == 0) { split_empty_delimiter(str_ref, column_string_chars, column_string_offsets, - string_pos, dest_pos); + string_pos, dest_pos, limit_value); } else { if constexpr (!delimiter_const) { search.set_pattern(&delimiter_ref); } + Int32 split_count = 0; for (size_t str_pos = 0; str_pos <= str_ref.size;) { + // If limit reached, dump remainder as final token + if (limit_value > 0 && split_count == limit_value - 1) { + const size_t remaining = str_ref.size - str_pos; + const size_t old_size = column_string_chars.size(); + if (remaining > 0) { + const size_t new_size = old_size + remaining; + column_string_chars.resize(new_size); + memcpy_small_allow_read_write_overflow15( + column_string_chars.data() + old_size, str_ref.data + str_pos, + remaining); + string_pos += remaining; + } + column_string_offsets.push_back(string_pos); + dest_pos++; + break; + } const size_t str_offset = str_pos; const size_t old_size = column_string_chars.size(); // search first match delimter_ref index from src string among str_offset to end @@ -2155,6 +2149,7 @@ class FunctionSplitByString : public IFunction { column_string_offsets.push_back(string_pos); // array offset + 1 dest_pos++; + split_count++; // add src string str_pos to next search start str_pos += split_part_size + delimiter_ref.size; } @@ -2163,44 +2158,143 @@ class FunctionSplitByString : public IFunction { } } - void split_empty_delimiter(const StringRef& str_ref, ColumnString::Chars& column_string_chars, - ColumnString::Offsets& column_string_offsets, - ColumnArray::Offset64& string_pos, - ColumnArray::Offset64& dest_pos) const { + static void split_empty_delimiter(const StringRef& str_ref, + ColumnString::Chars& column_string_chars, + ColumnString::Offsets& column_string_offsets, + ColumnArray::Offset64& string_pos, + ColumnArray::Offset64& dest_pos, Int32 limit_value) { const size_t old_size = column_string_chars.size(); const size_t new_size = old_size + str_ref.size; column_string_chars.resize(new_size); memcpy(column_string_chars.data() + old_size, str_ref.data, str_ref.size); - if (simd::VStringFunctions::is_ascii(str_ref)) { - const auto size = str_ref.size; - - const auto nested_old_size = column_string_offsets.size(); - const auto nested_new_size = nested_old_size + size; - column_string_offsets.resize(nested_new_size); - std::iota(column_string_offsets.data() + nested_old_size, - column_string_offsets.data() + nested_new_size, string_pos + 1); - - string_pos += size; - dest_pos += size; - // The above code is equivalent to the code in the following comment. - // for (size_t i = 0; i < str_ref.size; i++) { - // string_pos++; - // column_string_offsets.push_back(string_pos); - // (*dest_nested_null_map).push_back(false); - // dest_pos++; - // } + + if (limit_value > 0) { + // With limit: split character by character up to limit-1, then remainder + Int32 split_count = 0; + size_t i = 0; + if (simd::VStringFunctions::is_ascii(str_ref)) { + for (; i < str_ref.size; i++) { + if (split_count == limit_value - 1) { + // remainder + string_pos += str_ref.size - i; + column_string_offsets.push_back(string_pos); + dest_pos++; + return; + } + string_pos++; + column_string_offsets.push_back(string_pos); + dest_pos++; + split_count++; + } + } else { + for (size_t utf8_char_len = 0; i < str_ref.size; i += utf8_char_len) { + utf8_char_len = UTF8_BYTE_LENGTH[(unsigned char)str_ref.data[i]]; + if (split_count == limit_value - 1) { + // remainder + string_pos += str_ref.size - i; + column_string_offsets.push_back(string_pos); + dest_pos++; + return; + } + string_pos += utf8_char_len; + column_string_offsets.push_back(string_pos); + dest_pos++; + split_count++; + } + } } else { - for (size_t i = 0, utf8_char_len = 0; i < str_ref.size; i += utf8_char_len) { - utf8_char_len = UTF8_BYTE_LENGTH[(unsigned char)str_ref.data[i]]; + // No limit: original behavior + if (simd::VStringFunctions::is_ascii(str_ref)) { + const auto size = str_ref.size; + + const auto nested_old_size = column_string_offsets.size(); + const auto nested_new_size = nested_old_size + size; + column_string_offsets.resize(nested_new_size); + std::iota(column_string_offsets.data() + nested_old_size, + column_string_offsets.data() + nested_new_size, string_pos + 1); + + string_pos += size; + dest_pos += size; + } else { + for (size_t i = 0, utf8_char_len = 0; i < str_ref.size; i += utf8_char_len) { + utf8_char_len = UTF8_BYTE_LENGTH[(unsigned char)str_ref.data[i]]; - string_pos += utf8_char_len; - column_string_offsets.push_back(string_pos); - dest_pos++; + string_pos += utf8_char_len; + column_string_offsets.push_back(string_pos); + dest_pos++; + } } } } }; +struct SplitByStringTwoArgImpl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared()}; + } + + static Status execute_impl(FunctionContext* /*context*/, Block& block, + const ColumnNumbers& arguments, uint32_t result, + size_t input_rows_count) { + DCHECK_EQ(arguments.size(), 2); + return SplitByStringExecutor::execute_core(block, arguments, result, input_rows_count, -1); + } +}; + +struct SplitByStringThreeArgImpl { + static DataTypes get_variadic_argument_types() { + return {std::make_shared(), std::make_shared(), + std::make_shared()}; + } + + static Status execute_impl(FunctionContext* /*context*/, Block& block, + const ColumnNumbers& arguments, uint32_t result, + size_t input_rows_count) { + DCHECK_EQ(arguments.size(), 3); + const auto& [limit_column, limit_is_const] = + unpack_if_const(block.get_by_position(arguments[2]).column); + DCHECK(limit_is_const) << "limit argument of split_by_string must be a constant"; + auto limit_value = assert_cast(*limit_column).get_element(0); + return SplitByStringExecutor::execute_core(block, arguments, result, input_rows_count, + limit_value); + } +}; + +template +class FunctionSplitByString : public IFunction { +public: + static constexpr auto name = "split_by_string"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + bool is_variadic() const override { return true; } + + size_t get_number_of_arguments() const override { + return get_variadic_argument_types_impl().size(); + } + + DataTypes get_variadic_argument_types_impl() const override { + return Impl::get_variadic_argument_types(); + } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + DCHECK(is_string_type(arguments[0]->get_primitive_type())) + << "first argument for function: " << name << " should be string" + << " and arguments[0] is " << arguments[0]->get_name(); + DCHECK(is_string_type(arguments[1]->get_primitive_type())) + << "second argument for function: " << name << " should be string" + << " and arguments[1] is " << arguments[1]->get_name(); + return std::make_shared(make_nullable(arguments[0])); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + uint32_t result, size_t input_rows_count) const override { + return Impl::execute_impl(context, block, arguments, result, input_rows_count); + } +}; + enum class FunctionCountSubStringType { TWO_ARGUMENTS, THREE_ARGUMENTS }; template diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 97e6f47ef957bd..a94e0f408e1753 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -23,6 +23,10 @@ #include "function_test_util.h" #include "util/encryption_util.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_string.h" #include "vec/core/field.h" #include "vec/core/types.h" #include "vec/data_types/data_type_number.h" @@ -3869,4 +3873,206 @@ TEST(function_string_test, function_unicode_normalize_invalid_mode) { EXPECT_NE(Status::OK(), st); } +// Helper: run split_by_string with 3 args (str, delimiter, limit) on a single row +// Returns the result column (Array>) +static ColumnPtr run_split_by_string_3arg(const std::string& str, const std::string& delimiter, + Int32 limit_val) { + Block block; + auto str_type = std::make_shared(); + auto int_type = std::make_shared(); + auto ret_type = std::make_shared(make_nullable(str_type)); + + // Build input columns with one row each + auto str_col = ColumnString::create(); + str_col->insert_data(str.data(), str.size()); + auto delim_col = ColumnString::create(); + delim_col->insert_data(delimiter.data(), delimiter.size()); + auto limit_col = ColumnInt32::create(); + limit_col->insert_value(limit_val); + auto const_limit_col = ColumnConst::create(std::move(limit_col), 1); + + block.insert({std::move(str_col), str_type, "str"}); + block.insert({std::move(delim_col), str_type, "delim"}); + block.insert({std::move(const_limit_col), int_type, "limit"}); + block.insert({nullptr, ret_type, "result"}); + + ColumnsWithTypeAndName arguments = {block.get_by_position(0), block.get_by_position(1), + block.get_by_position(2)}; + auto func = + SimpleFunctionFactory::instance().get_function("split_by_string", arguments, ret_type); + EXPECT_TRUE(func != nullptr); + auto st = func->execute(nullptr, block, {0, 1, 2}, 3, 1); + EXPECT_EQ(Status::OK(), st); + + return block.get_by_position(3).column; +} + +// Helper: run split_by_string with 2 args (str, delimiter) on a single row +static ColumnPtr run_split_by_string_2arg(const std::string& str, const std::string& delimiter) { + Block block; + auto str_type = std::make_shared(); + auto ret_type = std::make_shared(make_nullable(str_type)); + + auto str_col = ColumnString::create(); + str_col->insert_data(str.data(), str.size()); + auto delim_col = ColumnString::create(); + delim_col->insert_data(delimiter.data(), delimiter.size()); + + block.insert({std::move(str_col), str_type, "str"}); + block.insert({std::move(delim_col), str_type, "delim"}); + block.insert({nullptr, ret_type, "result"}); + + ColumnsWithTypeAndName arguments = {block.get_by_position(0), block.get_by_position(1)}; + auto func = + SimpleFunctionFactory::instance().get_function("split_by_string", arguments, ret_type); + EXPECT_TRUE(func != nullptr); + auto st = func->execute(nullptr, block, {0, 1}, 2, 1); + EXPECT_EQ(Status::OK(), st); + + return block.get_by_position(2).column; +} + +// Helper: extract array elements as vector of strings from row 0 of an array column +static std::vector get_array_strings(const ColumnPtr& col) { + const auto* array_col = assert_cast(col.get()); + const auto& offsets = array_col->get_offsets(); + size_t start = 0; + size_t end = offsets[0]; + + std::vector result; + const auto& nested = array_col->get_data(); + // nested is ColumnNullable + const auto* nullable_col = assert_cast(&nested); + const auto* str_col = assert_cast(&nullable_col->get_nested_column()); + + for (size_t i = start; i < end; i++) { + auto ref = str_col->get_data_at(i); + result.emplace_back(ref.data, ref.size); + } + return result; +} + +TEST(function_string_test, function_split_by_string_with_limit_test) { + // Basic limit functionality + { + auto col = run_split_by_string_3arg("one,two,three,", ",", 2); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 2); + EXPECT_EQ(arr[0], "one"); + EXPECT_EQ(arr[1], "two,three,"); + } + { + auto col = run_split_by_string_3arg("one,two,three,", ",", 3); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + EXPECT_EQ(arr[0], "one"); + EXPECT_EQ(arr[1], "two"); + EXPECT_EQ(arr[2], "three,"); + } + // limit = 1: no split + { + auto col = run_split_by_string_3arg("one,two,three", ",", 1); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 1); + EXPECT_EQ(arr[0], "one,two,three"); + } + // limit >= parts: return all + { + auto col = run_split_by_string_3arg("a,b,c", ",", 10); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[1], "b"); + EXPECT_EQ(arr[2], "c"); + } + // Multi-char delimiter + limit + { + auto col = run_split_by_string_3arg("a::b::c::d", "::", 2); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 2); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[1], "b::c::d"); + } +} + +TEST(function_string_test, function_split_by_string_limit_empty_delim_test) { + // Empty delimiter + limit: splits by character (ASCII) + { + auto col = run_split_by_string_3arg("abcde", "", 3); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[1], "b"); + EXPECT_EQ(arr[2], "cde"); + } + { + auto col = run_split_by_string_3arg("abcde", "", 1); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 1); + EXPECT_EQ(arr[0], "abcde"); + } + { + auto col = run_split_by_string_3arg("abcde", "", 10); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 5); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[4], "e"); + } + // Empty delimiter + limit: UTF-8 + { + // "你好世" = 3 UTF-8 characters + std::string utf8_str = "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96"; + auto col = run_split_by_string_3arg(utf8_str, "", 2); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 2); + EXPECT_EQ(arr[0], "\xe4\xbd\xa0"); // 你 + EXPECT_EQ(arr[1], "\xe5\xa5\xbd\xe4\xb8\x96"); // 好世 + } +} + +TEST(function_string_test, function_split_by_string_limit_edge_cases_test) { + // limit <= 0: behaves like no limit + { + auto col = run_split_by_string_3arg("a,b,c", ",", -1); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[1], "b"); + EXPECT_EQ(arr[2], "c"); + } + { + auto col = run_split_by_string_3arg("a,b,c", ",", 0); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + } + // Empty source string + { + auto col = run_split_by_string_3arg("", ",", 2); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 0); + } + // Consecutive delimiters + limit + { + auto col = run_split_by_string_3arg(",,,", ",", 2); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 2); + EXPECT_EQ(arr[0], ""); + EXPECT_EQ(arr[1], ",,"); + } + // 2-arg version still works after refactoring + { + auto col = run_split_by_string_2arg("a,b,c", ","); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 3); + EXPECT_EQ(arr[0], "a"); + EXPECT_EQ(arr[1], "b"); + EXPECT_EQ(arr[2], "c"); + } + { + auto col = run_split_by_string_2arg("abcde", ""); + auto arr = get_array_strings(col); + ASSERT_EQ(arr.size(), 5); + } +} + } // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java index 0172c3b433940f..07919a0c360584 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmetic.java @@ -750,6 +750,60 @@ public static Expression splitByString(StringLikeLiteral first, StringLikeLitera return new ArrayLiteral(items); } + /** + * Executable arithmetic functions split_by_string with limit + */ + @ExecFunction(name = "split_by_string") + public static Expression splitByString(StringLikeLiteral first, StringLikeLiteral second, + IntegerLiteral limit) { + int maxParts = limit.getValue(); + if (maxParts <= 0) { + return splitByString(first, second); + } + if (first.getValue().isEmpty()) { + return new ArrayLiteral(ImmutableList.of(), ArrayType.of(first.getDataType())); + } + if (second.getValue().isEmpty()) { + List graphemes = splitByGrapheme(first); + List result = new ArrayList<>(); + if (maxParts >= graphemes.size()) { + for (String resultStr : graphemes) { + result.add(castStringLikeLiteral(first, resultStr)); + } + } else { + for (int i = 0; i < maxParts - 1; i++) { + result.add(castStringLikeLiteral(first, graphemes.get(i))); + } + StringBuilder remaining = new StringBuilder(); + for (int i = maxParts - 1; i < graphemes.size(); i++) { + remaining.append(graphemes.get(i)); + } + result.add(castStringLikeLiteral(first, remaining.toString())); + } + return new ArrayLiteral(result); + } + String[] parts = first.getValue().split(Pattern.quote(second.getValue()), -1); + List items = new ArrayList<>(); + if (maxParts >= parts.length) { + for (String s : parts) { + items.add(castStringLikeLiteral(first, s)); + } + } else { + for (int i = 0; i < maxParts - 1; i++) { + items.add(castStringLikeLiteral(first, parts[i])); + } + StringBuilder rest = new StringBuilder(); + for (int i = maxParts - 1; i < parts.length; i++) { + if (i > maxParts - 1) { + rest.append(second.getValue()); + } + rest.append(parts[i]); + } + items.add(castStringLikeLiteral(first, rest.toString())); + } + return new ArrayLiteral(items); + } + /** * Executable arithmetic functions split_part */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SplitByString.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SplitByString.java index 11d2346c496324..1b07c0cd5b5fac 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SplitByString.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SplitByString.java @@ -18,12 +18,14 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; -import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; @@ -36,11 +38,13 @@ * ScalarFunction 'split_by_string'. This class is generated by GenerateFunction. */ public class SplitByString extends ScalarFunction - implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { + implements ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT)) - .args(StringType.INSTANCE, StringType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE), + FunctionSignature.ret(ArrayType.of(VarcharType.SYSTEM_DEFAULT)) + .args(StringType.INSTANCE, StringType.INSTANCE, IntegerType.INSTANCE) ); /** @@ -50,6 +54,13 @@ public SplitByString(Expression arg0, Expression arg1) { super("split_by_string", arg0, arg1); } + /** + * constructor with 3 arguments. + */ + public SplitByString(Expression arg0, Expression arg1, Expression arg2) { + super("split_by_string", arg0, arg1, arg2); + } + /** constructor for withChildren and reuse signature */ private SplitByString(ScalarFunctionParams functionParams) { super(functionParams); @@ -60,10 +71,20 @@ private SplitByString(ScalarFunctionParams functionParams) { */ @Override public SplitByString withChildren(List children) { - Preconditions.checkArgument(children.size() == 2); + Preconditions.checkArgument(children.size() == 2 || children.size() == 3); return new SplitByString(getFunctionParams(children)); } + @Override + public void checkLegalityBeforeTypeCoercion() { + if (children().size() == 3) { + if (!child(2).isConstant() || !(child(2) instanceof IntegerLikeLiteral)) { + throw new AnalysisException("the third parameter of " + + getName() + " function must be a constant integer: " + toSql()); + } + } + } + @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitSplitByString(this, context); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmeticSplitByStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmeticSplitByStringTest.java new file mode 100644 index 00000000000000..d85aeed2ef5f2f --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/executable/StringArithmeticSplitByStringTest.java @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.executable; + +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.StringType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +public class StringArithmeticSplitByStringTest { + + private static ArrayLiteral makeArray(String... values) { + List items = Arrays.stream(values) + .map(StringLiteral::new) + .collect(Collectors.toList()); + return new ArrayLiteral(items); + } + + private static ArrayLiteral makeEmptyArray() { + return new ArrayLiteral(ImmutableList.of(), ArrayType.of(StringType.INSTANCE)); + } + + @Test + public void testSplitByStringWithLimitBasic() { + // limit < parts: "a,b,c,d" split by "," limit 2 -> ["a", "b,c,d"] + Expression result = StringArithmetic.splitByString( + new StringLiteral("a,b,c,d"), new StringLiteral(","), new IntegerLiteral(2)); + Assertions.assertEquals(makeArray("a", "b,c,d"), result); + + // limit = 3 + result = StringArithmetic.splitByString( + new StringLiteral("a,b,c,d"), new StringLiteral(","), new IntegerLiteral(3)); + Assertions.assertEquals(makeArray("a", "b", "c,d"), result); + + // limit = 1: no split at all + result = StringArithmetic.splitByString( + new StringLiteral("one,two,three"), new StringLiteral(","), new IntegerLiteral(1)); + Assertions.assertEquals(makeArray("one,two,three"), result); + + // multi-char delimiter + limit + result = StringArithmetic.splitByString( + new StringLiteral("a::b::c::d"), new StringLiteral("::"), new IntegerLiteral(2)); + Assertions.assertEquals(makeArray("a", "b::c::d"), result); + } + + @Test + public void testSplitByStringWithLimitExceedParts() { + // limit >= parts: "a,b,c" split by "," limit 10 -> ["a","b","c"] + Expression result = StringArithmetic.splitByString( + new StringLiteral("a,b,c"), new StringLiteral(","), new IntegerLiteral(10)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + + // limit == parts + result = StringArithmetic.splitByString( + new StringLiteral("a,b,c"), new StringLiteral(","), new IntegerLiteral(3)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + } + + @Test + public void testSplitByStringWithLimitZeroAndNegative() { + // limit = 0 -> delegates to 2-arg version + Expression result = StringArithmetic.splitByString( + new StringLiteral("a,b,c"), new StringLiteral(","), new IntegerLiteral(0)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + + // limit = -1 -> delegates to 2-arg version + result = StringArithmetic.splitByString( + new StringLiteral("a,b,c"), new StringLiteral(","), new IntegerLiteral(-1)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + + // limit = -100 -> delegates to 2-arg version + result = StringArithmetic.splitByString( + new StringLiteral("a,b,c"), new StringLiteral(","), new IntegerLiteral(-100)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + } + + @Test + public void testSplitByStringWithLimitEmptyFirst() { + // empty source string -> empty array + Expression result = StringArithmetic.splitByString( + new StringLiteral(""), new StringLiteral(","), new IntegerLiteral(2)); + Assertions.assertEquals(makeEmptyArray(), result); + + result = StringArithmetic.splitByString( + new StringLiteral(""), new StringLiteral(","), new IntegerLiteral(0)); + // limit <= 0 delegates to 2-arg, which also returns empty array for empty input + Assertions.assertEquals(makeEmptyArray(), result); + } + + @Test + public void testSplitByStringWithLimitEmptyDelimiter() { + // empty delimiter splits by character, with limit < chars + Expression result = StringArithmetic.splitByString( + new StringLiteral("abcde"), new StringLiteral(""), new IntegerLiteral(3)); + Assertions.assertEquals(makeArray("a", "b", "cde"), result); + + // limit = 1 -> entire string as single element + result = StringArithmetic.splitByString( + new StringLiteral("abcde"), new StringLiteral(""), new IntegerLiteral(1)); + Assertions.assertEquals(makeArray("abcde"), result); + } + + @Test + public void testSplitByStringWithLimitEmptyDelimiterExceed() { + // empty delimiter + limit >= chars -> all characters + Expression result = StringArithmetic.splitByString( + new StringLiteral("abcde"), new StringLiteral(""), new IntegerLiteral(10)); + Assertions.assertEquals(makeArray("a", "b", "c", "d", "e"), result); + + // exact match + result = StringArithmetic.splitByString( + new StringLiteral("abc"), new StringLiteral(""), new IntegerLiteral(3)); + Assertions.assertEquals(makeArray("a", "b", "c"), result); + } + + @Test + public void testSplitByStringWithLimitConsecutiveDelimiters() { + // consecutive delimiters produce empty strings + Expression result = StringArithmetic.splitByString( + new StringLiteral(",,,"), new StringLiteral(","), new IntegerLiteral(2)); + Assertions.assertEquals(makeArray("", ",,"), result); + + result = StringArithmetic.splitByString( + new StringLiteral(",,a,b,c,"), new StringLiteral(","), new IntegerLiteral(3)); + Assertions.assertEquals(makeArray("", "", "a,b,c,"), result); + } +} diff --git a/regression-test/data/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.out b/regression-test/data/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.out new file mode 100644 index 00000000000000..e2257040fcf649 --- /dev/null +++ b/regression-test/data/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.out @@ -0,0 +1,81 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !limit1 -- +["one", "two,three,"] + +-- !limit2 -- +["one", "two", "three,"] + +-- !limit3 -- +["one", "two", "three", ""] + +-- !limit4 -- +["one", "two", "three", ""] + +-- !limit5 -- +["one,two,three"] + +-- !limit6 -- +["one", "two", "three", ""] + +-- !limit7 -- +["a", "b", "c"] + +-- !limit8 -- +[] + +-- !limit9 -- +["a", "b", "cde"] + +-- !limit10 -- +["abcde"] + +-- !limit11 -- +["a", "b", "c", "d", "e"] + +-- !limit12 -- +["a", "b::c::d"] + +-- !limit13 -- +["a", "b", "c::d"] + +-- !limit14 -- +["1", "2,3,,4,5,,abcde"] + +-- !limit15 -- +\N + +-- !limit16 -- +["你", "a", "好b世c界"] + +-- !limit17 -- +["", ",,"] + +-- !limit18 -- +["", "", "a,b,c,"] + +-- !table1 -- +1 ["a", "b,c,d"] +2 ["x", "y::z"] +3 ["hello"] +4 \N +5 ["a", "b,c,d,e"] + +-- !table2 -- +1 ["a", "b", "c,d"] +2 ["x", "y", "z"] +3 ["hello"] +4 \N +5 ["a", "b", "c,d,e"] + +-- !alias1 -- +["one", "two,three"] + +-- !alias2 -- +["a", "b::c"] + +-- !noregress1 -- +["a", "b", "c"] + +-- !noregress2 -- +["a", "b", "c", "d", "e"] + diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy index 88863da910b32d..bab410e013b13e 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy @@ -987,6 +987,16 @@ suite("fold_constant_string_arithmatic") { testFoldConst("SELECT split_by_string('a..b\$\$c||d((e))f[[g{{h^^i??j**k++l\\\\m','++')") testFoldConst("SELECT split_by_string('a..b\$\$c||d((e))f[[g{{h^^i??j**k++l\\\\m','\\\\')") + // split_by_string with limit + testFoldConst("select split_by_string('a,b,c,d', ',', 2)") + testFoldConst("select split_by_string('a,b,c,d', ',', 3)") + testFoldConst("select split_by_string('a,b,c,d', ',', -1)") + testFoldConst("select split_by_string('a,b,c,d', ',', 0)") + testFoldConst("select split_by_string('abcde', '', 3)") + testFoldConst("select split_by_string('a::b::c', '::', 2)") + testFoldConst("select split_by_string('one,two,three,', ',', 1)") + testFoldConst("select split_by_string('', ',', 2)") + // split_part testFoldConst("select split_part('a,b,c', '', -2)") testFoldConst("select split_part('a,b,c', '', -1)") diff --git a/regression-test/suites/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.groovy b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.groovy new file mode 100644 index 00000000000000..e88af18a68a36f --- /dev/null +++ b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_split_by_string_limit.groovy @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_split_by_string_limit") { + // === Constant expression tests with limit === + + // Basic limit functionality + qt_limit1 "select split_by_string('one,two,three,', ',', 2);" + qt_limit2 "select split_by_string('one,two,three,', ',', 3);" + qt_limit3 "select split_by_string('one,two,three,', ',', 4);" + qt_limit4 "select split_by_string('one,two,three,', ',', 10);" + qt_limit5 "select split_by_string('one,two,three', ',', 1);" + + // limit = -1 (no limit, same as 2-arg) + qt_limit6 "select split_by_string('one,two,three,', ',', -1);" + + // limit = 0 (no limit, same as 2-arg) + qt_limit7 "select split_by_string('a,b,c', ',', 0);" + + // Empty source string + limit + qt_limit8 "select split_by_string('', ',', 2);" + + // Empty delimiter + limit (split by character) + qt_limit9 "select split_by_string('abcde', '', 3);" + qt_limit10 "select split_by_string('abcde', '', 1);" + qt_limit11 "select split_by_string('abcde', '', 10);" + + // Multi-char delimiter + limit + qt_limit12 "select split_by_string('a::b::c::d', '::', 2);" + qt_limit13 "select split_by_string('a::b::c::d', '::', 3);" + qt_limit14 "select split_by_string('1,,2,3,,4,5,,abcde', ',,', 2);" + + // NULL handling + qt_limit15 "select split_by_string(NULL, ',', 2);" + + // UTF-8 + limit + qt_limit16 "select split_by_string('你a好b世c界', '', 3);" + + // Edge cases: consecutive delimiters + limit + qt_limit17 "select split_by_string(',,,', ',', 2);" + qt_limit18 "select split_by_string(',,a,b,c,', ',', 3);" + + // === Table data tests === + sql """DROP TABLE IF EXISTS test_split_limit""" + sql """ + CREATE TABLE IF NOT EXISTS test_split_limit ( + `k1` int(11) NULL COMMENT "", + `v1` varchar(50) NULL COMMENT "", + `v2` varchar(10) NOT NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`k1`) + DISTRIBUTED BY HASH(`k1`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ) + """ + sql """ INSERT INTO test_split_limit VALUES(1, 'a,b,c,d', ',') """ + sql """ INSERT INTO test_split_limit VALUES(2, 'x::y::z', '::') """ + sql """ INSERT INTO test_split_limit VALUES(3, 'hello', ',') """ + sql """ INSERT INTO test_split_limit VALUES(4, null, ',') """ + sql """ INSERT INTO test_split_limit VALUES(5, 'a,b,c,d,e', ',') """ + + qt_table1 "SELECT k1, split_by_string(v1, v2, 2) FROM test_split_limit ORDER BY k1" + qt_table2 "SELECT k1, split_by_string(v1, v2, 3) FROM test_split_limit ORDER BY k1" + + // === split alias + limit === + qt_alias1 "select split('one,two,three', ',', 2);" + qt_alias2 "select split('a::b::c', '::', 2);" + + // === Verify 2-arg still works (no regression) === + qt_noregress1 "select split_by_string('a,b,c', ',');" + qt_noregress2 "select split_by_string('abcde', '');" +}