Skip to content

Commit 7ade90e

Browse files
authored
feat: add sdcpp api support (#1407)
1 parent 118489e commit 7ade90e

File tree

17 files changed

+3237
-1308
lines changed

17 files changed

+3237
-1308
lines changed

examples/cli/main.cpp

Lines changed: 46 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,22 @@ struct SDCliParams {
192192
return options;
193193
};
194194

195-
bool process_and_check() {
196-
if (mode != METADATA && output_path.length() == 0) {
197-
LOG_ERROR("error: the following arguments are required: output_path");
198-
return false;
199-
}
200-
195+
bool resolve() {
201196
if (mode == CONVERT) {
202197
if (output_path == "output.png") {
203198
output_path = "output.gguf";
204199
}
205-
} else if (mode == METADATA) {
200+
}
201+
return true;
202+
}
203+
204+
bool validate() {
205+
if (mode != METADATA) {
206+
if (output_path.length() == 0) {
207+
LOG_ERROR("error: the following arguments are required: output_path");
208+
return false;
209+
}
210+
} else {
206211
if (image_path.empty()) {
207212
LOG_ERROR("error: metadata mode needs an image path (--image)");
208213
return false;
@@ -216,6 +221,16 @@ struct SDCliParams {
216221
return true;
217222
}
218223

224+
bool resolve_and_validate() {
225+
if (!resolve()) {
226+
return false;
227+
}
228+
if (!validate()) {
229+
return false;
230+
}
231+
return true;
232+
}
233+
219234
std::string to_string() const {
220235
std::ostringstream oss;
221236
oss << "SDCliParams {\n"
@@ -260,10 +275,10 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
260275
exit(cli_params.normal_exit ? 0 : 1);
261276
}
262277

263-
bool valid = cli_params.process_and_check();
278+
bool valid = cli_params.resolve_and_validate();
264279
if (valid && cli_params.mode != METADATA) {
265-
valid = ctx_params.process_and_check(cli_params.mode) &&
266-
gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir);
280+
valid = ctx_params.resolve_and_validate(cli_params.mode) &&
281+
gen_params.resolve_and_validate(cli_params.mode, ctx_params.lora_model_dir);
267282
}
268283

269284
if (!valid) {
@@ -278,7 +293,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
278293
}
279294

280295
bool load_images_from_dir(const std::string dir,
281-
SDImageVec& images,
296+
std::vector<SDImageOwner>& images,
282297
int expected_width = 0,
283298
int expected_height = 0,
284299
int max_image_num = 0,
@@ -315,10 +330,10 @@ bool load_images_from_dir(const std::string dir,
315330
return false;
316331
}
317332

318-
images.push_back({(uint32_t)width,
319-
(uint32_t)height,
320-
3,
321-
image_buffer});
333+
images.emplace_back(sd_image_t{(uint32_t)width,
334+
(uint32_t)height,
335+
3,
336+
image_buffer});
322337

323338
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
324339
break;
@@ -558,13 +573,6 @@ int main(int argc, const char* argv[]) {
558573
}
559574

560575
bool vae_decode_only = true;
561-
SDImageOwner init_image({0, 0, 3, nullptr});
562-
SDImageOwner end_image({0, 0, 3, nullptr});
563-
SDImageOwner control_image({0, 0, 3, nullptr});
564-
SDImageOwner mask_image({0, 0, 1, nullptr});
565-
SDImageVec ref_images;
566-
SDImageVec pmid_images;
567-
SDImageVec control_frames;
568576

569577
auto load_image_and_update_size = [&](const std::string& path,
570578
SDImageOwner& image,
@@ -588,31 +596,32 @@ int main(int argc, const char* argv[]) {
588596

589597
if (gen_params.init_image_path.size() > 0) {
590598
vae_decode_only = false;
591-
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
599+
if (!load_image_and_update_size(gen_params.init_image_path, gen_params.init_image)) {
592600
return 1;
593601
}
594602
}
595603

596604
if (gen_params.end_image_path.size() > 0) {
597605
vae_decode_only = false;
598-
if (!load_image_and_update_size(gen_params.end_image_path, end_image)) {
606+
if (!load_image_and_update_size(gen_params.end_image_path, gen_params.end_image)) {
599607
return 1;
600608
}
601609
}
602610

603611
if (gen_params.ref_image_paths.size() > 0) {
604612
vae_decode_only = false;
613+
gen_params.ref_images.clear();
605614
for (auto& path : gen_params.ref_image_paths) {
606615
SDImageOwner ref_image({0, 0, 3, nullptr});
607616
if (!load_image_and_update_size(path, ref_image, false)) {
608617
return 1;
609618
}
610-
ref_images.push_back(std::move(ref_image));
619+
gen_params.ref_images.push_back(std::move(ref_image));
611620
}
612621
}
613622

614623
if (gen_params.mask_image_path.size() > 0) {
615-
if (!load_sd_image_from_file(mask_image.put(),
624+
if (!load_sd_image_from_file(gen_params.mask_image.put(),
616625
gen_params.mask_image_path.c_str(),
617626
gen_params.get_resolved_width(),
618627
gen_params.get_resolved_height(),
@@ -630,19 +639,19 @@ int main(int argc, const char* argv[]) {
630639
generated_mask.width = gen_params.get_resolved_width();
631640
generated_mask.height = gen_params.get_resolved_height();
632641
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
633-
mask_image.reset(generated_mask);
642+
gen_params.mask_image.reset(generated_mask);
634643
}
635644

636645
if (gen_params.control_image_path.size() > 0) {
637-
if (!load_sd_image_from_file(control_image.put(),
646+
if (!load_sd_image_from_file(gen_params.control_image.put(),
638647
gen_params.control_image_path.c_str(),
639648
gen_params.get_resolved_width(),
640649
gen_params.get_resolved_height())) {
641650
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
642651
return 1;
643652
}
644653
if (cli_params.canny_preprocess) { // apply preprocessor
645-
preprocess_canny(control_image.get(),
654+
preprocess_canny(gen_params.control_image.get(),
646655
0.08f,
647656
0.08f,
648657
0.8f,
@@ -652,8 +661,9 @@ int main(int argc, const char* argv[]) {
652661
}
653662

654663
if (!gen_params.control_video_path.empty()) {
664+
gen_params.control_frames.clear();
655665
if (!load_images_from_dir(gen_params.control_video_path,
656-
control_frames,
666+
gen_params.control_frames,
657667
gen_params.get_resolved_width(),
658668
gen_params.get_resolved_height(),
659669
gen_params.video_frames,
@@ -663,8 +673,9 @@ int main(int argc, const char* argv[]) {
663673
}
664674

665675
if (!gen_params.pm_id_images_dir.empty()) {
676+
gen_params.pm_id_images.clear();
666677
if (!load_images_from_dir(gen_params.pm_id_images_dir,
667-
pmid_images,
678+
gen_params.pm_id_images,
668679
0,
669680
0,
670681
0,
@@ -684,7 +695,7 @@ int main(int argc, const char* argv[]) {
684695

685696
if (cli_params.mode == UPSCALE) {
686697
num_results = 1;
687-
results.push_back(init_image.release());
698+
results.push_back(gen_params.init_image.release());
688699
} else {
689700
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
690701

@@ -706,63 +717,13 @@ int main(int argc, const char* argv[]) {
706717
}
707718

708719
if (cli_params.mode == IMG_GEN) {
709-
sd_img_gen_params_t img_gen_params = {
710-
gen_params.lora_vec.data(),
711-
static_cast<uint32_t>(gen_params.lora_vec.size()),
712-
gen_params.prompt.c_str(),
713-
gen_params.negative_prompt.c_str(),
714-
gen_params.clip_skip,
715-
init_image.get(),
716-
ref_images.data(),
717-
(int)ref_images.size(),
718-
gen_params.auto_resize_ref_image,
719-
gen_params.increase_ref_index,
720-
mask_image.get(),
721-
gen_params.get_resolved_width(),
722-
gen_params.get_resolved_height(),
723-
gen_params.sample_params,
724-
gen_params.strength,
725-
gen_params.seed,
726-
gen_params.batch_count,
727-
control_image.get(),
728-
gen_params.control_strength,
729-
{
730-
pmid_images.data(),
731-
(int)pmid_images.size(),
732-
gen_params.pm_id_embed_path.c_str(),
733-
gen_params.pm_style_strength,
734-
}, // pm_params
735-
gen_params.vae_tiling_params,
736-
gen_params.cache_params,
737-
};
720+
sd_img_gen_params_t img_gen_params = gen_params.to_sd_img_gen_params_t();
738721

739722
num_results = gen_params.batch_count;
740723
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
741724
} else if (cli_params.mode == VID_GEN) {
742-
sd_vid_gen_params_t vid_gen_params = {
743-
gen_params.lora_vec.data(),
744-
static_cast<uint32_t>(gen_params.lora_vec.size()),
745-
gen_params.prompt.c_str(),
746-
gen_params.negative_prompt.c_str(),
747-
gen_params.clip_skip,
748-
init_image.get(),
749-
end_image.get(),
750-
control_frames.data(),
751-
(int)control_frames.size(),
752-
gen_params.get_resolved_width(),
753-
gen_params.get_resolved_height(),
754-
gen_params.sample_params,
755-
gen_params.high_noise_sample_params,
756-
gen_params.moe_boundary,
757-
gen_params.strength,
758-
gen_params.seed,
759-
gen_params.video_frames,
760-
gen_params.vace_strength,
761-
gen_params.vae_tiling_params,
762-
gen_params.cache_params,
763-
};
764-
765-
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
725+
sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t();
726+
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
766727
results.adopt(generated_video, num_results);
767728
}
768729

0 commit comments

Comments
 (0)