From 85130d7042a1c2177afff107a3d96bdc1eeadad6 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 18 Jan 2026 17:05:03 +0800 Subject: [PATCH 1/2] Use image width and height when not explicitly set --- examples/cli/main.cpp | 112 +++++++++++++++++++------------------ examples/common/common.hpp | 51 ++++++++++++++--- 2 files changed, 101 insertions(+), 62 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ddc282817..cd23f9b5a 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; - parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", "; + parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", "; parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { @@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) { } bool vae_decode_only = true; - sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; + sd_image_t init_image = {0, 0, 3, nullptr}; + sd_image_t end_image = {0, 0, 3, nullptr}; + sd_image_t control_image = {0, 0, 3, nullptr}; + sd_image_t mask_image = {0, 0, 1, nullptr}; std::vector ref_images; std::vector pmid_images; std::vector control_frames; @@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) { control_frames.clear(); }; - if (gen_params.init_image_path.size() > 0) { - vae_decode_only = false; + auto load_image_and_update_size = [&](const std::string& path, + sd_image_t& image, + bool resize_image = true, + int expected_channel = 3) -> bool { + int expected_width = 0; + int expected_height = 0; + if (resize_image && gen_params.width_and_height_are_set()) { + expected_width = gen_params.width; + expected_height = gen_params.height; + } - int width = 0; - int height = 0; - init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); - if (init_image.data == nullptr) { - LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str()); + if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) { + LOG_ERROR("load image from '%s' failed", path.c_str()); release_all_resources(); + return false; + } + + gen_params.set_width_and_height_if_unset(image.width, image.height); + return true; + }; + + if (gen_params.init_image_path.size() > 0) { + vae_decode_only = false; + if (!load_image_and_update_size(gen_params.init_image_path, init_image)) { return 1; } } if (gen_params.end_image_path.size() > 0) { vae_decode_only = false; - - int width = 0; - int height = 0; - end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); - if (end_image.data == nullptr) { - LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str()); - release_all_resources(); + if (!load_image_and_update_size(gen_params.init_image_path, end_image)) { return 1; } } + if (gen_params.ref_image_paths.size() > 0) { + vae_decode_only = false; + for (auto& path : gen_params.ref_image_paths) { + sd_image_t ref_image = {0, 0, 3, nullptr}; + if (!load_image_and_update_size(path, ref_image, false)) { + return 1; + } + ref_images.push_back(ref_image); + } + } + if (gen_params.mask_image_path.size() > 0) { - int c = 0; - int width = 0; - int height = 0; - mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); - if (mask_image.data == nullptr) { + if (load_sd_image_from_file(&mask_image, + gen_params.mask_image_path.c_str(), + gen_params.get_resolved_width(), + gen_params.get_resolved_height(), + 1)) { LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); release_all_resources(); return 1; } } else { - mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height); + mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height()); if (mask_image.data == nullptr) { LOG_ERROR("malloc mask image failed"); release_all_resources(); return 1; } - memset(mask_image.data, 255, gen_params.width * gen_params.height); + mask_image.width = gen_params.get_resolved_width(); + mask_image.height = gen_params.get_resolved_height(); + memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height()); } if (gen_params.control_image_path.size() > 0) { - int width = 0; - int height = 0; - control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); - if (control_image.data == nullptr) { + if (load_sd_image_from_file(&control_image, + gen_params.control_image_path.c_str(), + gen_params.get_resolved_width(), + gen_params.get_resolved_height())) { LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); release_all_resources(); return 1; @@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) { } } - if (gen_params.ref_image_paths.size() > 0) { - vae_decode_only = false; - for (auto& path : gen_params.ref_image_paths) { - int width = 0; - int height = 0; - uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height); - if (image_buffer == nullptr) { - LOG_ERROR("load image from '%s' failed", path.c_str()); - release_all_resources(); - return 1; - } - ref_images.push_back({(uint32_t)width, - (uint32_t)height, - 3, - image_buffer}); - } - } - if (!gen_params.control_video_path.empty()) { if (!load_images_from_dir(gen_params.control_video_path, control_frames, - gen_params.width, - gen_params.height, + gen_params.get_resolved_width(), + gen_params.get_resolved_height(), gen_params.video_frames, cli_params.verbose)) { release_all_resources(); @@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) { gen_params.auto_resize_ref_image, gen_params.increase_ref_index, mask_image, - gen_params.width, - gen_params.height, + gen_params.get_resolved_width(), + gen_params.get_resolved_height(), gen_params.sample_params, gen_params.strength, gen_params.seed, @@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) { end_image, control_frames.data(), (int)control_frames.size(), - gen_params.width, - gen_params.height, + gen_params.get_resolved_width(), + gen_params.get_resolved_height(), gen_params.sample_params, gen_params.high_noise_sample_params, gen_params.moe_boundary, diff --git a/examples/common/common.hpp b/examples/common/common.hpp index d299da50c..dd8a22649 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1024,8 +1024,8 @@ struct SDGenerationParams { std::string prompt_with_lora; // for metadata record only std::string negative_prompt; int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; + int width = -1; + int height = -1; int batch_count = 1; std::string init_image_path; std::string end_image_path; @@ -1705,17 +1705,36 @@ struct SDGenerationParams { } } - bool process_and_check(SDMode mode, const std::string& lora_model_dir) { - prompt_with_lora = prompt; + void set_width_if_unset(int w) { if (width <= 0) { - LOG_ERROR("error: the width must be greater than 0\n"); - return false; + width = w; } + } + void set_height_if_unset(int h) { if (height <= 0) { - LOG_ERROR("error: the height must be greater than 0\n"); - return false; + height = h; } + } + + bool width_and_height_are_set() const { + return width > 0 && height > 0; + } + + void set_width_and_height_if_unset(int w, int h) { + if (!width_and_height_are_set()) { + LOG_INFO("set width x height to %d x %d", w, h); + width = w; + height = h; + } + } + + int get_resolved_width() const { return (width > 0) ? width : 512; } + + int get_resolved_height() const { return (height > 0) ? height : 512; } + + bool process_and_check(SDMode mode, const std::string& lora_model_dir) { + prompt_with_lora = prompt; if (sample_params.sample_steps <= 0) { LOG_ERROR("error: the sample_steps must be greater than 0\n"); @@ -2083,6 +2102,22 @@ uint8_t* load_image_from_file(const char* image_path, return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); } +bool load_sd_image_from_file(sd_image_t* image, + const char* image_path, + int expected_width = 0, + int expected_height = 0, + int expected_channel = 3) { + int width; + int height; + image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); + if (image->data == nullptr) { + return false; + } + image->width = width; + image->height = height; + return true; +} + uint8_t* load_image_from_memory(const char* image_bytes, int len, int& width, From 1b56fa20326b239c20b1ecbc50cb66d7b9f26750 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 22 Jan 2026 23:38:07 +0800 Subject: [PATCH 2/2] remove unused methods --- examples/common/common.hpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/examples/common/common.hpp b/examples/common/common.hpp index dd8a22649..e7b9c20df 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1705,18 +1705,6 @@ struct SDGenerationParams { } } - void set_width_if_unset(int w) { - if (width <= 0) { - width = w; - } - } - - void set_height_if_unset(int h) { - if (height <= 0) { - height = h; - } - } - bool width_and_height_are_set() const { return width > 0 && height > 0; }