Skip to content

Commit 118489e

Browse files
wbrunaprofessor-moody
andauthored
chore: harden safetensors and gguf loading code (#1404)
Co-authored-by: professor-moody <keys@nimbus.lan>
1 parent be9f51b commit 118489e

2 files changed

Lines changed: 24 additions & 8 deletions

File tree

src/gguf_reader.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class GGUFReader {
5959
if (!safe_read(fin, key_len))
6060
return false;
6161

62+
if (key_len > 4096)
63+
return false;
64+
6265
std::string key(key_len, '\0');
6366
if (!safe_read(fin, (char*)key.data(), key_len))
6467
return false;

src/model.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ bool is_safetensors_file(const std::string& file_path) {
315315
if (!file) {
316316
return false;
317317
}
318-
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
319-
if (header_.is_discarded()) {
318+
try {
319+
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
320+
} catch (const std::exception&) {
320321
return false;
321322
}
322323
return true;
@@ -511,7 +512,14 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
511512
return false;
512513
}
513514

514-
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
515+
nlohmann::json header_;
516+
try {
517+
header_ = nlohmann::json::parse(header_buf.data());
518+
} catch (const std::exception&) {
519+
LOG_ERROR("parsing safetensors header failed", file_path.c_str());
520+
file_paths_.pop_back();
521+
return false;
522+
}
515523

516524
for (auto& item : header_.items()) {
517525
std::string name = item.key();
@@ -575,24 +583,29 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
575583

576584
size_t tensor_data_size = end - begin;
577585

586+
bool tensor_size_ok;
578587
if (dtype == "F8_E4M3") {
579588
tensor_storage.is_f8_e4m3 = true;
580589
// f8 -> f16
581-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
590+
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
582591
} else if (dtype == "F8_E5M2") {
583592
tensor_storage.is_f8_e5m2 = true;
584593
// f8 -> f16
585-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
594+
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size * 2);
586595
} else if (dtype == "F64") {
587596
tensor_storage.is_f64 = true;
588597
// f64 -> f32
589-
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
598+
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
590599
} else if (dtype == "I64") {
591600
tensor_storage.is_i64 = true;
592601
// i64 -> i32
593-
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
602+
tensor_size_ok = (tensor_storage.nbytes() * 2 == tensor_data_size);
594603
} else {
595-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
604+
tensor_size_ok = (tensor_storage.nbytes() == tensor_data_size);
605+
}
606+
if (!tensor_size_ok) {
607+
LOG_ERROR("size mismatch for tensor '%s' (%s)\n", name.c_str(), dtype.c_str());
608+
return false;
596609
}
597610

598611
add_tensor_storage(tensor_storage);

0 commit comments

Comments
 (0)