Skip to content

Add missing .float() for bf16 inference for FSQ models#1

Open
NilanEkanayake wants to merge 1 commit intomicrosoft:mainfrom
NilanEkanayake:patch-1
Open

Add missing .float() for bf16 inference for FSQ models#1
NilanEkanayake wants to merge 1 commit intomicrosoft:mainfrom
NilanEkanayake:patch-1

Conversation

@NilanEkanayake
Copy link

@NilanEkanayake NilanEkanayake commented Dec 18, 2024

Simple change that allows BF 16 inference like so:

model.to(torch.bfloat16)
with torch.no_grad(), torch.autocast(device_type='cuda:0', dtype=torch.bfloat16):
    _, xrec, _ = model(input)

...although loss calculation needn't be run during inference anyways.

@deeptimhe
Copy link
Contributor

Hi, thanks for the contribution! We are working on testing the loss for bfloat16 inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants