Skip to content

Commit 738bc8e

Browse files
committed
feat: add support for the eta parameter to ancestral samplers
1 parent 09b12d5 commit 738bc8e

File tree

5 files changed

+39
-13
lines changed

5 files changed

+39
-13
lines changed

examples/cli/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ Generation Options:
114114
medium
115115
--skip-layer-start <float> SLG enabling point (default: 0.01)
116116
--skip-layer-end <float> SLG disabling point (default: 0.2)
117-
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
117+
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
118118
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
119119
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
120120
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
121121
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
122122
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
123123
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
124124
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
125-
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
125+
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
126126
--strength <float> strength for noising/unnoising (default: 0.75)
127127
--pm-style-strength <float>
128128
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

examples/common/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ struct SDGenerationParams {
11311131
&sample_params.guidance.slg.layer_end},
11321132
{"",
11331133
"--eta",
1134-
"eta in DDIM, only for DDIM and TCD (default: 0)",
1134+
"noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
11351135
&sample_params.eta},
11361136
{"",
11371137
"--flow-shift",
@@ -1163,7 +1163,7 @@ struct SDGenerationParams {
11631163
&high_noise_sample_params.guidance.slg.layer_end},
11641164
{"",
11651165
"--high-noise-eta",
1166-
"(high noise) eta in DDIM, only for DDIM and TCD (default: 0)",
1166+
"(high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)",
11671167
&high_noise_sample_params.eta},
11681168
{"",
11691169
"--strength",

examples/server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ Default Generation Options:
189189
medium
190190
--skip-layer-start <float> SLG enabling point (default: 0.01)
191191
--skip-layer-end <float> SLG disabling point (default: 0.2)
192-
--eta <float> eta in DDIM, only for DDIM and TCD (default: 0)
192+
--eta <float> noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
193193
--flow-shift <float> shift value for Flow models like SD3.x or WAN (default: auto)
194194
--high-noise-cfg-scale <float> (high noise) unconditional guidance scale: (default: 7.0)
195195
--high-noise-img-cfg-scale <float> (high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)
196196
--high-noise-guidance <float> (high noise) distilled guidance scale for models with guidance input (default: 3.5)
197197
--high-noise-slg-scale <float> (high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)
198198
--high-noise-skip-layer-start <float> (high noise) SLG enabling point (default: 0.01)
199199
--high-noise-skip-layer-end <float> (high noise) SLG disabling point (default: 0.2)
200-
--high-noise-eta <float> (high noise) eta in DDIM, only for DDIM and TCD (default: 0)
200+
--high-noise-eta <float> (high noise) noise multiplier (default: 0 for ddim_trailing, tcd, res_multistep and res_2s; 1 for euler_a and dpm++2s_a)
201201
--strength <float> strength for noising/unnoising (default: 0.75)
202202
--pm-style-strength <float>
203203
--control-strength <float> strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image

src/denoiser.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,8 @@ static std::pair<float, float> get_ancestral_step(float sigma_from,
789789
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
790790
sd::Tensor<float> x,
791791
const std::vector<float>& sigmas,
792-
std::shared_ptr<RNG> rng) {
792+
std::shared_ptr<RNG> rng,
793+
float eta) {
793794
int steps = static_cast<int>(sigmas.size()) - 1;
794795
for (int i = 0; i < steps; i++) {
795796
float sigma = sigmas[i];
@@ -799,7 +800,7 @@ static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
799800
}
800801
sd::Tensor<float> denoised = std::move(denoised_opt);
801802
sd::Tensor<float> d = (x - denoised) / sigma;
802-
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
803+
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
803804
x += d * (sigma_down - sigmas[i]);
804805
if (sigmas[i + 1] > 0) {
805806
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
@@ -885,7 +886,8 @@ static sd::Tensor<float> sample_dpm2(denoise_cb_t model,
885886
static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
886887
sd::Tensor<float> x,
887888
const std::vector<float>& sigmas,
888-
std::shared_ptr<RNG> rng) {
889+
std::shared_ptr<RNG> rng,
890+
float eta) {
889891
auto t_fn = [](float sigma) -> float { return -log(sigma); };
890892
auto sigma_fn = [](float t) -> float { return exp(-t); };
891893

@@ -896,7 +898,7 @@ static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
896898
return {};
897899
}
898900
sd::Tensor<float> denoised = std::move(denoised_opt);
899-
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
901+
auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1], eta);
900902

901903
if (sigma_down == 0) {
902904
x = denoised;
@@ -1371,15 +1373,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
13711373
float eta) {
13721374
switch (method) {
13731375
case EULER_A_SAMPLE_METHOD:
1374-
return sample_euler_ancestral(model, std::move(x), sigmas, rng);
1376+
return sample_euler_ancestral(model, std::move(x), sigmas, rng, eta);
13751377
case EULER_SAMPLE_METHOD:
13761378
return sample_euler(model, std::move(x), sigmas);
13771379
case HEUN_SAMPLE_METHOD:
13781380
return sample_heun(model, std::move(x), sigmas);
13791381
case DPM2_SAMPLE_METHOD:
13801382
return sample_dpm2(model, std::move(x), sigmas);
13811383
case DPMPP2S_A_SAMPLE_METHOD:
1382-
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng);
1384+
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng, eta);
13831385
case DPMPP2M_SAMPLE_METHOD:
13841386
return sample_dpmpp_2m(model, std::move(x), sigmas);
13851387
case DPMPP2Mv2_SAMPLE_METHOD:

src/stable-diffusion.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
22252225
sample_params->scheduler = SCHEDULER_COUNT;
22262226
sample_params->sample_method = SAMPLE_METHOD_COUNT;
22272227
sample_params->sample_steps = 20;
2228+
sample_params->eta = INFINITY;
22282229
sample_params->custom_sigmas = nullptr;
22292230
sample_params->custom_sigmas_count = 0;
22302231
sample_params->flow_shift = INFINITY;
@@ -2438,6 +2439,25 @@ static scheduler_t resolve_scheduler(sd_ctx_t* sd_ctx,
24382439
return scheduler;
24392440
}
24402441

2442+
static float resolve_eta(sd_ctx_t* sd_ctx,
2443+
float eta,
2444+
enum sample_method_t sample_method) {
2445+
if (eta == INFINITY) {
2446+
switch(sample_method) {
2447+
case DDIM_TRAILING_SAMPLE_METHOD:
2448+
case TCD_SAMPLE_METHOD:
2449+
case RES_MULTISTEP_SAMPLE_METHOD:
2450+
case RES_2S_SAMPLE_METHOD:
2451+
return 0.0f;
2452+
case EULER_A_SAMPLE_METHOD:
2453+
case DPMPP2S_A_SAMPLE_METHOD:
2454+
return 1.0f;
2455+
default: ;
2456+
}
2457+
}
2458+
return eta;
2459+
}
2460+
24412461
struct GenerationRequest {
24422462
std::string prompt;
24432463
std::string negative_prompt;
@@ -2586,6 +2606,7 @@ struct GenerationRequest {
25862606
struct SamplePlan {
25872607
enum sample_method_t sample_method = SAMPLE_METHOD_COUNT;
25882608
enum sample_method_t high_noise_sample_method = SAMPLE_METHOD_COUNT;
2609+
float eta = 0.f;
25892610
int sample_steps = 0;
25902611
int high_noise_sample_steps = 0;
25912612
int total_steps = 0;
@@ -2597,6 +2618,7 @@ struct SamplePlan {
25972618
const sd_img_gen_params_t* sd_img_gen_params,
25982619
const GenerationRequest& request) {
25992620
sample_method = sd_img_gen_params->sample_params.sample_method;
2621+
eta = sd_img_gen_params->sample_params.eta;
26002622
sample_steps = sd_img_gen_params->sample_params.sample_steps;
26012623
resolve(sd_ctx, &request, &sd_img_gen_params->sample_params);
26022624
}
@@ -2644,6 +2666,8 @@ struct SamplePlan {
26442666
sd_ctx->sd->version);
26452667
}
26462668

2669+
eta = resolve_eta(sd_ctx, eta, sample_method);
2670+
26472671
if (high_noise_sample_steps < 0) {
26482672
for (size_t i = 0; i < sigmas.size(); ++i) {
26492673
if (sigmas[i] < moe_boundary) {
@@ -3123,7 +3147,7 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
31233147
latents.control_image,
31243148
request.control_strength,
31253149
request.guidance,
3126-
request.eta,
3150+
plan.eta,
31273151
request.shifted_timestep,
31283152
plan.sample_method,
31293153
plan.sigmas,

0 commit comments

Comments
 (0)