Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 1 addition & 58 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,63 +225,6 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
}
}

std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
std::string parameter_string = gen_params.prompt_with_lora + "\n";
if (gen_params.negative_prompt.size() != 0) {
parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n";
}
parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : gen_params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
if (!gen_params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
}
if (!ctx_params.diffusion_model_path.empty()) {
parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", ";
}
if (!ctx_params.vae_path.empty()) {
parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", ";
}
if (gen_params.clip_skip != -1) {
parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
}
parameter_string += "Version: stable-diffusion.cpp";
return parameter_string;
}

void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
SDCliParams* cli_params = (SDCliParams*)data;
log_print(level, log, cli_params->verbose, cli_params->color);
Expand Down Expand Up @@ -411,7 +354,7 @@ bool save_results(const SDCliParams& cli_params,
if (!img.data)
return;

std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
std::string params = get_image_params(ctx_params, gen_params, gen_params.seed + idx);
int ok = 0;
if (is_jpg) {
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
Expand Down
63 changes: 63 additions & 0 deletions examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2092,3 +2092,66 @@ uint8_t* load_image_from_memory(const char* image_bytes,
int expected_channel = 3) {
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
}

std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
std::string parameter_string;
if (gen_params.prompt_with_lora.size() != 0) {
parameter_string += gen_params.prompt_with_lora + "\n";
} else {
parameter_string += gen_params.prompt + "\n";
}
if (gen_params.negative_prompt.size() != 0) {
parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n";
}
parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : gen_params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
}
parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
if (!gen_params.custom_sigmas.empty()) {
parameter_string += ", Custom Sigmas: [";
for (size_t i = 0; i < gen_params.custom_sigmas.size(); ++i) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(4) << gen_params.custom_sigmas[i];
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() - 1 ? "" : ", ");
}
parameter_string += "]";
} else if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) { // Only show schedule if not using custom sigmas
parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
}
if (!ctx_params.diffusion_model_path.empty()) {
parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", ";
}
if (!ctx_params.vae_path.empty()) {
parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", ";
}
if (gen_params.clip_skip != -1) {
parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
}
parameter_string += "Version: stable-diffusion.cpp";
return parameter_string;
}

21 changes: 19 additions & 2 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,23 @@ std::string extract_and_remove_sd_cpp_extra_args(std::string& text) {
enum class ImageFormat { JPEG,
PNG };

static int stbi_ext_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes, const char* parameters)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this function. Maybe I should change stb_image_write.h instead?

{
int len;
unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len, parameters);
if (png == NULL) return 0;
func(context, png, len);
STBIW_FREE(png);
return 1;
}

std::vector<uint8_t> write_image_to_vector(
ImageFormat format,
const uint8_t* image,
int width,
int height,
int channels,
std::string params = "",
int quality = 90) {
std::vector<uint8_t> buffer;

Expand All @@ -245,7 +256,7 @@ std::vector<uint8_t> write_image_to_vector(
result = stbi_write_jpg_to_func(c_func, &ctx, width, height, channels, image, quality);
break;
case ImageFormat::PNG:
result = stbi_write_png_to_func(c_func, &ctx, width, height, channels, image, width * channels);
result = stbi_ext_write_png_to_func(c_func, &ctx, width, height, channels, image, width * channels, params.c_str());
break;
default:
throw std::runtime_error("invalid image format");
Expand Down Expand Up @@ -464,11 +475,13 @@ int main(int argc, const char** argv) {
if (results[i].data == nullptr) {
continue;
}
std::string params = get_image_params(ctx_params, gen_params, gen_params.seed + i);
auto image_bytes = write_image_to_vector(output_format == "jpeg" ? ImageFormat::JPEG : ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params,
output_compression);
if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
Expand Down Expand Up @@ -684,11 +697,13 @@ int main(int argc, const char** argv) {
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr)
continue;
std::string params = get_image_params(ctx_params, gen_params, gen_params.seed + i);
auto image_bytes = write_image_to_vector(output_format == "jpeg" ? ImageFormat::JPEG : ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
params,
output_compression);
std::string b64 = base64_encode(image_bytes);
json item;
Expand Down Expand Up @@ -938,11 +953,13 @@ int main(int argc, const char** argv) {
continue;
}

std::string params = get_image_params(ctx_params, gen_params, gen_params.seed + i);
auto image_bytes = write_image_to_vector(ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel);
results[i].channel,
params);

if (image_bytes.empty()) {
LOG_ERROR("write image to mem failed");
Expand Down
Loading