Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bench/small_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6'
llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b'
eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd'
phoenix_path = '/scratch/avner/huggingface/hub/models--togethercomputer--phoenix-Llama-3p2-1B-Instruct-tgt-Llama-3p3-70b-instruct-UNTRAINED/snapshots/3af59d71514388e14d8685f2b684f74e3e311717'
# eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B'
assert os.path.isdir(llama_1b_path)
assert os.path.isdir(llama_70b_path)
Expand All @@ -18,6 +19,7 @@
parser.add_argument("--model", type=str, default=llama_1b_path)
parser.add_argument("--draft", type=str, default=llama_1b_path)
parser.add_argument("--eagle", action="store_true")
parser.add_argument("--phoenix", action="store_true")
parser.add_argument("--k", type=int, default=7)
parser.add_argument("--jit-speculate", action="store_true")
parser.add_argument("--num-gpus", type=int, default=2)
Expand All @@ -36,10 +38,18 @@
args.jit_speculate = True
args.chat_template = True

if args.phoenix:
args.draft = phoenix_path
args.model = llama_70b_path
args.num_gpus = 5
args.jit_speculate = True
args.chat_template = True

llm = LLM(
model=args.model,
draft=args.draft,
use_eagle=args.eagle,
use_phoenix=args.phoenix,
speculate_k=args.k,
speculate=True,
draft_async=True,
Expand Down
19 changes: 14 additions & 5 deletions ssd/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ class Config:
communicate_logits: bool = False
communicate_cache_hits: bool = False

# eagle3
# eagle3 / phoenix
use_eagle: bool = False
use_phoenix: bool = False
eagle_layers: list[int] | None = None
d_model_target: int | None = None
tokenizer_path: str | None = None
Expand All @@ -53,6 +54,10 @@ class Config:
def max_blocks(self):
return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size

@property
def use_eagle_or_phoenix(self):
return self.use_eagle or self.use_phoenix

def __post_init__(self):
model = self.model
assert os.path.isdir(model)
Expand All @@ -79,12 +84,16 @@ def __post_init__(self):
if self.fan_out_list is None:
self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1)
self.MQ_LEN = sum(self.fan_out_list)
if self.fan_out_list_miss is None:
self.fan_out_list_miss = self.fan_out_list
if not self.jit_speculate:
print(f'[Config] Setting fan_out_list_miss to [sum(fan_out_list)] + [0] * speculate_k because jit_speculate is False', flush=True)
self.fan_out_list_miss = [sum(self.fan_out_list)] + [0] * self.speculate_k
elif self.fan_out_list_miss is None:
self.fan_out_list_miss = self.fan_out_list

assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list"

if self.use_eagle:
if self.eagle_layers is None:
if self.use_eagle_or_phoenix:
if self.use_eagle and self.eagle_layers is None:
L = self.hf_config.num_hidden_layers
# self.eagle_layers = [3, L//2, L-3]
self.eagle_layers = [2, L//2, L-3] # [2, 16, 29] outputs, ie. [3, L//2+1, L-2] inputs
Expand Down
133 changes: 74 additions & 59 deletions ssd/engine/draft_runner.py

Large diffs are not rendered by default.

73 changes: 44 additions & 29 deletions ssd/engine/helpers/cudagraph_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,14 +482,17 @@ def capture_cudagraph(model_runner):
is_jit = (model_runner.config.speculate and model_runner.config.draft_async and model_runner.is_draft)

# Eagle models need special handling during CUDA graph capture
is_eagle_draft = config.use_eagle and model_runner.is_draft
is_eagle_target = config.use_eagle and not model_runner.is_draft
is_eagle_or_phoenix_draft = config.use_eagle_or_phoenix and model_runner.is_draft
is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft
hidden_states = None
if is_eagle_draft:
# Use hidden_size (d_model_draft) so CG captures the pass-through branch in Eagle3DraftForCausalLM.forward()
# All callers project target acts via fc() BEFORE passing to CG
hidden_states = torch.zeros(max_bs, hf_config.hidden_size,
dtype=hf_config.torch_dtype, device=input_ids.device)
if is_eagle_or_phoenix_draft:
# Note: For Eagle3, all callers project target acts via fc() BEFORE passing to CG
hidden_states = torch.zeros(
max_bs,
model_runner.hidden_states_dim,
dtype=hf_config.torch_dtype,
device=input_ids.device,
)

total_graphs = len(graph_bs_list)
print(f'[capture_cudagraph] Starting capture of {total_graphs} graphs, bs list: {graph_bs_list[:5]}...{graph_bs_list[-3:]} max_bs={max_bs}', flush=True)
Expand All @@ -498,21 +501,21 @@ def capture_cudagraph(model_runner):
graph = torch.cuda.CUDAGraph()
set_context(
False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs], is_jit=is_jit)
if is_eagle_draft:
if is_eagle_or_phoenix_draft:
outputs[:bs] = model_runner.model(
input_ids[:bs], positions[:bs], hidden_states[:bs]) # warmup
elif is_eagle_target:
elif is_eagle_or_phoenix_target:
out, _ = model_runner.model(
input_ids[:bs], positions[:bs]) # warmup
outputs[:bs] = out
else:
outputs[:bs] = model_runner.model(
input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, graph_pool):
if is_eagle_draft:
if is_eagle_or_phoenix_draft:
outputs[:bs] = model_runner.model(
input_ids[:bs], positions[:bs], hidden_states[:bs]) # capture
elif is_eagle_target:
elif is_eagle_or_phoenix_target:
out, _ = model_runner.model(
input_ids[:bs], positions[:bs]) # capture
outputs[:bs] = out
Expand Down Expand Up @@ -547,7 +550,7 @@ def capture_verify_cudagraph(model_runner):
max_bs = min(model_runner.config.max_num_seqs, 512)
k_plus_1 = model_runner.config.speculate_k + 1

is_eagle_target = config.use_eagle and not model_runner.is_draft
is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft

# For verify, we need to handle k+1 tokens per sequence, and use cu_seqlens_q and max_seqlen_q
input_ids = torch.zeros(max_bs * k_plus_1, dtype=torch.int64)
Expand All @@ -559,12 +562,14 @@ def capture_verify_cudagraph(model_runner):
outputs = torch.zeros(max_bs * k_plus_1, hf_config.hidden_size)
cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32)

# Eagle target: also capture eagle_acts from model forward
# Eagle/Phoenix target: also capture activations from model forward
eagle_acts = None
if is_eagle_target:
# eagle_acts has shape [num_tokens, 3 * hidden_size] for 3 layers
eagle_acts = torch.zeros(max_bs * k_plus_1, 3 * hf_config.hidden_size,
dtype=hf_config.torch_dtype)
if is_eagle_or_phoenix_target:
eagle_acts = torch.zeros(
max_bs * k_plus_1,
model_runner.eagle_acts_dim,
dtype=hf_config.torch_dtype,
)

base = [1, 2, 4, 8]
dynamic = list(range(16, max_bs+1, 16))
Expand Down Expand Up @@ -685,6 +690,7 @@ def run_glue_decode_cudagraph(model_runner, input_ids, positions, last_only, gra

outputs = graph_vars["outputs"][:orig_flat]
logits = model_runner.model.compute_logits(outputs, last_only)
assert logits.dim() == 2, "ERROR in run_glue_decode_cudagraph: logits must be 2D"
if "eagle_hidden_states" in graph_vars:
return logits, outputs
return logits
Expand All @@ -709,9 +715,14 @@ def capture_glue_decode_cudagraph(model_runner):
outputs = torch.empty(max_flat, hf_config.hidden_size, device=model_runner.device)
cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32, device=model_runner.device)

eagle_hs = None
if config.use_eagle and model_runner.is_draft:
eagle_hs = torch.zeros(max_flat, hf_config.hidden_size, dtype=hf_config.torch_dtype, device=model_runner.device)
eagle_hidden_states = None
if config.use_eagle_or_phoenix and model_runner.is_draft:
eagle_hidden_states = torch.zeros(
max_flat,
model_runner.hidden_states_dim,
dtype=hf_config.torch_dtype,
device=model_runner.device,
)

graph_bs_list = [1]
for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)):
Expand Down Expand Up @@ -745,14 +756,14 @@ def capture_glue_decode_cudagraph(model_runner):
block_tables=block_tables[:bs],
)

if eagle_hs is not None:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat])
if eagle_hidden_states is not None:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat])
else:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat])

with torch.cuda.graph(graph, graph_pool):
if eagle_hs is not None:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat])
if eagle_hidden_states is not None:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat])
else:
outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat])

Expand All @@ -771,8 +782,8 @@ def capture_glue_decode_cudagraph(model_runner):
cu_seqlens_q=cu_seqlens_q,
outputs=outputs,
)
if eagle_hs is not None:
graph_vars["eagle_hidden_states"] = eagle_hs
if eagle_hidden_states is not None:
graph_vars["eagle_hidden_states"] = eagle_hidden_states

return graph_vars, graph_pool, graphs, graph_bs_list

Expand Down Expand Up @@ -813,9 +824,13 @@ def capture_fi_tree_decode_cudagraph(model_runner):
# All callers project target acts via fc() BEFORE passing to CG
# MUST be outside the for-loop so all graphs share the same tensor
fi_hidden_states = None
if config.use_eagle and model_runner.is_draft:
fi_hidden_states = torch.zeros(max_flat_batch_size, hf_config.hidden_size,
dtype=hf_config.torch_dtype, device=model_runner.device)
if config.use_eagle_or_phoenix and model_runner.is_draft:
fi_hidden_states = torch.zeros(
max_flat_batch_size,
model_runner.hidden_states_dim,
dtype=hf_config.torch_dtype,
device=model_runner.device,
)

print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FI cudagraphs for bs={graph_bs_list}', flush=True)

Expand Down
6 changes: 3 additions & 3 deletions ssd/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def create_inference_step(self, config: Config) -> InferenceStep:
draft_dtype=config.draft_hf_config.torch_dtype,
kvcache_block_size=config.kvcache_block_size,
max_model_len=config.max_model_len,
eagle=config.use_eagle,
eagle_act_dim=3 * config.hf_config.hidden_size if config.use_eagle else 0,
eagle=config.use_eagle_or_phoenix,
eagle_act_dim=self.model_runner.eagle_acts_dim if config.use_eagle_or_phoenix else 0,
communicate_logits=config.communicate_logits,
communicate_cache_hits=config.communicate_cache_hits,
async_pg=self.model_runner.async_pg,
Expand Down Expand Up @@ -328,7 +328,7 @@ def create_inference_step(self, config: Config) -> InferenceStep:
scheduler=self.scheduler,
speculator=speculator,
verifier=verifier,
eagle=config.use_eagle,
eagle=config.use_eagle_or_phoenix,
tokenizer=self.tokenizer,
async_spec=config.draft_async,
)
Expand Down
Loading