Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 142 additions & 126 deletions lightllm/common/kv_trans_kernel/nixl_kv_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
@triton.jit
def _page_io(
mem_index_ptr,
token_num,
page_write_head_num,
k_page_ptr,
k_page_stride_size,
k_page_stride_layer_num,
Expand Down Expand Up @@ -45,88 +47,91 @@ def _page_io(
k_stride_size = tl.cast(k_stride_size, dtype=tl.int64)
v_stride_size = tl.cast(v_stride_size, dtype=tl.int64)

tid = tl.program_id(0)
kv_head_id = tl.program_id(1)
page_head_id = page_head_start + kv_head_id
start_index = tl.program_id(0)
grid_num = tl.num_programs(0)

mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None
for tid in tl.range(start_index, token_num, step=grid_num):
for kv_head_id in tl.range(page_write_head_num):

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
k_tensor = tl.load(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim * k_stride_dim,
mask=mask,
)
v_tensor = tl.load(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim * v_stride_dim,
mask=mask,
)
tl.store(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim * k_page_stride_dim,
k_tensor,
mask=mask,
)
tl.store(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim * v_page_stride_dim,
v_tensor,
mask=mask,
)
else:
k_page_tensor = tl.load(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim * k_page_stride_dim,
mask=mask,
)
v_page_tensor = tl.load(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim * v_page_stride_dim,
mask=mask,
)
tl.store(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim * k_stride_dim,
k_page_tensor,
mask=mask,
)
tl.store(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim * v_stride_dim,
v_page_tensor,
mask=mask,
)
page_head_id = page_head_start + kv_head_id
mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
k_tensor = tl.load(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim,
mask=mask,
)
v_tensor = tl.load(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim,
mask=mask,
)
tl.store(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim,
k_tensor,
mask=mask,
)
tl.store(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim,
v_tensor,
mask=mask,
)
else:
k_page_tensor = tl.load(
k_page_ptr
+ tid * k_page_stride_size
+ layer_index * k_page_stride_layer_num
+ page_head_id * k_page_stride_head
+ off_dim,
mask=mask,
)
v_page_tensor = tl.load(
v_page_ptr
+ tid * v_page_stride_size
+ layer_index * v_page_stride_layer_num
+ page_head_id * v_page_stride_head
+ off_dim,
mask=mask,
)
tl.store(
k_ptr
+ layer_index * k_stride_layer_num
+ mem_index * k_stride_size
+ kv_head_id * k_stride_head
+ off_dim,
k_page_tensor,
mask=mask,
)
tl.store(
v_ptr
+ layer_index * v_stride_layer_num
+ mem_index * v_stride_size
+ kv_head_id * v_stride_head
+ off_dim,
v_page_tensor,
mask=mask,
)
return


Expand Down Expand Up @@ -169,10 +174,17 @@ def page_io(
page_head_start = tp_index * (page_write_head_num)

token_num = len(mem_indexes)
grid = (token_num, page_write_head_num)
grid = (128,)

assert k_page_tensor.stride(3) == 1
assert v_page_tensor.stride(3) == 1
assert k_buffer.stride(3) == 1
assert v_buffer.stride(3) == 1

_page_io[grid](
mem_index_ptr=mem_indexes,
token_num=token_num,
page_write_head_num=page_write_head_num,
k_page_ptr=k_page_tensor,
k_page_stride_size=k_page_tensor.stride(0),
k_page_stride_layer_num=k_page_tensor.stride(1),
Expand Down Expand Up @@ -207,6 +219,7 @@ def page_io(
@triton.jit
def _mla_page_io(
mem_index_ptr,
token_num,
page_ptr,
page_stride_size,
page_stride_layer_num,
Expand All @@ -227,52 +240,54 @@ def _mla_page_io(
kv_stride_layer_num = tl.cast(kv_stride_layer_num, dtype=tl.int64)
kv_stride_size = tl.cast(kv_stride_size, dtype=tl.int64)

tid = tl.program_id(0)
start_index = tl.program_id(0)
grid_num = tl.num_programs(0)

mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
kv_tensor = tl.load(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
mask=mask,
)
tl.store(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
kv_tensor,
mask=mask,
)
for tid in tl.range(start_index, token_num, step=grid_num):
mem_index = tl.load(mem_index_ptr + tid)
off_dim = tl.arange(0, HEAD_DIM_BLOCK)
if NEED_MASK:
mask = off_dim < head_dim
else:
page_tensor = tl.load(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
mask=mask,
)
tl.store(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
page_tensor,
mask=mask,
)
mask = None

for layer_index in tl.range(layer_num, num_stages=3):
if IS_WRITE:
kv_tensor = tl.load(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
mask=mask,
)
tl.store(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
kv_tensor,
mask=mask,
)
else:
page_tensor = tl.load(
page_ptr
+ tid * page_stride_size
+ layer_index * page_stride_layer_num
+ 0 * page_stride_head
+ off_dim * page_stride_dim,
mask=mask,
)
tl.store(
kv_ptr
+ layer_index * kv_stride_layer_num
+ mem_index * kv_stride_size
+ 0 * kv_stride_head
+ off_dim * kv_stride_dim,
page_tensor,
mask=mask,
)
return


Expand All @@ -290,10 +305,11 @@ def mla_page_io(mem_indexes: torch.Tensor, page_tensor: torch.Tensor, kv_buffer:
assert page_head_num == kv_head_num == 1

token_num = len(mem_indexes)
grid = (token_num,)
grid = (64,)

_mla_page_io[grid](
mem_index_ptr=mem_indexes,
token_num=token_num,
page_ptr=page_tensor,
page_stride_size=page_tensor.stride(0),
page_stride_layer_num=page_tensor.stride(1),
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ async def fetch_nixl_stream(
)

try:
await asyncio.wait_for(up_status_event.wait(), timeout=60)
await asyncio.wait_for(up_status_event.wait(), timeout=180)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,17 @@ def _create_nixl_trans_task(
):
# 确定传输设备
if req_obj.nixl_trans_device_id == -1:
if not hasattr(self, "nixl_iter_device_id"):
self.nixl_iter_device_id = 0
req_obj.nixl_trans_device_id = self.nixl_iter_device_id
# only self.is_master_in_dp will be used.
req_obj.nixl_trans_device_id = random.randint(0, self.node_world_size - 1)
self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size

trans_task = NIXLChunckedTransTask(
request_id=req_obj.req_id,
start_kv_index=kv_start_index,
end_kv_index=kv_end_index,
time_out_secs=80,
time_out_secs=180,
pd_master_node_id=req_obj.sampling_param.pd_master_node_id,
prefill_dp_index=None,
decode_dp_index=self.dp_rank_in_node,
Expand Down
Loading
Loading