diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 16a4909828b0..eb931011906f 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -620,7 +620,8 @@ void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username std::string& password) { std::string encoded_credentials = FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); - std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + ASSERT_OK_AND_ASSIGN(auto decoded, arrow::util::base64_decode(encoded_credentials)); + std::stringstream decoded_stream(decoded); std::getline(decoded_stream, username, ':'); std::getline(decoded_stream, password, ':'); } diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 4352716ebd76..deb3e9e3fbe4 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -49,6 +49,7 @@ add_arrow_test(utility-test SOURCES align_util_test.cc atfork_test.cc + base64_test.cc byte_size_test.cc byte_stream_split_test.cc cache_test.cc diff --git a/cpp/src/arrow/util/base64.h b/cpp/src/arrow/util/base64.h index 5b80e19d896b..a575fee45132 100644 --- a/cpp/src/arrow/util/base64.h +++ b/cpp/src/arrow/util/base64.h @@ -20,6 +20,7 @@ #include #include +#include "arrow/result.h" #include "arrow/util/visibility.h" namespace arrow { @@ -29,7 +30,7 @@ ARROW_EXPORT std::string base64_encode(std::string_view s); ARROW_EXPORT -std::string base64_decode(std::string_view s); +arrow::Result base64_decode(std::string_view s); } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/base64_test.cc b/cpp/src/arrow/util/base64_test.cc new file mode 100644 index 000000000000..38f99ea5e6a1 --- /dev/null +++ b/cpp/src/arrow/util/base64_test.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/base64.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { +namespace util { + +TEST(Base64DecodeTest, ValidInputs) { + ASSERT_OK_AND_ASSIGN(auto empty, arrow::util::base64_decode("")); + EXPECT_EQ(empty, ""); + + ASSERT_OK_AND_ASSIGN(auto two_paddings, arrow::util::base64_decode("Zg==")); + EXPECT_EQ(two_paddings, "f"); + + ASSERT_OK_AND_ASSIGN(auto one_padding, arrow::util::base64_decode("Zm8=")); + EXPECT_EQ(one_padding, "fo"); + + ASSERT_OK_AND_ASSIGN(auto no_padding, arrow::util::base64_decode("Zm9v")); + EXPECT_EQ(no_padding, "foo"); + + ASSERT_OK_AND_ASSIGN(auto multiblock, arrow::util::base64_decode("SGVsbG8gd29ybGQ=")); + EXPECT_EQ(multiblock, "Hello world"); +} + +TEST(Base64DecodeTest, BinaryOutput) { + // 'A' maps to index 0 — same zero value used for padding slots + // verifies the 'A' bug is not present + ASSERT_OK_AND_ASSIGN(auto all_A, arrow::util::base64_decode("AAAA")); + EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3)); + + // Arbitrary non-ASCII output bytes + ASSERT_OK_AND_ASSIGN(auto binary, arrow::util::base64_decode("AP8A")); + EXPECT_EQ(binary, std::string("\x00\xff\x00", 3)); +} + +TEST(Base64DecodeTest, InvalidLength) { + ASSERT_RAISES_WITH_MESSAGE( + Invalid, "Invalid: Invalid base64 input: length is not a multiple of 4", + arrow::util::base64_decode("abc")); +} + +TEST(Base64DecodeTest, InvalidCharacters) { + ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab$=")); + + // Non-ASCII byte + std::string non_ascii = std::string("abc") + static_cast(0xFF); + ASSERT_RAISES(Invalid, arrow::util::base64_decode(non_ascii)); + + // Corruption mid-string across multiple blocks + ASSERT_RAISES(Invalid, arrow::util::base64_decode("aGVs$G8gd29ybGQ=")); +} + +TEST(Base64DecodeTest, InvalidPadding) { + // Padding in wrong position within block + ASSERT_RAISES(Invalid, arrow::util::base64_decode("ab=c")); + + // 3 padding characters — exceeds maximum of 2 + ASSERT_RAISES(Invalid, arrow::util::base64_decode("a===")); + + // 4 padding characters + ASSERT_RAISES(Invalid, arrow::util::base64_decode("====")); + + // Padding in non-final block across multiple blocks + ASSERT_RAISES(Invalid, arrow::util::base64_decode("Zm8=Zm8=")); +} + +} // namespace util +} // namespace arrow diff --git a/cpp/src/arrow/vendored/base64.cpp b/cpp/src/arrow/vendored/base64.cpp index 6f53c0524e71..c0348d985fba 100644 --- a/cpp/src/arrow/vendored/base64.cpp +++ b/cpp/src/arrow/vendored/base64.cpp @@ -30,6 +30,7 @@ */ #include "arrow/util/base64.h" +#include "arrow/result.h" #include namespace arrow { @@ -40,11 +41,6 @@ static const std::string base64_chars = "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; - -static inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - static std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { std::string ret; int i = 0; @@ -93,38 +89,67 @@ std::string base64_encode(std::string_view string_to_encode) { return base64_encode(bytes_to_encode, in_len); } -std::string base64_decode(std::string_view encoded_string) { +Result base64_decode(std::string_view encoded_string) { size_t in_len = encoded_string.size(); int i = 0; - int j = 0; - int in_ = 0; + std::string_view::size_type in_ = 0; + int padding_count = 0; + int block_padding = 0; + bool padding_started = false; unsigned char char_array_4[4], char_array_3[3]; std::string ret; - while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i ==4) { - for (i = 0; i <4; i++) - char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff; + if (encoded_string.size() % 4 != 0) { + return Status::Invalid("Invalid base64 input: length is not a multiple of 4"); + } - char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + while (in_len--) { + unsigned char c = encoded_string[in_]; - for (i = 0; (i < 3); i++) - ret += char_array_3[i]; - i = 0; + if (c == '=') { + padding_started = true; + padding_count++; + + if (padding_count > 2) { + return Status::Invalid("Invalid base64 input: too many padding characters"); + } + + char_array_4[i++] = 0; + } else { + if (padding_started) { + return Status::Invalid("Invalid base64 input: padding characters must be at the end"); + } + + if (base64_chars.find(c) == std::string::npos) { + return Status::Invalid( + "Invalid base64 input: contains non-base64 byte at position " + + std::to_string(in_)); + } + + char_array_4[i++] = c; } - } - if (i) { - for (j = 0; j < i; j++) - char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff; + in_++; + + if (i == 4) { + for (i = 0; i < 4; i++) { + if (char_array_4[i] != 0) { + char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff; + } + } + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + block_padding = padding_count; - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + for (i = 0; i < 3 - block_padding; i++) { + ret += char_array_3[i]; + } - for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; + i = 0; + } } return ret; diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 3eda4afadb03..4d8531f76e7f 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -269,7 +269,15 @@ const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t i return ""; } // use arrow method to decode base64 string - std::string decoded_str = arrow::util::base64_decode(std::string_view(in, in_len)); + auto result = arrow::util::base64_decode(std::string_view(in, in_len)); + if (!result.ok()) { + gdv_fn_context_set_error_msg(context, result.status().message().c_str()); + *out_len = 0; + return ""; + } + + std::string decoded_str = std::move(result).ValueOrDie(); + *out_len = static_cast(decoded_str.length()); // allocate memory for response char* ret = reinterpret_cast( diff --git a/cpp/src/parquet/arrow/fuzz_internal.cc b/cpp/src/parquet/arrow/fuzz_internal.cc index dfbb8ae161fd..b223986f0886 100644 --- a/cpp/src/parquet/arrow/fuzz_internal.cc +++ b/cpp/src/parquet/arrow/fuzz_internal.cc @@ -83,8 +83,13 @@ class FuzzDecryptionKeyRetriever : public DecryptionKeyRetriever { } // Is it a key generated by MakeEncryptionKey? if (key_id.starts_with(kInlineKeyPrefix)) { - return SecureString( - ::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length()))); + auto result = + ::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length())); + if (!result.ok()) { + throw ParquetException(result.status().message()); + } + + return SecureString(std::move(result).ValueOrDie()); } throw ParquetException("Unknown fuzz encryption key_id"); } diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 11d5d13e4bcb..9c4c462c6b8c 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -953,7 +953,8 @@ Status GetOriginSchema(const std::shared_ptr& metadata, // The original Arrow schema was serialized using the store_schema option. // We deserialize it here and use it to inform read options such as // dictionary-encoded fields. - auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index)); + ARROW_ASSIGN_OR_RAISE(auto decoded, + ::arrow::util::base64_decode(metadata->value(schema_index))); auto schema_buf = std::make_shared(decoded); ::arrow::ipc::DictionaryMemo dict_memo; diff --git a/cpp/src/parquet/encryption/file_key_unwrapper.cc b/cpp/src/parquet/encryption/file_key_unwrapper.cc index 4dc1492a0b76..dabf56dc5966 100644 --- a/cpp/src/parquet/encryption/file_key_unwrapper.cc +++ b/cpp/src/parquet/encryption/file_key_unwrapper.cc @@ -122,7 +122,12 @@ KeyWithMasterId FileKeyUnwrapper::GetDataEncryptionKey(const KeyMaterial& key_ma }); // Decrypt the data key - std::string aad = ::arrow::util::base64_decode(encoded_kek_id); + auto result = ::arrow::util::base64_decode(encoded_kek_id); + if (!result.ok()) { + throw ParquetException(result.status().message()); + } + + std::string aad = std::move(result).ValueOrDie(); data_key = internal::DecryptKeyLocally(encoded_wrapped_dek, kek_bytes, aad); } diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc b/cpp/src/parquet/encryption/key_toolkit_internal.cc index d304041e3ea7..b987ac4f596a 100644 --- a/cpp/src/parquet/encryption/key_toolkit_internal.cc +++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc @@ -52,7 +52,12 @@ std::string EncryptKeyLocally(const SecureString& key_bytes, SecureString DecryptKeyLocally(const std::string& encoded_encrypted_key, const SecureString& master_key, const std::string& aad) { - std::string encrypted_key = ::arrow::util::base64_decode(encoded_encrypted_key); + auto result = ::arrow::util::base64_decode(encoded_encrypted_key); + if (!result.ok()) { + throw ParquetException(result.status().message()); + } + + std::string encrypted_key = std::move(result).ValueOrDie(); AesDecryptor key_decryptor(ParquetCipher::AES_GCM_V1, static_cast(master_key.size()), false,