@@ -192,17 +192,22 @@ struct SDCliParams {
192192 return options;
193193 };
194194
195- bool process_and_check () {
196- if (mode != METADATA && output_path.length () == 0 ) {
197- LOG_ERROR (" error: the following arguments are required: output_path" );
198- return false ;
199- }
200-
195+ bool resolve () {
201196 if (mode == CONVERT) {
202197 if (output_path == " output.png" ) {
203198 output_path = " output.gguf" ;
204199 }
205- } else if (mode == METADATA) {
200+ }
201+ return true ;
202+ }
203+
204+ bool validate () {
205+ if (mode != METADATA) {
206+ if (output_path.length () == 0 ) {
207+ LOG_ERROR (" error: the following arguments are required: output_path" );
208+ return false ;
209+ }
210+ } else {
206211 if (image_path.empty ()) {
207212 LOG_ERROR (" error: metadata mode needs an image path (--image)" );
208213 return false ;
@@ -216,6 +221,16 @@ struct SDCliParams {
216221 return true ;
217222 }
218223
224+ bool resolve_and_validate () {
225+ if (!resolve ()) {
226+ return false ;
227+ }
228+ if (!validate ()) {
229+ return false ;
230+ }
231+ return true ;
232+ }
233+
219234 std::string to_string () const {
220235 std::ostringstream oss;
221236 oss << " SDCliParams {\n "
@@ -260,10 +275,10 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
260275 exit (cli_params.normal_exit ? 0 : 1 );
261276 }
262277
263- bool valid = cli_params.process_and_check ();
278+ bool valid = cli_params.resolve_and_validate ();
264279 if (valid && cli_params.mode != METADATA) {
265- valid = ctx_params.process_and_check (cli_params.mode ) &&
266- gen_params.process_and_check (cli_params.mode , ctx_params.lora_model_dir );
280+ valid = ctx_params.resolve_and_validate (cli_params.mode ) &&
281+ gen_params.resolve_and_validate (cli_params.mode , ctx_params.lora_model_dir );
267282 }
268283
269284 if (!valid) {
@@ -278,7 +293,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
278293}
279294
280295bool load_images_from_dir (const std::string dir,
281- SDImageVec & images,
296+ std::vector<SDImageOwner> & images,
282297 int expected_width = 0 ,
283298 int expected_height = 0 ,
284299 int max_image_num = 0 ,
@@ -315,10 +330,10 @@ bool load_images_from_dir(const std::string dir,
315330 return false ;
316331 }
317332
318- images.push_back ( {(uint32_t )width,
319- (uint32_t )height,
320- 3 ,
321- image_buffer});
333+ images.emplace_back ( sd_image_t {(uint32_t )width,
334+ (uint32_t )height,
335+ 3 ,
336+ image_buffer});
322337
323338 if (max_image_num > 0 && static_cast <int >(images.size ()) >= max_image_num) {
324339 break ;
@@ -558,13 +573,6 @@ int main(int argc, const char* argv[]) {
558573 }
559574
560575 bool vae_decode_only = true ;
561- SDImageOwner init_image ({0 , 0 , 3 , nullptr });
562- SDImageOwner end_image ({0 , 0 , 3 , nullptr });
563- SDImageOwner control_image ({0 , 0 , 3 , nullptr });
564- SDImageOwner mask_image ({0 , 0 , 1 , nullptr });
565- SDImageVec ref_images;
566- SDImageVec pmid_images;
567- SDImageVec control_frames;
568576
569577 auto load_image_and_update_size = [&](const std::string& path,
570578 SDImageOwner& image,
@@ -588,31 +596,32 @@ int main(int argc, const char* argv[]) {
588596
589597 if (gen_params.init_image_path .size () > 0 ) {
590598 vae_decode_only = false ;
591- if (!load_image_and_update_size (gen_params.init_image_path , init_image)) {
599+ if (!load_image_and_update_size (gen_params.init_image_path , gen_params. init_image )) {
592600 return 1 ;
593601 }
594602 }
595603
596604 if (gen_params.end_image_path .size () > 0 ) {
597605 vae_decode_only = false ;
598- if (!load_image_and_update_size (gen_params.end_image_path , end_image)) {
606+ if (!load_image_and_update_size (gen_params.end_image_path , gen_params. end_image )) {
599607 return 1 ;
600608 }
601609 }
602610
603611 if (gen_params.ref_image_paths .size () > 0 ) {
604612 vae_decode_only = false ;
613+ gen_params.ref_images .clear ();
605614 for (auto & path : gen_params.ref_image_paths ) {
606615 SDImageOwner ref_image ({0 , 0 , 3 , nullptr });
607616 if (!load_image_and_update_size (path, ref_image, false )) {
608617 return 1 ;
609618 }
610- ref_images.push_back (std::move (ref_image));
619+ gen_params. ref_images .push_back (std::move (ref_image));
611620 }
612621 }
613622
614623 if (gen_params.mask_image_path .size () > 0 ) {
615- if (!load_sd_image_from_file (mask_image.put (),
624+ if (!load_sd_image_from_file (gen_params. mask_image .put (),
616625 gen_params.mask_image_path .c_str (),
617626 gen_params.get_resolved_width (),
618627 gen_params.get_resolved_height (),
@@ -630,19 +639,19 @@ int main(int argc, const char* argv[]) {
630639 generated_mask.width = gen_params.get_resolved_width ();
631640 generated_mask.height = gen_params.get_resolved_height ();
632641 memset (generated_mask.data , 255 , gen_params.get_resolved_width () * gen_params.get_resolved_height ());
633- mask_image.reset (generated_mask);
642+ gen_params. mask_image .reset (generated_mask);
634643 }
635644
636645 if (gen_params.control_image_path .size () > 0 ) {
637- if (!load_sd_image_from_file (control_image.put (),
646+ if (!load_sd_image_from_file (gen_params. control_image .put (),
638647 gen_params.control_image_path .c_str (),
639648 gen_params.get_resolved_width (),
640649 gen_params.get_resolved_height ())) {
641650 LOG_ERROR (" load image from '%s' failed" , gen_params.control_image_path .c_str ());
642651 return 1 ;
643652 }
644653 if (cli_params.canny_preprocess ) { // apply preprocessor
645- preprocess_canny (control_image.get (),
654+ preprocess_canny (gen_params. control_image .get (),
646655 0 .08f ,
647656 0 .08f ,
648657 0 .8f ,
@@ -652,8 +661,9 @@ int main(int argc, const char* argv[]) {
652661 }
653662
654663 if (!gen_params.control_video_path .empty ()) {
664+ gen_params.control_frames .clear ();
655665 if (!load_images_from_dir (gen_params.control_video_path ,
656- control_frames,
666+ gen_params. control_frames ,
657667 gen_params.get_resolved_width (),
658668 gen_params.get_resolved_height (),
659669 gen_params.video_frames ,
@@ -663,8 +673,9 @@ int main(int argc, const char* argv[]) {
663673 }
664674
665675 if (!gen_params.pm_id_images_dir .empty ()) {
676+ gen_params.pm_id_images .clear ();
666677 if (!load_images_from_dir (gen_params.pm_id_images_dir ,
667- pmid_images ,
678+ gen_params. pm_id_images ,
668679 0 ,
669680 0 ,
670681 0 ,
@@ -684,7 +695,7 @@ int main(int argc, const char* argv[]) {
684695
685696 if (cli_params.mode == UPSCALE) {
686697 num_results = 1 ;
687- results.push_back (init_image.release ());
698+ results.push_back (gen_params. init_image .release ());
688699 } else {
689700 SDCtxPtr sd_ctx (new_sd_ctx (&sd_ctx_params));
690701
@@ -706,63 +717,13 @@ int main(int argc, const char* argv[]) {
706717 }
707718
708719 if (cli_params.mode == IMG_GEN) {
709- sd_img_gen_params_t img_gen_params = {
710- gen_params.lora_vec .data (),
711- static_cast <uint32_t >(gen_params.lora_vec .size ()),
712- gen_params.prompt .c_str (),
713- gen_params.negative_prompt .c_str (),
714- gen_params.clip_skip ,
715- init_image.get (),
716- ref_images.data (),
717- (int )ref_images.size (),
718- gen_params.auto_resize_ref_image ,
719- gen_params.increase_ref_index ,
720- mask_image.get (),
721- gen_params.get_resolved_width (),
722- gen_params.get_resolved_height (),
723- gen_params.sample_params ,
724- gen_params.strength ,
725- gen_params.seed ,
726- gen_params.batch_count ,
727- control_image.get (),
728- gen_params.control_strength ,
729- {
730- pmid_images.data (),
731- (int )pmid_images.size (),
732- gen_params.pm_id_embed_path .c_str (),
733- gen_params.pm_style_strength ,
734- }, // pm_params
735- gen_params.vae_tiling_params ,
736- gen_params.cache_params ,
737- };
720+ sd_img_gen_params_t img_gen_params = gen_params.to_sd_img_gen_params_t ();
738721
739722 num_results = gen_params.batch_count ;
740723 results.adopt (generate_image (sd_ctx.get (), &img_gen_params), num_results);
741724 } else if (cli_params.mode == VID_GEN) {
742- sd_vid_gen_params_t vid_gen_params = {
743- gen_params.lora_vec .data (),
744- static_cast <uint32_t >(gen_params.lora_vec .size ()),
745- gen_params.prompt .c_str (),
746- gen_params.negative_prompt .c_str (),
747- gen_params.clip_skip ,
748- init_image.get (),
749- end_image.get (),
750- control_frames.data (),
751- (int )control_frames.size (),
752- gen_params.get_resolved_width (),
753- gen_params.get_resolved_height (),
754- gen_params.sample_params ,
755- gen_params.high_noise_sample_params ,
756- gen_params.moe_boundary ,
757- gen_params.strength ,
758- gen_params.seed ,
759- gen_params.video_frames ,
760- gen_params.vace_strength ,
761- gen_params.vae_tiling_params ,
762- gen_params.cache_params ,
763- };
764-
765- sd_image_t * generated_video = generate_video (sd_ctx.get (), &vid_gen_params, &num_results);
725+ sd_vid_gen_params_t vid_gen_params = gen_params.to_sd_vid_gen_params_t ();
726+ sd_image_t * generated_video = generate_video (sd_ctx.get (), &vid_gen_params, &num_results);
766727 results.adopt (generated_video, num_results);
767728 }
768729
0 commit comments