Skip to content

Commit fd00632

Browse files
committed
feat: option to enable vae tiling automatically for large images
1 parent 545fac4 commit fd00632

3 files changed

Lines changed: 53 additions & 9 deletions

File tree

examples/cli/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,8 @@ int main(int argc, const char* argv[]) {
750750
gen_params.pm_id_embed_path.c_str(),
751751
gen_params.pm_style_strength,
752752
}, // pm_params
753-
ctx_params.vae_tiling_params,
753+
ctx_params.get_tiling_params(gen_params.get_resolved_width(),
754+
gen_params.get_resolved_height()),
754755
gen_params.cache_params,
755756
};
756757

@@ -776,7 +777,8 @@ int main(int argc, const char* argv[]) {
776777
gen_params.seed,
777778
gen_params.video_frames,
778779
gen_params.vace_strength,
779-
ctx_params.vae_tiling_params,
780+
ctx_params.get_tiling_params(gen_params.get_resolved_width(),
781+
gen_params.get_resolved_height()),
780782
gen_params.cache_params,
781783
};
782784

examples/common/common.hpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ struct SDContextParams {
475475
prediction_t prediction = PREDICTION_COUNT;
476476
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
477477

478+
int vae_tiling_threshold = 0;
478479
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
479480
bool force_sdxl_vae_conv_scale = false;
480481

@@ -584,10 +585,6 @@ struct SDContextParams {
584585
};
585586

586587
options.bool_options = {
587-
{"",
588-
"--vae-tiling",
589-
"process vae in tiles to reduce memory usage",
590-
true, &vae_tiling_params.enabled},
591588
{"",
592589
"--force-sdxl-vae-conv-scale",
593590
"force use of conv scale on sdxl vae",
@@ -724,6 +721,33 @@ struct SDContextParams {
724721
return 1;
725722
};
726723

724+
auto on_tiling_threshold = [&](int argc, const char** argv, int index) {
725+
vae_tiling_threshold = 1;
726+
if (++index >= argc) {
727+
return 0;
728+
}
729+
size_t pos = 0;
730+
std::string threshold_str = argv[index];
731+
int result = -1;
732+
try {
733+
result = std::stoi(threshold_str, &pos);
734+
} catch (const std::invalid_argument&) {
735+
// check if it's likely to be another flag
736+
return (threshold_str.rfind("-", 0) == 0) ? 0 : -1;
737+
} catch (const std::out_of_range&) {
738+
return -1;
739+
}
740+
if (pos != threshold_str.length() || result < 0) {
741+
return -1;
742+
}
743+
if (result > 32768) {
744+
// avoid overflow if the user disabled tiling by using a huge value
745+
result = 0;
746+
}
747+
vae_tiling_threshold = result;
748+
return 1;
749+
};
750+
727751
auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
728752
if (++index >= argc) {
729753
return -1;
@@ -796,6 +820,11 @@ struct SDContextParams {
796820
"but it usually offers faster inference speed and, in some cases, lower memory usage. "
797821
"The at_runtime mode, on the other hand, is exactly the opposite.",
798822
on_lora_apply_mode_arg},
823+
{"",
824+
"--vae-tiling",
825+
"process vae in tiles to reduce memory usage. Optionally receives a size threshold T, which will "
826+
"turn on tiling only for images larger than TxT.",
827+
on_tiling_threshold},
799828
{"",
800829
"--vae-tile-size",
801830
"tile size for vae tiling, format [X]x[Y] (default: 32x32)",
@@ -924,6 +953,7 @@ struct SDContextParams {
924953
<< vae_tiling_params.target_overlap << ", "
925954
<< vae_tiling_params.rel_size_x << ", "
926955
<< vae_tiling_params.rel_size_y << " },\n"
956+
<< " vae_tiling_threshold: " << vae_tiling_threshold << ",\n"
927957
<< " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n"
928958
<< "}";
929959
return oss.str();
@@ -984,6 +1014,18 @@ struct SDContextParams {
9841014
};
9851015
return sd_ctx_params;
9861016
}
1017+
1018+
sd_tiling_params_t get_tiling_params(int width, int height) {
1019+
sd_tiling_params_t params = vae_tiling_params;
1020+
if (vae_tiling_threshold == 0) {
1021+
params.enabled = false;
1022+
} else {
1023+
int area = width * height;
1024+
int threshold = vae_tiling_threshold * vae_tiling_threshold;
1025+
params.enabled = (area > threshold);
1026+
}
1027+
return params;
1028+
}
9871029
};
9881030

9891031
template <typename T>

examples/server/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ int main(int argc, const char** argv) {
525525
gen_params.pm_id_embed_path.c_str(),
526526
gen_params.pm_style_strength,
527527
}, // pm_params
528-
ctx_params.vae_tiling_params,
528+
ctx_params.get_tiling_params(gen_params.width, gen_params.height),
529529
gen_params.cache_params,
530530
};
531531

@@ -772,7 +772,7 @@ int main(int argc, const char** argv) {
772772
gen_params.pm_id_embed_path.c_str(),
773773
gen_params.pm_style_strength,
774774
}, // pm_params
775-
ctx_params.vae_tiling_params,
775+
ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()),
776776
gen_params.cache_params,
777777
};
778778

@@ -1088,7 +1088,7 @@ int main(int argc, const char** argv) {
10881088
gen_params.pm_id_embed_path.c_str(),
10891089
gen_params.pm_style_strength,
10901090
}, // pm_params
1091-
ctx_params.vae_tiling_params,
1091+
ctx_params.get_tiling_params(get_resolved_width(), get_resolved_height()),
10921092
gen_params.cache_params,
10931093
};
10941094

0 commit comments

Comments
 (0)