Skip to content

feat(scheduler): ratio-based decode reservation, chunked prefill and retract preemption#7038

Open
Foriv wants to merge 14 commits intoPaddlePaddle:developfrom
Foriv:feat-all
Open

feat(scheduler): ratio-based decode reservation, chunked prefill and retract preemption#7038
Foriv wants to merge 14 commits intoPaddlePaddle:developfrom
Foriv:feat-all

Conversation

@Foriv
Copy link
Copy Markdown

@Foriv Foriv commented Mar 26, 2026

Scheduler Refactor (ResourceManagerV1): Ratio-based Decode Reservation, Chunked Prefill, and Retract-style Preemption

Changed Files

  • fastdeploy/engine/sched/resource_manager_v1.py
  • fastdeploy/envs.py
  • fastdeploy/scheduler/local_scheduler.py
  • fastdeploy/engine/sched/scheduler_metrics_logger.py

1) resource_manager_v1.py: Core Implementation Changes

1.1 Introduce active_chunked_prefill_req (single active chunked prefill)

  • Added state:
    • self.active_chunked_prefill_req: Request | None = None
  • Added helpers:
    • _num_active_running_requests() (counts active_chunked_prefill_req into running metrics)
    • _ensure_request_slot_allocated(request) (ensures tasks_list/stop_flags/req_dict slot correctness in mixed mode)
  • Integrated active_chunked_prefill_req handling into lifecycle paths:
    • preempted_all()
    • wait_worker_inflight_requests_finish()
    • pre_recycle_resource()
    • finish_requests()
    • clear_data()
    • update_metrics() and log_status() (running/waiting counters updated)

1.2 Decode reservation: replace fixed block reservation with new_token_ratio (ratio-based)

  • Removed the fixed per-request block reservation mechanism (reserve_output_block_num family):
    • init_reserve_output_block_num / decay_output_block_num / min_reserve_output_block_num
    • current_reserve_output_block_num / current_reserve_output_block_num_float
    • can_relax_prefill_strategy
  • Added ratio-based reservation fields:
    • init_new_token_ratio
    • min_new_token_ratio
    • new_token_ratio_decay
    • current_new_token_ratio
    • clip_max_new_tokens_estimation
  • Added/updated functions:
    • _calculate_decode_reserved_tokens_by_ratio() (ratio-based reserved token estimation for running decode-phase requests)
    • _calculate_decode_reserved_tokens_for_new_requests(new_decode_reserved_tokens) (per-cycle accumulation for newly admitted last-chunk requests)
    • _update_new_token_ratio_after_preemption() (update ratio after preemption based on remaining decode state)
    • reset_new_token_ratio_on_idle() (reset ratio back to init when system is fully idle)
  • Prefill admission threshold refactor:
    • _get_can_schedule_prefill_threshold_block(...) signature extended with:
      • is_last_chunk
      • new_decode_reserved_tokens
      • cached_running_decode_reserved
    • Threshold now accounts for:
      • current chunk tokens/blocks
      • future decode tokens for last chunk
      • reserved tokens from running decode requests (ratio-based)
      • reserved tokens from newly admitted last-chunk requests in the same cycle

1.3 Chunked prefill scheduling and budgets (SGLang-aligned)

  • Added per-step chunk budget and batch input budget:
    • chunked_prefill_size = envs.FD_CHUNKED_PREFILL_SIZE
    • rem_chunk_tokens = chunked_prefill_size
    • rem_input_tokens = envs.FD_REM_INPUT_TOKENS - running_decode_count
  • Added utilities:
    • _get_paged_prefill_tokens(num_new_tokens) (align prefill budget to block_size)
    • _is_last_prefill_chunk(request, num_new_tokens)
  • Reworked _get_num_new_tokens(...):
    • Signature changed from (request, token_budget) to:
      • (request, rem_chunk_tokens, rem_input_tokens, existing_prefill_in_batch, ignore_rem_input_budget)
    • Supports:
      • truncation under rem_chunk_tokens with block-aligned chunk sizing
      • rem_input_tokens gating (more strict once a prefill chunk is already admitted in the same cycle)
      • existing HPU alignment logic adapted to the new budget model
  • schedule() main flow changes:
    • If an unfinished prefill exists in running, migrate one into active_chunked_prefill_req (enforce single active unfinished chunked prefill)
    • Schedule one chunk for active_chunked_prefill_req first (if admission threshold passes)
    • Then schedule waiting under shared rem_chunk_tokens/rem_input_tokens budgets
    • If a waiting request becomes a non-last chunk (i.e., chunked), admit it and break (at most one newly chunked waiting request per cycle)

1.4 Preemption refactor: retract-style decode preemption + eviction

  • Completely rewrote _trigger_preempt():
    • Preemption candidates restricted to decode-phase requests
    • Selection order changed to:
      • shortest output first + longest input first
      • sorted by (len(output_token_ids), -prompt_token_ids_len) and popped
    • Added cache eviction during preemption:
      • _evict_decode_kv_cache(remaining_req_count)
      • eviction executed before attempting retraction, and again after each retraction
    • After retraction:
      • preempted_req.is_retracted = True
      • _update_new_token_ratio_after_preemption() is invoked
  • Updated preempted_all() to include:
    • preempting active_chunked_prefill_req together with running
    • special handling of use_extend_tables requests to reinsert into active_chunked_prefill_req or running as appropriate

1.5 Scheduling order and decode scheduling gate

  • Added internal _schedule_decode_requests() to handle decode-phase scheduling and block allocation/extend logic
  • Enforced scheduling order:
    1. active/running prefill (chunk)
    2. waiting prefill (budget + admission threshold controlled)
    3. decode: only scheduled when no prefill was admitted in this cycle and no preemption occurred
  • Ratio decay condition updated:
    • when decode requests exist, and no prefill is scheduled, and no preemption occurs, current_new_token_ratio decays linearly

1.6 Queue behavior change for rescheduled (preempted) requests

  • reschedule_preempt_task():
    • changed from waiting.appendleft(request) to waiting.append(request) (FIFO append)

2) envs.py: Environment Variables Update

  • Removed fixed block reservation envs:
    • FD_RESERVE_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
    • FD_RESERVE_DECAY_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
    • FD_RESERVE_MIN_OUTPUT_BLOCK_NUM_FOR_DECODE_WHEN_SCHEDULE_NEW_PREFILL
  • Added ratio-based decode reservation + chunked prefill + budget envs:
    • FD_INIT_NEW_TOKEN_RATIO (default 0.7)
    • FD_MIN_NEW_TOKEN_RATIO_FACTOR (default 0.30)
    • FD_NEW_TOKEN_RATIO_DECAY_STEPS (default 600)
    • FD_RETRACT_DECODE_STEPS (default 20)
    • FD_CLIP_MAX_NEW_TOKENS_ESTIMATION (default 4096)
    • FD_CHUNKED_PREFILL_SIZE (default 8192)
    • FD_REM_INPUT_TOKENS (default 16384)

3) local_scheduler.py: Remove block-accumulation admission filtering

  • In get_requests(), removed:
    • required_total_blocks/current_prefill_tokens accumulation and available_blocks break logic
    • long/short partial prefill counters and limits
  • Kept:
    • basic request selection (requests.append(request.raw)), leaving admission control and chunking to resource_manager_v1

4) scheduler_metrics_logger.py: Decode logging interval change

  • DEFAULT_DECODE_LOG_INTERVAL: 5 -> 1
  • Decode logging trigger changed from:
    • batch-count modulo
    • to elapsed-time threshold check

Foriv added 14 commits March 22, 2026 05:07
…ned token-ratio mechanism

- Remove can_relax_prefill_strategy, reserve_output_block_num and fixed block reservation system
- Add dynamic current_new_token_ratio with init/decay/preemption-reset (SGLang-aligned)
- Prefill threshold: use current chunk blocks only, not full remaining prefill
- Remove cross-chunk prefill block reservation (items 5 & 6 in threshold)
- Merge token_budget into rem_input_tokens (= FD_REM_INPUT_TOKENS - running_decode_count)
- Dual budget: rem_chunk_tokens (8192) + rem_input_tokens (16384 - decode_count)
- _get_num_new_tokens: min(remaining, rem_chunk_tokens, rem_input_tokens) with floor-align
- _trigger_preempt: SGLang retract_decode (shortest output first, evict prefix cache)
- envs.py: add FD_INIT_NEW_TOKEN_RATIO, FD_MIN_NEW_TOKEN_RATIO_FACTOR, FD_NEW_TOKEN_RATIO_DECAY_STEPS, FD_RETRACT_DECODE_STEPS, FD_CLIP_MAX_NEW_TOKENS_ESTIMATION; remove 7 obsolete vars
…unded prefill admission

Move cached_running_decode_reserved and scheduled_new_decode_reserved_tokens
init before RUNNING loop so both loops share the same reservation state.
Add _get_can_schedule_prefill_threshold_block check in RUNNING loop prefill
path, aligned with WAITING loop, to enforce decode reservation constraint
on chunked prefill continuation. Track last-chunk decode reservation in
RUNNING loop and propagate to WAITING loop.
…chitecture

RUNNING loop: only continue ONE chunked prefill then break, instead of
iterating all running prefills sharing rem_chunk_tokens budget (which
caused fragmentation: 8 prefills each getting ~1024 tokens instead of
one getting full 8192).

WAITING loop: split token consumption - chunked (non-last) requests
consume rem_chunk_tokens and trigger break after one admit; non-chunked
(last-chunk/short) requests consume rem_input_tokens only and continue
to be admitted. Applied to both WAITING and PREEMPTED branches.

Remove threshold check from RUNNING loop (no longer needed with single
prefill break). Move cached_running_decode_reserved back to WAITING
section init.
- Remove prealloc_dec_block_slot_num_threshold and enc_dec_block_num
- Allocate exactly 1 block only when current block is exhausted
- Evict decode KV cache before preemption
- Self-preempt to avoid livelock when only 1 request remains
Decay ratio only when all conditions are met:
- has_decode_requests: decode requests running
- not has_running_prefill: no chunked prefill in progress in running queue
- not self.waiting: waiting queue is empty (no new prefill admitted)
- not has_scheduled_prefill: no prefill was scheduled this round
- not preempted_reqs: no preemption occurred this round
This aligns with SGLang's is_extend_mode = has_running_prefill or bool(waiting),
ensuring ratio stays high when prefill is active (running or waiting).
- Add has_scheduled_running_prefill guard in RUNNING loop to prevent
  multiple in-flight prefill requests from being scheduled in one step
- Gate WAITING loop on not has_scheduled_running_prefill so new prefill
  admission is skipped when an in-flight prefill chunk is scheduled
- Use independent rem_chunk_tokens_waiting budget for WAITING loop so
  RUNNING in-flight prefill consumption cannot starve new requests
- Replace break with req_index += 1 after in-flight prefill to allow
  decode requests to continue being scheduled in the same step
- Priority reorder: move in-flight prefill to self.running[0] each step
- Fix _fetch_request to use max_prefill_batch directly instead of
  available_batch() which returns 0 during GPU execution
- Split #running-req log field into #running-decode and #running-prefill
- Fix decode log #token to show actual tokens generated per step instead
  of KV cache block-granularity usage (jumps of block_size=64)
- Remove unused tokens_used parameter from both log methods
- Add iteration summary: decode-only windows now healthy with 540+ tok/s throughput
- Document 4 remaining issues: dead reset_new_token_ratio_on_idle(), active_chunked_prefill_req missing in preempted_all()/wait_worker_inflight_requests_finish(), coarse last-chunk judgment, missing prefill-side priority preemption
- Prioritize fixing issues 1 and 2 for next iteration
The max_new variable was already clipped via min() when defined,
but scheduled_new_decode_reserved_tokens was applying min() again,
causing double clipping. Changed to simply add max_new.
Copilot AI review requested due to automatic review settings March 26, 2026 19:10
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Mar 26, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Mar 26, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants