Skip to content

[Update] A better sinkhorn implementation.#47

Open
ZhangZhiPku wants to merge 1 commit intoNVIDIA:mainfrom
ZhangZhiPku:patch-1
Open

[Update] A better sinkhorn implementation.#47
ZhangZhiPku wants to merge 1 commit intoNVIDIA:mainfrom
ZhangZhiPku:patch-1

Conversation

@ZhangZhiPku
Copy link

In the original implementation, the Sinkhorn operator processes n*n data blocks using a single block, which causes a large number of threads to remain idle when n is small. Therefore, introducing a new parallel dimension m can significantly improve operator performance.

In the original implementation, the Sinkhorn operator processes n*n data blocks using a single block, which causes a large number of threads to remain idle when n is small. Therefore, introducing a new parallel dimension m can significantly improve operator performance.
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 2, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ZhangZhiPku
Copy link
Author

Test Code:

import torch
import cuda.tile as ct

@ct.kernel
def sinkhorn(M: ct.Array, O: ct.Array, iter: int, tile_size: ct.Constant):
    # block binding: M[B(block_y 1:1), S(block_x 1:1), K, K]
    block_x, block_y = ct.bid(0), ct.bid(1)
    tile = ct.load(M, (block_y, block_x, 0, 0), (1, 1, tile_size, tile_size))
    tile = tile.astype(ct.float32)
    tile = tile - ct.max(tile)
    tile = ct.exp(tile)

    for i in range(iter):
        tile = ct.truediv(tile, ct.sum(tile, axis=-1, keepdims=True) + 1e-7)
        tile = ct.truediv(tile, ct.sum(tile, axis=-2, keepdims=True) + 1e-7)

    ct.store(O, (block_y, block_x, 0, 0), tile.astype(M.dtype))

@ct.kernel
def sinkhorn2(
    M: ct.Array, O: ct.Array, iter: int, tileS: ct.Constant, tileK: ct.Constant
):
    # block binding: M[B(block_y 1:1), S(block_x 1:1), K, K]
    block_x, block_y = ct.bid(0), ct.bid(1)
    tile = ct.load(M, (block_y, block_x, 0, 0), (1, tileS, tileK, tileK))
    tile = tile.astype(ct.float32)
    tile = tile - ct.max(tile)
    tile = ct.exp(tile)

    for i in range(iter):
        tile = ct.truediv(tile, ct.sum(tile, axis=-1, keepdims=True) + 1e-7)
        tile = ct.truediv(tile, ct.sum(tile, axis=-2, keepdims=True) + 1e-7)

    ct.store(O, (block_y, block_x, 0, 0), tile.astype(M.dtype))

B, S, K = 1, 12800, 4
M = torch.randn(size=[B, S, K, K], device="cuda", dtype=torch.bfloat16)
O = torch.empty_like(M)
ct.launch(torch.cuda.current_stream(), (S, B), sinkhorn, (M, O, 20, K))
print(O)

ct.launch(torch.cuda.current_stream(), (ct.cdiv(S, 8), B), sinkhorn2, (M, O, 20, 8, K))
print(O)

Result:
sinkhorn: 85509ns
sinkhorn2: 15937ns

@hannahli-nv
Copy link
Collaborator

Hi @ZhangZhiPku , thank you for your contribution.
As this is your first time contributing to TileGym, please submit your signed CLA document as described in CONTRIBUTING.md.
Thank you very much.

@hannahli-nv
Copy link
Collaborator

/ok to test 97b525d

@xjmxyt
Copy link
Collaborator

xjmxyt commented Feb 6, 2026

Hello, we found test_op_mhc_sinkhorn failed. Could you please modify the code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments