-
Notifications
You must be signed in to change notification settings - Fork 56
Add support for TransformerEngine flash attention in WAN #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@cpersson-amd I've been out on PTO for a month. I'll take a closer look at this next week. Meanwhile, can you update your branch with the latest in main. Thanks. |
entrpn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the PR looks good, but I'm still unsure if adding another axes, fsdp_batch, is really necessary. I would prefer not to add it. The other major thing is switching the mesh_axes from data, fsdp, tensor to data, tensor, fsdp.
|
@susanbao can you take a quick look at this PR. |
|
@cpersson-amd please review Sanbao's comments above and rebase with main. We tested the PR internally and it looks good. Would you be willing to change the axis fsdp to context? If not, I can make the change after this PR is merged. |
thanks @cpersson-amd this looks great. Can you run |
|
@entrpn Sure, I ran 'ruff check --fix' and had to manually fix some bare except statements. It should be good with the latest commit |
|
@cpersson-amd Please review my PR to fix some of the unit tests. Once they pass, this can be merged. cpersson-amd#1 |
|
@entrpn PR looks good and is merged, I rebased with main and double checked for errors with ruff. Hopefully it is good to go now. |
|
thanks @cpersson-amd its been merged. |
* add flash attn te support for wan * add gpu optimized sharding parallelism * sharding bugfixes * generalize across sharding parallelisms * fix issue with inference using fsdp + te flash attention * revert fsdp_tpu name change * update readme with wan2.1 gpu notes * re-order parallelism axes and revert dynamic context parallel axes selection * remove unused max_utils imports * change mesh names to more accurately reflect sharding * cleanup * fix lint errors * update configs for unit tests. --------- Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
This PR implements the following:
The code has been tested on WAN 2.1 (training and inference) and flux (only training) using GPUs.