@@ -2225,6 +2225,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
22252225 sample_params->scheduler = SCHEDULER_COUNT;
22262226 sample_params->sample_method = SAMPLE_METHOD_COUNT;
22272227 sample_params->sample_steps = 20 ;
2228+ sample_params->eta = INFINITY;
22282229 sample_params->custom_sigmas = nullptr ;
22292230 sample_params->custom_sigmas_count = 0 ;
22302231 sample_params->flow_shift = INFINITY;
@@ -2438,6 +2439,25 @@ static scheduler_t resolve_scheduler(sd_ctx_t* sd_ctx,
24382439 return scheduler;
24392440}
24402441
2442+ static float resolve_eta (sd_ctx_t * sd_ctx,
2443+ float eta,
2444+ enum sample_method_t sample_method) {
2445+ if (eta == INFINITY) {
2446+ switch (sample_method) {
2447+ case DDIM_TRAILING_SAMPLE_METHOD:
2448+ case TCD_SAMPLE_METHOD:
2449+ case RES_MULTISTEP_SAMPLE_METHOD:
2450+ case RES_2S_SAMPLE_METHOD:
2451+ return 0 .0f ;
2452+ case EULER_A_SAMPLE_METHOD:
2453+ case DPMPP2S_A_SAMPLE_METHOD:
2454+ return 1 .0f ;
2455+ default : ;
2456+ }
2457+ }
2458+ return eta;
2459+ }
2460+
24412461struct GenerationRequest {
24422462 std::string prompt;
24432463 std::string negative_prompt;
@@ -2586,6 +2606,7 @@ struct GenerationRequest {
25862606struct SamplePlan {
25872607 enum sample_method_t sample_method = SAMPLE_METHOD_COUNT;
25882608 enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT;
2609+ float eta = 0 .f;
25892610 int sample_steps = 0 ;
25902611 int high_noise_sample_steps = 0 ;
25912612 int total_steps = 0 ;
@@ -2597,6 +2618,7 @@ struct SamplePlan {
25972618 const sd_img_gen_params_t * sd_img_gen_params,
25982619 const GenerationRequest& request) {
25992620 sample_method = sd_img_gen_params->sample_params .sample_method ;
2621+ eta = sd_img_gen_params->sample_params .eta ;
26002622 sample_steps = sd_img_gen_params->sample_params .sample_steps ;
26012623 resolve (sd_ctx, &request, &sd_img_gen_params->sample_params );
26022624 }
@@ -2644,6 +2666,8 @@ struct SamplePlan {
26442666 sd_ctx->sd ->version );
26452667 }
26462668
2669+ eta = resolve_eta (sd_ctx, eta, sample_method);
2670+
26472671 if (high_noise_sample_steps < 0 ) {
26482672 for (size_t i = 0 ; i < sigmas.size (); ++i) {
26492673 if (sigmas[i] < moe_boundary) {
@@ -3123,7 +3147,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
31233147 latents.control_image ,
31243148 request.control_strength ,
31253149 request.guidance ,
3126- request .eta ,
3150+ plan .eta ,
31273151 request.shifted_timestep ,
31283152 plan.sample_method ,
31293153 plan.sigmas ,
0 commit comments