Skip to content

Commit 27c01ea

Browse files
committed
Fix Flux2 DreamBooth prior preservation prompt repeats
1 parent fbe8a75 commit 27c01ea

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401740
prompt_embeds = prompt_embeds_cache[step]
17411741
text_ids = text_ids_cache[step]
17421742
else:
1743-
num_repeat_elements = len(prompts)
1743+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings,
1744+
# while collate_fn doubles the prompts list. Repeat by the number of instance prompts only.
1745+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
17441746
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
17451747
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
17461748

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16801680
prompt_embeds = prompt_embeds_cache[step]
16811681
text_ids = text_ids_cache[step]
16821682
else:
1683-
num_repeat_elements = len(prompts)
1683+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings,
1684+
# while collate_fn doubles the prompts list. Repeat by the number of instance prompts only.
1685+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
16841686
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
16851687
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
16861688

0 commit comments

Comments
 (0)