diff --git a/bench/small_test.py b/bench/small_test.py index 8131faf..4efb136 100644 --- a/bench/small_test.py +++ b/bench/small_test.py @@ -9,6 +9,7 @@ llama_1b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6' llama_70b_path = '/scratch/avner/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b' eagle_path = '/scratch/avner/huggingface/hub/models--lmsys--SGLang-EAGLE3-Llama-3.3-70B-Instruct-SpecForge/snapshots/63ebaa6585f96b89685adad8fdfa0da53be6a8fd' + phoenix_path = '/scratch/avner/huggingface/hub/models--togethercomputer--phoenix-Llama-3p2-1B-Instruct-tgt-Llama-3p3-70b-instruct-UNTRAINED/snapshots/3af59d71514388e14d8685f2b684f74e3e311717' # eagle_path = '/scratch/avner/huggingface/hub/models--yuhuili--EAGLE3-LLaMA3.3-Instruct-70B' assert os.path.isdir(llama_1b_path) assert os.path.isdir(llama_70b_path) @@ -18,6 +19,7 @@ parser.add_argument("--model", type=str, default=llama_1b_path) parser.add_argument("--draft", type=str, default=llama_1b_path) parser.add_argument("--eagle", action="store_true") + parser.add_argument("--phoenix", action="store_true") parser.add_argument("--k", type=int, default=7) parser.add_argument("--jit-speculate", action="store_true") parser.add_argument("--num-gpus", type=int, default=2) @@ -36,10 +38,18 @@ args.jit_speculate = True args.chat_template = True + if args.phoenix: + args.draft = phoenix_path + args.model = llama_70b_path + args.num_gpus = 5 + args.jit_speculate = True + args.chat_template = True + llm = LLM( model=args.model, draft=args.draft, use_eagle=args.eagle, + use_phoenix=args.phoenix, speculate_k=args.k, speculate=True, draft_async=True, diff --git a/ssd/config.py b/ssd/config.py index c031746..5d1c7ea 100644 --- a/ssd/config.py +++ b/ssd/config.py @@ -38,8 +38,9 @@ class Config: communicate_logits: bool = False communicate_cache_hits: bool = False - # eagle3 + # eagle3 / phoenix use_eagle: bool = False + use_phoenix: bool = False eagle_layers: list[int] | None = None d_model_target: int | None = None tokenizer_path: str | None = None @@ -53,6 +54,10 @@ class Config: def max_blocks(self): return (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size + @property + def use_eagle_or_phoenix(self): + return self.use_eagle or self.use_phoenix + def __post_init__(self): model = self.model assert os.path.isdir(model) @@ -79,12 +84,16 @@ def __post_init__(self): if self.fan_out_list is None: self.fan_out_list = [self.async_fan_out] * (self.speculate_k + 1) self.MQ_LEN = sum(self.fan_out_list) - if self.fan_out_list_miss is None: - self.fan_out_list_miss = self.fan_out_list + if not self.jit_speculate: + print(f'[Config] Setting fan_out_list_miss to [sum(fan_out_list)] + [0] * speculate_k because jit_speculate is False', flush=True) + self.fan_out_list_miss = [sum(self.fan_out_list)] + [0] * self.speculate_k + elif self.fan_out_list_miss is None: + self.fan_out_list_miss = self.fan_out_list + assert sum(self.fan_out_list_miss) == sum(self.fan_out_list), "ERROR in Config: fan_out_list_miss must be the same as fan_out_list" - if self.use_eagle: - if self.eagle_layers is None: + if self.use_eagle_or_phoenix: + if self.use_eagle and self.eagle_layers is None: L = self.hf_config.num_hidden_layers # self.eagle_layers = [3, L//2, L-3] self.eagle_layers = [2, L//2, L-3] # [2, 16, 29] outputs, ie. [3, L//2+1, L-2] inputs diff --git a/ssd/engine/draft_runner.py b/ssd/engine/draft_runner.py index afb1af0..0765ece 100644 --- a/ssd/engine/draft_runner.py +++ b/ssd/engine/draft_runner.py @@ -33,8 +33,8 @@ def create_draft_config(cls, cfg: Config) -> Config: cfg, model=cfg.draft, gpu_memory_utilization = (0.75 if not cfg.draft_async else 0.8), # REMAINING SPACE if not draft_async - tokenizer_path=cfg.model if cfg.use_eagle else None, - d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle and cfg.hf_config else None, + tokenizer_path=cfg.model if cfg.use_eagle_or_phoenix else None, + d_model_target=cfg.hf_config.hidden_size if cfg.use_eagle_or_phoenix and cfg.hf_config else None, ) return draft_cfg @@ -49,10 +49,6 @@ def __init__(self, draft_cfg: Config, rank: int = 0, init_q = None): self.target_rank = 0 self.communicate_logits = self.config.communicate_logits self.communicate_cache_hits = self.config.communicate_cache_hits - - if self.config.use_eagle: - assert self.config.jit_speculate, \ - "EAGLE requires jit_speculate=True (cache misses need draft activations)" if self.is_draft and self.draft_async: self._reset_tree_cache_tensors() @@ -67,8 +63,8 @@ def draft_async_prefill(self): if self.config.verbose: print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL STARTING', flush=True) - prefill_request = PrefillRequest.receive(self.async_pg, self.target_rank, self.device, metadata_buffer=self._prefill_metadata, tokenizer=self.tokenizer) - total_new_tokens, batch_size, max_blocks, use_eagle, eagle_act_dim = prefill_request.metadata.tolist() + prefill_request = PrefillRequest.receive(self.async_pg, self.target_rank, self.device, metadata_buffer=self._prefill_metadata) + total_new_tokens, batch_size, max_blocks, use_eagle_or_phoenix, eagle_phoenix_act_dim = prefill_request.metadata.tolist() input_ids = prefill_request.input_ids num_tokens = prefill_request.num_tokens draft_block_table = prefill_request.draft_block_table @@ -87,12 +83,16 @@ def draft_async_prefill(self): prefill_ctxt = self.prepare_prefill_ctxt(num_tokens, draft_block_table) - if use_eagle: - assert eagle_act_dim == 3 * self.config.d_model_target, ( - f"EAGLE activation dimension {eagle_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + if self.config.use_eagle: + assert eagle_phoenix_act_dim == 3 * self.config.d_model_target, ( + f"EAGLE activation dimension {eagle_phoenix_act_dim} does not match expected dimension 3 * {self.config.d_model_target}" + ) + elif self.config.use_phoenix: + assert eagle_phoenix_act_dim == self.config.d_model_target, ( + f"PHOENIX activation dimension {eagle_phoenix_act_dim} does not match expected dimension {self.config.d_model_target}" ) if self.config.verbose: - print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle={use_eagle}, eagle_act_dim={eagle_act_dim}', flush=True) + print(f'[{_ts()}] [draft_async_prefill] METADATA: total_new_tokens={total_new_tokens}, batch_size={batch_size}, max_blocks={max_blocks}, use_eagle_or_phoenix={use_eagle_or_phoenix}, eagle_phoenix_act_dim={eagle_phoenix_act_dim}', flush=True) # 5) set up context exactly like prepare_prefill() does: @@ -108,10 +108,7 @@ def draft_async_prefill(self): # 6) run the draft model in prefill mode positions = prefill_ctxt["positions"] - if self.config.use_eagle: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) - else: - self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) + self.run_model(input_ids, positions, is_prefill=True, last_only=True, hidden_states=eagle_acts) if self.config.verbose: print(f'[{_ts()}] [draft_async_prefill] DRAFT ASYNC PREFILL DONE', flush=True) @@ -155,11 +152,9 @@ def jit_speculate( draft_block_tables: torch.Tensor, target_recovery_activations: torch.Tensor = None, ): - input_ids = request_keys[:, -1] - pos_offset = -1 if self.config.use_eagle else 0 - positions = num_tokens - 1 + pos_offset # want to write rec token at post N-1 since [0, ..., N-2] filled by prefill - context_lens = num_tokens + pos_offset # N+1 + positions = num_tokens - 1 + context_lens = num_tokens # Calculate slot mapping vectorized block_idx = positions // self.block_size pos_in_block = positions % self.block_size @@ -168,13 +163,16 @@ def jit_speculate( hidden_states = None spec_activations = None - - if self.config.use_eagle: + + if self.config.use_eagle_or_phoenix: assert target_recovery_activations is not None - hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + if self.config.use_eagle: + hidden_states = self.model.fc(target_recovery_activations.to(self.model.fc.weight.dtype)) + else: + hidden_states = target_recovery_activations spec_activations = torch.empty( input_ids.shape[0], self.config.speculate_k, - self.hf_config.hidden_size, + self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device) for i in range(self.config.speculate_k): # we're going to glue after this anyways, and by sending the spec request target has verified we have K more slots left in our last page @@ -186,10 +184,13 @@ def jit_speculate( is_jit=True, ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(input_ids, positions, is_prefill=False, last_only=True, hidden_states=hidden_states) - spec_activations[:, i] = prenorm - hidden_states = prenorm + if self.config.use_eagle: + spec_activations[:, i] = prenorm + hidden_states = prenorm + else: + spec_activations[:, i] = hidden_states else: logits = self.run_model(input_ids, positions, is_prefill=False, last_only=True) @@ -221,12 +222,11 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta cache_hits = torch.zeros(B, dtype=torch.int64, device=self.device) assert request_keys.shape == (B, 3), f"ERROR in hit_cache: request_keys should be (B, 3), got {request_keys.shape}" - - hidden_size = self.hf_config.hidden_size + out_activations = torch.empty( - B, K, hidden_size, + B, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None + ) if self.config.use_eagle_or_phoenix else None # Statistics ttl += int(B) @@ -274,7 +274,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta out_tokens[sel] = self.tree_cache_tokens[idx[sel]] # logits [T,K+1,V] out_logits[sel] = self.tree_cache_logits[idx[sel]] - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations[sel] = self.tree_cache_activations[idx[sel]] elif self.config.jit_speculate: # print(f'[hit_cache] found a cache miss, running jit speculate', flush=True) @@ -289,7 +289,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) # write into out_logits, out_tokens - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts elif self.config.jit_speculate: # Cache is empty (first iteration), must JIT all @@ -304,7 +304,7 @@ def hit_cache(self, request_keys, B, K, num_tokens, temperatures, draft_block_ta draft_block_tables, target_recovery_activations ) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: out_activations = jit_acts rec_toks = request_keys[:, 2] @@ -422,8 +422,7 @@ def prepare_prefill_ctxt( def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): K = self.config.speculate_k - pos_offset = -1 if self.config.use_eagle else 0 - positions_start = (num_tokens - 1 + pos_offset).unsqueeze(-1) + positions_start = (num_tokens - 1).unsqueeze(-1) positions_grid = positions_start + self._arange_kp1 # Calculate block indices and offsets for ALL positions @@ -441,7 +440,7 @@ def prepare_glue_decode_ctxt(self, num_tokens, input_ids, dbt, B): positions_flat = positions_grid.reshape(-1).to(torch.int64) slot_map_flat = slot_map_grid.reshape(-1).to(torch.int32) - context_lens = (num_tokens + pos_offset + K).to(torch.int32) + context_lens = (num_tokens + K).to(torch.int32) seqlen_q = torch.full((B,), K + 1, dtype=torch.int32, device=self.device) cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device=self.device) cu_seqlens_q[1:] = torch.cumsum(seqlen_q, dim=0) @@ -514,9 +513,8 @@ def _construct_tree_decode_args(self, partial_tree_decode_args, rec_flat, dbt): seq_ids = partial_tree_decode_args["seq_ids"] seq_ids_expanded = seq_ids[b_flat] - pos_offset = -1 if self.config.use_eagle else 0 - positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + (K + 1) + fkp1_flat - rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1 + pos_offset) + j_idx_flat + 1 + positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + (K + 1) + fkp1_flat + rope_positions = (partial_tree_decode_args["num_tokens"][b_flat] - 1) + j_idx_flat + 1 temperatures = partial_tree_decode_args["temperatures"][b_flat] tree_decode_args = { @@ -541,9 +539,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): dbt = partial_tree_decode_args["dbt"] cache_hits = partial_tree_decode_args["cache_hits"] cache_hits_list = cache_hits.tolist() - pos_offset = -1 if self.config.use_eagle else 0 - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: B = partial_tree_decode_args["num_tokens"].shape[0] extend_counts = partial_tree_decode_args.get("extend_counts") if extend_counts is None: @@ -552,8 +549,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): extend_token_ids_batch = partial_tree_decode_args.get("extend_token_ids") target_acts = partial_tree_decode_args["target_recovery_activations"] prev_acts = partial_tree_decode_args["previous_activations"] - hidden_size = self.hf_config.hidden_size - fc_dtype = self.model.fc.weight.dtype + hidden_size = self.hidden_states_dim + fc_dtype = self.model.fc.weight.dtype if self.config.use_eagle else self.hf_config.torch_dtype gd_view = glue_decode_input_ids.view(B, K + 1) rec_tok_ids = gd_view[:, 0] @@ -598,7 +595,10 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fused_ids[is_rec] = rec_tok_ids[batch_idx[is_rec]] # Single batched fc call - fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + if self.config.use_eagle: + fused_hs[is_target_conditioned] = self.model.fc(tc_acts) + elif self.config.use_phoenix: + fused_hs[is_target_conditioned] = tc_acts # Spec tokens: ids from spec_tok_ids, hs from prev_acts (self-conditioned, no fc) spec_j = local_off[is_spec] - n_ext_per_tok[is_spec] - 1 # 0..K-1 @@ -628,8 +628,8 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): N_pre = _pre_b_flat.shape[0] _pre_metadata_ints = (B, K, self.config.async_fan_out, N_pre) _pre_seq_ids_expanded = partial_tree_decode_args["seq_ids"][_pre_b_flat] - _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + (K + 1) + _pre_fkp1_flat - _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1 + pos_offset) + _pre_j_idx_flat + 1 + _pre_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + (K + 1) + _pre_fkp1_flat + _pre_rope_positions = (partial_tree_decode_args["num_tokens"][_pre_b_flat] - 1) + _pre_j_idx_flat + 1 _pre_temperatures = partial_tree_decode_args["temperatures"][_pre_b_flat] # --- Run glue decode forward --- @@ -643,7 +643,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): ) glue_prenorm = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: fused_hs_flat = glue_decode_ctxt["hidden_states"] glue_decode_logits_flat, glue_prenorm = self.run_model( glue_decode_ctxt["input_ids"], glue_decode_ctxt["positions"], @@ -662,7 +662,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): reset_context() # --- Extract K+1 logits/prenorms at rec+spec positions --- - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: # Packed layout: rec at cu_seqlens_q[b] + n_ext[b], spec follows cu_q = glue_decode_ctxt["cu_seqlens_q"] rec_offsets = cu_q[:-1].long() + extend_counts.long() # [B] @@ -679,6 +679,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): # --- Build tree hidden states from K+1 prenorms --- tree_hidden_states = None if glue_prenorm is not None: + assert self.config.use_eagle_or_phoenix, "ERROR in _build_tree_batch: use_eagle_or_phoenix must be True when glue_prenorm is not None." # Vectorized: for each (b, depth), repeat prenorm by fan_out[depth] # fan_out_t[depth] for hits, fan_out_t_miss[depth] for misses fan_hit = self.config.fan_out_t # [K+1] @@ -690,12 +691,20 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): fan_miss.unsqueeze(0).expand(B, K + 1), ) # [B, K+1] reps_flat = per_batch_fan.reshape(-1) # [B*(K+1)] - prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] - tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + + if self.config.use_eagle: + prenorms_flat = glue_prenorm_kp1.reshape(B * (K + 1), -1) # [B*(K+1), d] + tree_hidden_states = torch.repeat_interleave(prenorms_flat, reps_flat, dim=0) + else: + assert self.config.use_phoenix + # Phoenix conditions on target activations, not prenorms + target_acts_expanded = target_acts.unsqueeze(1).expand(B, K + 1, -1) # [B, K+1, target_dim] + acts_flat = target_acts_expanded.reshape(B * (K + 1), -1) # [B*(K+1), target_dim] + tree_hidden_states = torch.repeat_interleave(acts_flat, reps_flat, dim=0) # --- Fork tokens from K+1 logits --- # Need [B, K+1] input_ids for forking (rec + spec tokens) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: gd_for_fork = gd_view # [B, K+1] already computed above else: gd_for_fork = glue_decode_input_ids.reshape(B, K + 1) @@ -719,6 +728,7 @@ def _build_tree_batch(self, partial_tree_decode_args, glue_decode_input_ids): "seq_ids_expanded": _pre_seq_ids_expanded, "cache_hits": cache_hits, "cache_hits_list": cache_hits_list, + "target_recovery_activations": partial_tree_decode_args["target_recovery_activations"], } tree_decode_args["hidden_states"] = tree_hidden_states return tree_decode_args @@ -743,7 +753,7 @@ def _compute_step_positions_and_slot_maps(self, initial_positions, initial_rope_ return step_positions, step_rope_positions, step_context_lens, step_slot_maps - def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations): + def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_slot_maps, step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations): """Execute a single tree decode step.""" # Use precomputed values for this step set_context( @@ -754,11 +764,15 @@ def _decode_tree_step(self, depth, current_input_ids, step_rope_positions, step_ ) hidden_states = payload.get("hidden_states") - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, prenorm = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"], hidden_states=hidden_states) assert spec_activations is not None - spec_activations[:, depth] = prenorm - payload["hidden_states"] = prenorm + if self.config.use_eagle: + spec_activations[:, depth] = prenorm + payload["hidden_states"] = prenorm + else: + spec_activations[:, depth] = target_recovery_activations + payload["hidden_states"] = target_recovery_activations else: logits = self.run_model(current_input_ids, step_rope_positions[depth], is_prefill=False, last_only=False, tree_decode_step=depth, cache_hits=payload["cache_hits"]) @@ -785,9 +799,9 @@ def _decode_tree(self, payload): spec_logits = torch.empty( N, K, V, dtype=self.hf_config.torch_dtype, device=self.device) spec_activations = torch.empty( - N, K, self.hf_config.hidden_size, + N, K, self.hidden_states_dim, dtype=self.hf_config.torch_dtype, device=self.device - ) if self.config.use_eagle else None + ) if self.config.use_eagle_or_phoenix else None # Precompute all positions, context_lens, and slot_maps for all K steps # PERFORMANCE: no .clone() needed — these are not modified in-place @@ -795,6 +809,7 @@ def _decode_tree(self, payload): initial_rope_positions = payload["rope_positions"] # [N] current_input_ids = payload["input_ids"] # [N], the forked tokens dbt = payload["block_tables"] # [B, M] - constant across steps + target_recovery_activations = payload["target_recovery_activations"] # Use compiled function for batch-size independent computations _, step_rope_positions, step_context_lens, step_slot_maps = self._compute_step_positions_and_slot_maps( @@ -810,7 +825,7 @@ def _decode_tree(self, payload): _st = time.perf_counter() current_input_ids = self._decode_tree_step( depth, current_input_ids, step_rope_positions, step_slot_maps, - step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations + step_context_lens, dbt, payload, spec_tokens, spec_logits, spec_activations, target_recovery_activations, ) if _prof or PROFILE_DRAFT: torch.cuda.synchronize() diff --git a/ssd/engine/helpers/cudagraph_helpers.py b/ssd/engine/helpers/cudagraph_helpers.py index 6c38eed..cbcd010 100644 --- a/ssd/engine/helpers/cudagraph_helpers.py +++ b/ssd/engine/helpers/cudagraph_helpers.py @@ -482,14 +482,17 @@ def capture_cudagraph(model_runner): is_jit = (model_runner.config.speculate and model_runner.config.draft_async and model_runner.is_draft) # Eagle models need special handling during CUDA graph capture - is_eagle_draft = config.use_eagle and model_runner.is_draft - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_draft = config.use_eagle_or_phoenix and model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft hidden_states = None - if is_eagle_draft: - # Use hidden_size (d_model_draft) so CG captures the pass-through branch in Eagle3DraftForCausalLM.forward() - # All callers project target acts via fc() BEFORE passing to CG - hidden_states = torch.zeros(max_bs, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=input_ids.device) + if is_eagle_or_phoenix_draft: + # Note: For Eagle3, all callers project target acts via fc() BEFORE passing to CG + hidden_states = torch.zeros( + max_bs, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=input_ids.device, + ) total_graphs = len(graph_bs_list) print(f'[capture_cudagraph] Starting capture of {total_graphs} graphs, bs list: {graph_bs_list[:5]}...{graph_bs_list[-3:]} max_bs={max_bs}', flush=True) @@ -498,10 +501,10 @@ def capture_cudagraph(model_runner): graph = torch.cuda.CUDAGraph() set_context( False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs], is_jit=is_jit) - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # warmup - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # warmup outputs[:bs] = out @@ -509,10 +512,10 @@ def capture_cudagraph(model_runner): outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, graph_pool): - if is_eagle_draft: + if is_eagle_or_phoenix_draft: outputs[:bs] = model_runner.model( input_ids[:bs], positions[:bs], hidden_states[:bs]) # capture - elif is_eagle_target: + elif is_eagle_or_phoenix_target: out, _ = model_runner.model( input_ids[:bs], positions[:bs]) # capture outputs[:bs] = out @@ -547,7 +550,7 @@ def capture_verify_cudagraph(model_runner): max_bs = min(model_runner.config.max_num_seqs, 512) k_plus_1 = model_runner.config.speculate_k + 1 - is_eagle_target = config.use_eagle and not model_runner.is_draft + is_eagle_or_phoenix_target = config.use_eagle_or_phoenix and not model_runner.is_draft # For verify, we need to handle k+1 tokens per sequence, and use cu_seqlens_q and max_seqlen_q input_ids = torch.zeros(max_bs * k_plus_1, dtype=torch.int64) @@ -559,12 +562,14 @@ def capture_verify_cudagraph(model_runner): outputs = torch.zeros(max_bs * k_plus_1, hf_config.hidden_size) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32) - # Eagle target: also capture eagle_acts from model forward + # Eagle/Phoenix target: also capture activations from model forward eagle_acts = None - if is_eagle_target: - # eagle_acts has shape [num_tokens, 3 * hidden_size] for 3 layers - eagle_acts = torch.zeros(max_bs * k_plus_1, 3 * hf_config.hidden_size, - dtype=hf_config.torch_dtype) + if is_eagle_or_phoenix_target: + eagle_acts = torch.zeros( + max_bs * k_plus_1, + model_runner.eagle_acts_dim, + dtype=hf_config.torch_dtype, + ) base = [1, 2, 4, 8] dynamic = list(range(16, max_bs+1, 16)) @@ -685,6 +690,7 @@ def run_glue_decode_cudagraph(model_runner, input_ids, positions, last_only, gra outputs = graph_vars["outputs"][:orig_flat] logits = model_runner.model.compute_logits(outputs, last_only) + assert logits.dim() == 2, "ERROR in run_glue_decode_cudagraph: logits must be 2D" if "eagle_hidden_states" in graph_vars: return logits, outputs return logits @@ -709,9 +715,14 @@ def capture_glue_decode_cudagraph(model_runner): outputs = torch.empty(max_flat, hf_config.hidden_size, device=model_runner.device) cu_seqlens_q = torch.zeros(max_bs + 1, dtype=torch.int32, device=model_runner.device) - eagle_hs = None - if config.use_eagle and model_runner.is_draft: - eagle_hs = torch.zeros(max_flat, hf_config.hidden_size, dtype=hf_config.torch_dtype, device=model_runner.device) + eagle_hidden_states = None + if config.use_eagle_or_phoenix and model_runner.is_draft: + eagle_hidden_states = torch.zeros( + max_flat, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) graph_bs_list = [1] for bs in [2, 4, 8] + list(range(16, max_bs + 1, 16)): @@ -745,14 +756,14 @@ def capture_glue_decode_cudagraph(model_runner): block_tables=block_tables[:bs], ) - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) with torch.cuda.graph(graph, graph_pool): - if eagle_hs is not None: - outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hs[:flat]) + if eagle_hidden_states is not None: + outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat], eagle_hidden_states[:flat]) else: outputs[:flat] = model_runner.model(input_ids[:flat], positions[:flat]) @@ -771,8 +782,8 @@ def capture_glue_decode_cudagraph(model_runner): cu_seqlens_q=cu_seqlens_q, outputs=outputs, ) - if eagle_hs is not None: - graph_vars["eagle_hidden_states"] = eagle_hs + if eagle_hidden_states is not None: + graph_vars["eagle_hidden_states"] = eagle_hidden_states return graph_vars, graph_pool, graphs, graph_bs_list @@ -813,9 +824,13 @@ def capture_fi_tree_decode_cudagraph(model_runner): # All callers project target acts via fc() BEFORE passing to CG # MUST be outside the for-loop so all graphs share the same tensor fi_hidden_states = None - if config.use_eagle and model_runner.is_draft: - fi_hidden_states = torch.zeros(max_flat_batch_size, hf_config.hidden_size, - dtype=hf_config.torch_dtype, device=model_runner.device) + if config.use_eagle_or_phoenix and model_runner.is_draft: + fi_hidden_states = torch.zeros( + max_flat_batch_size, + model_runner.hidden_states_dim, + dtype=hf_config.torch_dtype, + device=model_runner.device, + ) print(f'[cuda_graph_helpers.capture_fi_tree_decode_cudagraph] About to capture FI cudagraphs for bs={graph_bs_list}', flush=True) diff --git a/ssd/engine/llm_engine.py b/ssd/engine/llm_engine.py index e99c648..0932989 100644 --- a/ssd/engine/llm_engine.py +++ b/ssd/engine/llm_engine.py @@ -298,8 +298,8 @@ def create_inference_step(self, config: Config) -> InferenceStep: draft_dtype=config.draft_hf_config.torch_dtype, kvcache_block_size=config.kvcache_block_size, max_model_len=config.max_model_len, - eagle=config.use_eagle, - eagle_act_dim=3 * config.hf_config.hidden_size if config.use_eagle else 0, + eagle=config.use_eagle_or_phoenix, + eagle_act_dim=self.model_runner.eagle_acts_dim if config.use_eagle_or_phoenix else 0, communicate_logits=config.communicate_logits, communicate_cache_hits=config.communicate_cache_hits, async_pg=self.model_runner.async_pg, @@ -328,7 +328,7 @@ def create_inference_step(self, config: Config) -> InferenceStep: scheduler=self.scheduler, speculator=speculator, verifier=verifier, - eagle=config.use_eagle, + eagle=config.use_eagle_or_phoenix, tokenizer=self.tokenizer, async_spec=config.draft_async, ) diff --git a/ssd/engine/model_runner.py b/ssd/engine/model_runner.py index b945522..8747eb5 100644 --- a/ssd/engine/model_runner.py +++ b/ssd/engine/model_runner.py @@ -14,6 +14,7 @@ from ssd.models.qwen3 import Qwen3ForCausalLM from ssd.models.llama3 import LlamaForCausalLM from ssd.models.eagle3_draft_llama3 import Eagle3DraftForCausalLM +from ssd.models.phoenix_draft_llama3 import PhoenixLlamaForCausalLM from ssd.layers.sampler import Sampler from ssd.utils.context import set_context, reset_context, get_context from ssd.utils.loader import load_model @@ -76,6 +77,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra self.world_size = config.num_gpus if should_use_dist else 1 self.rank = rank self.use_eagle = config.use_eagle + self.use_phoenix = config.use_phoenix if config.draft_async: self.draft_rank = config.num_gpus - 1 @@ -125,7 +127,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event], is_dra assert num_tp_gpus == 1, "ERROR in ModelRunner: draft should have tp_size=1" self.tp_pg = None # every rank is given an object from self.tp_pg, even tho draft doesnt participate it gets GROUP_NON_MEMBER object != None back, so we can't assert None here, we - print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}', flush=True) + print(f'[model_runner] about to setup and warmup model and cudagraphs, is use_eagle={self.use_eagle}, is use_phoenix={self.use_phoenix}', flush=True) model_type = self.setup_and_warmup_model_and_cudagraphs(config, self.hf_config, init_q, is_draft) if self.verbose: print(f'-----CAPTURED {model_type}CUDAGRAPH----', flush=True) @@ -228,6 +230,9 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC if config.use_eagle and is_draft: print(f'[EAGLE3] Loading Eagle3DraftForCausalLM as model_class', flush=True) model_class = Eagle3DraftForCausalLM + elif config.use_phoenix and is_draft: + print(f'[PHOENIX] Loading PhoenixDraftForCausalLM as model_class', flush=True) + model_class = PhoenixLlamaForCausalLM elif hf_config.model_type == 'llama': model_class = LlamaForCausalLM elif hf_config.model_type == 'qwen3': @@ -247,11 +252,12 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC tp_size=self.num_tp_gpus, ) - if config.use_eagle: - kwargs['use_eagle'] = True + if config.use_eagle_or_phoenix: + kwargs['use_eagle'] = config.use_eagle + kwargs['use_phoenix'] = config.use_phoenix kwargs['eagle_layers'] = self.config.eagle_layers - - if model_class == Eagle3DraftForCausalLM: + + if model_class in [Eagle3DraftForCausalLM, PhoenixLlamaForCausalLM]: kwargs['d_model_target'] = config.d_model_target kwargs['debug_mode'] = config.debug_mode @@ -307,7 +313,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["decode"] = decode_graph_pool self.graphs["decode"] = decode_graphs self.graph_bs_list["decode"] = decode_graph_bs_list - if self.config.speculate and not (self.is_draft and self.config.use_eagle): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead + if self.config.speculate and not (self.is_draft and self.config.use_eagle_or_phoenix): # verify CG: target always, non-EAGLE draft for fan-out; EAGLE draft uses glue_decode CG instead verify_graph_vars, verify_graph_pool, verify_graphs, verify_graph_bs_list = capture_verify_cudagraph(self) self.graph_vars["verify"] = verify_graph_vars self.graph_pools["verify"] = verify_graph_pool @@ -319,7 +325,7 @@ def setup_and_warmup_model_and_cudagraphs(self, config: Config, hf_config: AutoC self.graph_pools["fi_tree_decode"] = fi_tree_decode_graph_pool self.graphs["fi_tree_decode"] = fi_tree_decode_graphs self.graph_bs_list["fi_tree_decode"] = fi_tree_decode_graph_bs_list - if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle: + if self.config.speculate and self.is_draft and self.config.draft_async and self.config.use_eagle_or_phoenix: glue_gv, glue_pool, glue_graphs, glue_bs_list = capture_glue_decode_cudagraph(self) self.graph_vars["glue_decode"] = glue_gv self.graph_pools["glue_decode"] = glue_pool @@ -484,10 +490,15 @@ def warmup_model(self): seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] hidden_states = None - if self.config.use_eagle and self.is_draft: + if self.config.use_eagle_or_phoenix and self.is_draft: num_tokens = num_seqs * max_model_len d_model_target = self.config.d_model_target or 4096 - hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + if self.config.use_eagle: + hidden_states = torch.zeros(num_tokens, 3 * d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + elif self.config.use_phoenix: + hidden_states = torch.zeros(num_tokens, d_model_target, dtype=self.hf_config.torch_dtype, device=self.device) + else: + raise ValueError(f"Unsupported model type: {self.config.use_eagle_or_phoenix}") self.run(seqs, True, hidden_states=hidden_states) torch.cuda.empty_cache() @@ -643,6 +654,21 @@ def eager_tree_decode_plan(self, input_ids, positions, step, cache_hits): kv_data_type=self.hf_config.torch_dtype, ) + @property + def hidden_states_dim(self): + # The dimension of the hidden states that are concatenated with the draft tokens embeddings + # as the input to the Eagle/Phoenix draft model. + assert self.config.use_eagle_or_phoenix and self.is_draft + return self.config.hf_config.hidden_size if self.config.use_eagle else self.config.d_model_target + + @property + def eagle_acts_dim(self): + assert self.config.use_eagle_or_phoenix and not self.is_draft + if self.config.eagle_layers: + return len(self.config.eagle_layers) * self.config.hf_config.hidden_size + else: + return self.config.hf_config.hidden_size + @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool, last_only: bool = True, tree_decode_step: int = -1, cache_hits: torch.Tensor | None = None, hidden_states: torch.Tensor | None = None): is_tree_decode = self.is_draft and self.config.draft_async and tree_decode_step >= 0 @@ -655,10 +681,10 @@ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill if is_tree_decode: self.eager_tree_decode_plan(input_ids, positions, tree_decode_step, cache_hits) - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: if self.is_draft: assert hidden_states is not None, "hidden_states required for EAGLE draft" - assert isinstance(self.model, Eagle3DraftForCausalLM) + assert isinstance(self.model, Eagle3DraftForCausalLM) or isinstance(self.model, PhoenixLlamaForCausalLM) prenorm = self.model(input_ids, positions, hidden_states) logits = self.model.compute_logits(prenorm, last_only) return logits, prenorm # return prenorm as conditioning vector for next iteration @@ -708,7 +734,7 @@ def run( # Handle EAGLE returning (logits, conditioning_vector for next iter) conditioning = None - if self.config.use_eagle: + if self.config.use_eagle_or_phoenix: logits, conditioning = self.run_model( input_ids, positions, is_prefill, last_only, hidden_states=hidden_states) else: @@ -717,7 +743,7 @@ def run( if _pt: torch.cuda.synchronize() _r2 = time.perf_counter() - print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle} n_ids={input_ids.shape[0]}", flush=True) + print(f"[PROFILE target_run] prepare_decode={(_r1-_r0)*1000:.2f}ms run_model={(_r2-_r1)*1000:.2f}ms eagle={self.config.use_eagle}, phoenix={self.config.use_phoenix}, n_ids={input_ids.shape[0]}", flush=True) if last_only: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None @@ -730,5 +756,3 @@ def run( if conditioning is not None: return logits, conditioning return logits - - diff --git a/ssd/engine/speculator_async.py b/ssd/engine/speculator_async.py index a5e3abc..f61d121 100644 --- a/ssd/engine/speculator_async.py +++ b/ssd/engine/speculator_async.py @@ -75,18 +75,17 @@ def _prepare_prefill_request(self, seqs: list[Sequence], verify_result: VerifyRe eagle_acts = verify_result.eagle_acts input_id_list = [seq.token_ids for seq in seqs] - # EAGLE token-conditioning shift: token at position j gets conditioning - # from target act at position j-1. Skip first token per seq and drop - # last eagle_act per seq so they align correctly. + # EAGLE/Phoenix token-conditioning shift: we duplicate the first target activation for each sequence. + # [t0, h0], [t1, h0], [t2, h1], [t3, h2], ... if eagle_acts is not None: sliced = [] offset = 0 for ids in input_id_list: seq_len = len(ids) + sliced.append(eagle_acts[offset:offset + 1]) sliced.append(eagle_acts[offset:offset + seq_len - 1]) offset += seq_len eagle_acts = torch.cat(sliced, dim=0) - input_id_list = [ids[1:] for ids in input_id_list] max_blocks = (self.max_model_len + self.kvcache_block_size - 1) // self.kvcache_block_size input_ids_flat = [] diff --git a/ssd/layers/linear.py b/ssd/layers/linear.py index b258241..d605caa 100755 --- a/ssd/layers/linear.py +++ b/ssd/layers/linear.py @@ -89,6 +89,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) @@ -115,6 +118,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size) @@ -147,6 +153,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return assert loaded_shard_id in ["q", "k", "v"] if loaded_shard_id == "q": shard_size = self.num_heads * self.head_size @@ -187,6 +196,9 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data + if param_data.dim() == 1: # bias — no sharding needed + param_data.copy_(loaded_weight) + return shard_size = param_data.size(self.tp_dim) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size) diff --git a/ssd/models/eagle3_draft_llama3.py b/ssd/models/eagle3_draft_llama3.py index a74dd41..71c19a1 100644 --- a/ssd/models/eagle3_draft_llama3.py +++ b/ssd/models/eagle3_draft_llama3.py @@ -219,6 +219,7 @@ def __init__( draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, d_model_target: int = 4096, spec_k: int = 1, @@ -233,6 +234,7 @@ def __init__( assert draft, "ERROR in Eagle3DraftForLlama3: draft must be True" assert use_eagle, "ERROR in Eagle3DraftForLlama3: config.use_eagle must be True" assert eagle_layers is not None, "ERROR in Eagle3DraftForLlama3: eagle_layers must be set" + assert not use_phoenix, "ERROR in Eagle3DraftForLlama3: config.use_phoenix must be False" # this will be the draft that does tree decode, just needs a modified fwd pass that takes in hidden states and uses fc and dicts to sample, etc self.config = config diff --git a/ssd/models/llama3.py b/ssd/models/llama3.py index a9934ad..091df66 100755 --- a/ssd/models/llama3.py +++ b/ssd/models/llama3.py @@ -210,6 +210,7 @@ def __init__( async_fan_out: int = 1, draft_async: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, tp_group: dist.ProcessGroup | None = None, tp_size: int = 1, @@ -221,8 +222,9 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers - print(f'[LlamaModel] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) + print(f'[LlamaModel] use_eagle={use_eagle}, use_phoenix={use_phoenix}, eagle_layers={eagle_layers}', flush=True) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -249,24 +251,33 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - hidden_states = self.embed_tokens(input_ids) # torch.Size([4096, 2560]) always through residual stream + if hidden_states is None: + hidden_states = self.embed_tokens(input_ids) residual = None # Collect activations if use_eagle - collected_acts = [] if self.use_eagle else None + collected_acts = [] if not self.draft and (self.use_eagle or self.use_phoenix) else None for layer_idx, layer in enumerate(self.layers): - if collected_acts is not None and layer_idx in self.eagle_layers: + if collected_acts is not None and self.eagle_layers is not None and layer_idx in self.eagle_layers: current_act = hidden_states if residual is None else hidden_states + residual collected_acts.append(current_act) hidden_states, residual = layer(positions, hidden_states, residual) - hidden_states, _ = self.norm(hidden_states, residual) - - if collected_acts: - eagle_acts = torch.cat(collected_acts, dim=-1) + + if not self.draft and self.use_phoenix: + assert self.eagle_layers is None, "ERROR in LlamaModel: use_phoenix and eagle_layers are not compatible" + collected_acts.append(hidden_states) + + if collected_acts is not None: + if len(collected_acts) > 1: + eagle_acts = torch.cat(collected_acts, dim=-1) + else: + assert len(collected_acts) == 1 + eagle_acts = collected_acts[0] print(f'[LlamaModel] eagle_acts shape={eagle_acts.shape}', flush=True) return hidden_states, eagle_acts else: @@ -284,9 +295,11 @@ class LlamaForCausalLM(nn.Module): def __init__( self, - config: LlamaConfig, draft: bool = False, + config: LlamaConfig, + draft: bool = False, speculate: bool = False, use_eagle: bool = False, + use_phoenix: bool = False, eagle_layers: list[int] | None = None, spec_k: int = 1, async_fan_out: int = 1, @@ -301,6 +314,7 @@ def __init__( self.async_fan_out = async_fan_out self.draft_async = draft_async self.use_eagle = use_eagle + self.use_phoenix = use_phoenix self.eagle_layers = eagle_layers self.tp_group = tp_group self.tp_size = tp_size @@ -310,7 +324,19 @@ def __init__( print(f'Starting LlamaForCausalLM init, draft={draft}, speculate={speculate}, spec_k={spec_k}') print(f'[LlamaForCausalLM] use_eagle={use_eagle}, eagle_layers={eagle_layers}', flush=True) - self.model = LlamaModel(config, draft, speculate, spec_k, async_fan_out, draft_async, use_eagle=use_eagle, eagle_layers=eagle_layers, tp_group=tp_group, tp_size=self.tp_size) + self.model = LlamaModel( + config, + draft, + speculate, + spec_k, + async_fan_out, + draft_async, + use_eagle=use_eagle, + use_phoenix=use_phoenix, + eagle_layers=eagle_layers, + tp_group=tp_group, + tp_size=self.tp_size, + ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/ssd/models/phoenix_draft_llama3.py b/ssd/models/phoenix_draft_llama3.py new file mode 100644 index 0000000..2b25401 --- /dev/null +++ b/ssd/models/phoenix_draft_llama3.py @@ -0,0 +1,74 @@ +import torch +import torch.distributed as dist +from transformers import LlamaConfig + +from ssd.layers.linear import RowParallelLinear +from ssd.models.llama3 import LlamaForCausalLM + + +class PhoenixLlamaForCausalLM(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + draft: bool = True, + speculate: bool = True, + use_eagle: bool = False, + use_phoenix: bool = True, + eagle_layers: list[int] | None = None, + d_model_target: int = 4096, + spec_k: int = 1, + async_fan_out: int = 1, + draft_async: bool = False, + tp_group: dist.ProcessGroup | None = None, + tp_size: int = 1, + debug_mode: bool = False, + ) -> None: + assert draft, "ERROR in PhoenixLlamaForCausalLM: draft must be True" + assert use_phoenix, "ERROR in PhoenixLlamaForCausalLM: config.use_phoenix must be True" + assert not use_eagle, "ERROR in PhoenixLlamaForCausalLM: config.use_eagle must be False" + super().__init__( + config, + draft=True, + speculate=True, + use_eagle=False, + use_phoenix=True, + eagle_layers=None, + spec_k=spec_k, + async_fan_out=async_fan_out, + draft_async=draft_async, + tp_group=tp_group, + tp_size=tp_size, + ) + self.d_model_target = d_model_target + self.debug_mode = debug_mode + self.eh_proj = RowParallelLinear( + self.d_model_target + config.hidden_size, + config.hidden_size, + bias=True, + tp_group=tp_group, + tp_size=tp_size, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.model.embed_tokens(input_ids) + hidden_states = torch.cat((input_embeds, hidden_states), dim=-1) + hidden_states = self.eh_proj(hidden_states.to(self.eh_proj.weight.dtype)) + out = self.model(input_ids, positions, hidden_states) + return out + + def compute_logits( + self, + hidden_states: torch.Tensor, + last_only: bool = True, + ) -> torch.Tensor: + logits = self.lm_head(hidden_states, last_only=last_only) + + if logits.dim() == 3: + logits = logits.view(-1, logits.shape[-1]) + + return logits