Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 58 additions & 54 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
Expand Down Expand Up @@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
}

bool vae_decode_only = true;
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames;
Expand All @@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
control_frames.clear();
};

if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false;
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
int expected_height = 0;
if (resize_image && gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}

int width = 0;
int height = 0;
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}

gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
};

if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
return 1;
}
}

if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false;

int width = 0;
int height = 0;
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
release_all_resources();
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
return 1;
}
}

if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
}
}

if (gen_params.mask_image_path.size() > 0) {
int c = 0;
int width = 0;
int height = 0;
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
if (mask_image.data == nullptr) {
if (load_sd_image_from_file(&mask_image,
gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1;
}
} else {
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (mask_image.data == nullptr) {
LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1;
}
memset(mask_image.data, 255, gen_params.width * gen_params.height);
mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
}

if (gen_params.control_image_path.size() > 0) {
int width = 0;
int height = 0;
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (control_image.data == nullptr) {
if (load_sd_image_from_file(&control_image,
gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(),
gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1;
Expand All @@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
}
}

if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back({(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
}
}

if (!gen_params.control_video_path.empty()) {
if (!load_images_from_dir(gen_params.control_video_path,
control_frames,
gen_params.width,
gen_params.height,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.video_frames,
cli_params.verbose)) {
release_all_resources();
Expand Down Expand Up @@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
Expand Down Expand Up @@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
end_image,
control_frames.data(),
(int)control_frames.size(),
gen_params.width,
gen_params.height,
gen_params.get_resolved_width(),
gen_params.get_resolved_height(),
gen_params.sample_params,
gen_params.high_noise_sample_params,
gen_params.moe_boundary,
Expand Down
51 changes: 43 additions & 8 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ struct SDGenerationParams {
std::string prompt_with_lora; // for metadata record only
std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int width = -1;
int height = -1;
int batch_count = 1;
std::string init_image_path;
std::string end_image_path;
Expand Down Expand Up @@ -1705,17 +1705,36 @@ struct SDGenerationParams {
}
}

bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt;
void set_width_if_unset(int w) {
if (width <= 0) {
LOG_ERROR("error: the width must be greater than 0\n");
return false;
width = w;
}
}

void set_height_if_unset(int h) {
if (height <= 0) {
LOG_ERROR("error: the height must be greater than 0\n");
return false;
height = h;
}
}

bool width_and_height_are_set() const {
return width > 0 && height > 0;
}

void set_width_and_height_if_unset(int w, int h) {
if (!width_and_height_are_set()) {
LOG_INFO("set width x height to %d x %d", w, h);
width = w;
height = h;
}
}

int get_resolved_width() const { return (width > 0) ? width : 512; }

int get_resolved_height() const { return (height > 0) ? height : 512; }

bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt;

if (sample_params.sample_steps <= 0) {
LOG_ERROR("error: the sample_steps must be greater than 0\n");
Expand Down Expand Up @@ -2083,6 +2102,22 @@ uint8_t* load_image_from_file(const char* image_path,
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
}

bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int width;
int height;
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
if (image->data == nullptr) {
return false;
}
image->width = width;
image->height = height;
return true;
}

uint8_t* load_image_from_memory(const char* image_bytes,
int len,
int& width,
Expand Down
Loading