From 3a6f29ef3803f801cd1780d5c6b7f71d4bd7e5c4 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 8 Jun 2026 22:46:10 +0400 Subject: [PATCH 01/28] merge Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt.yaml | 14 +- .../pipelines/cache_aware_rnnt_pipeline.py | 260 ++++++++++++++- .../streaming/state/cache_aware_rnnt_state.py | 49 +++ .../submodules/rnnt_malsd_batched_computer.py | 309 ++++++++++++++++++ 4 files changed, 630 insertions(+), 2 deletions(-) diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index 0b0616bb15c3..8c5f754d0c97 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -8,7 +8,7 @@ asr: compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' use_amp: true # Enable Automatic Mixed Precision decoding: - strategy: "greedy_batch" + strategy: "greedy_batch" preserve_alignments: false fused_batch_size: -1 greedy: @@ -30,6 +30,18 @@ asr: source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer), # used with `key_phrases_file` and `key_phrases_list` boosting_tree_alpha: 0.0 # Weight of the boosting tree + beam: + beam_size: 4 + return_best_hypothesis: true + score_norm: true + allow_cuda_graphs: true + # n-gram LM (off by default) + ngram_lm_model: null + key_phrases_file: null + key_phrases_list: null + key_phrase_items_list: null + source_lang: "en" + boosting_tree_alpha: 0.0 # ========================================== # Inverse Text Normalization Configuration diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 39ba0446b2dd..b3a9142c1a9c 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -238,6 +238,239 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N feature_buffers = torch.cat(feature_buffers).to(self.device) return feature_buffers, feature_buffer_lens + def _streaming_step( + self, + states: list[CacheAwareRNNTStreamingState], + feature_buffers: Tensor, + feature_buffer_lens: Tensor, + context, + previous_hypotheses: list[Hypothesis | None], + drop_extra_pre_encoded: int, + keep_all_outputs: bool, + prompt_vectors: Tensor | None, + biasing_enabled: bool, + ) -> tuple[list[Hypothesis], object]: + """ + Dispatcher between the greedy single-shot path and the MALSD beam path. + + For greedy (``self.decoding_computer is None``) this just calls the existing + ``asr_model.stream_step``. For MALSD it runs the encoder once and drives + :class:`ModifiedALSDBatchedRNNTComputer` with the per-stream beam carry. + """ + if self.decoding_computer is None: + return self.asr_model.stream_step( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + previous_hypotheses=previous_hypotheses, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + prompt_vectors=prompt_vectors, + ) + return self._malsd_stream_step( + states=states, + feature_buffers=feature_buffers, + feature_buffer_lens=feature_buffer_lens, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + biasing_enabled=biasing_enabled, + ) + + def _malsd_stream_step( + self, + states: list[CacheAwareRNNTMALSDStreamingState], + feature_buffers: Tensor, + feature_buffer_lens: Tensor, + context, + drop_extra_pre_encoded: int, + keep_all_outputs: bool, + biasing_enabled: bool, + ) -> tuple[list[Hypothesis], object]: + """ + One streaming step for the MALSD beam-search path: + + 1. Encoder-only pass - the decoder is driven by this pipeline, not by + the model's built-in decoding wrapper. + 2. Merge per-stream ``MALSDStateItem``s into a batched MALSD state. + 3. Run :class:`ModifiedALSDBatchedRNNTComputer` for this chunk. + 4. Update per-stream windowed-beam tracking from this chunk's emissions. + 5. Split the batched MALSD state back into per-stream carries. + 6. Build a cumulative ``Hypothesis`` per stream from + ``window_committed + window_beam_tokens[top1]``. + + Collapse to the chunk's top-1 is NOT performed here - beams stay + diverged across chunks and are collapsed per-stream at the EOU + boundary inside :meth:`run_malsd_decoder`. + + Returns a list of cumulative ``Hypothesis`` per stream and the new + encoder cache context, matching the shape of ``stream_step``. + """ + # Per-stream multi-biasing ids: not yet supported on the MALSD streaming + # path. Greedy-side per-stream biasing knobs stay independent. + if biasing_enabled: + logging.warning( + "Per-stream biasing is not yet wired up on the MALSD cache-aware " + "streaming path; ignoring biasing requests for this chunk." + ) + + # Merge per-stream carries into a batched MALSD state. ``None`` entries + # (fresh streams) are filled with the after-SOS state inside ``merge_to_batched_state``. + carries = [state.hyp_decoding_state for state in states] + if all(c is None for c in carries): + batched_state = None + else: + batched_state = self.decoding_computer.merge_to_batched_state(carries) + + # All MALSD GPU work (encoder, decoder, windowed walk, split) shares one + # ``inference_mode`` region: ``split_batched_state`` mutates the inference + # tensors returned by ``decoding_computer(...)`` in place, which is illegal + # once we've left the captured ``inference_mode`` region. + with ( + torch.amp.autocast( + device_type=self.asr_model.device_str, + dtype=self.asr_model.compute_dtype, + enabled=self.asr_model.use_amp, + ), + torch.inference_mode(), + ): + encoded, encoded_len, new_context = self.asr_model.encode_step( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + ) + # ``encoded`` from the encoder wrapper is shaped [B, D, T]; the MALSD + # computer expects [B, T, D] (matches the rest of the decoding stack). + encs_dim_last = encoded.transpose(1, 2).contiguous() + + best_batched_hyps, batched_state = self.decoding_computer( + encs_dim_last, encoded_len, batched_state + ) + + self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) + + # Per-stream top-1 beam slot. Indexes ``window_beam_tokens`` (which was + # just appended against the diverged beam slots) to build the publishable + # cumulative hypothesis below. + beam_indices_cpu = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() + scores_cpu = best_batched_hyps.scores.detach().cpu() + + carry_items = self.decoding_computer.split_batched_state(batched_state) + for state, carry in zip(states, carry_items): + state.hyp_decoding_state = carry + + # Build per-stream cumulative ``Hypothesis`` from the windowed state. + # Collapse + window promotion is deferred to ``run_malsd_decoder`` and + # triggered by EOU, so the published hyp is the current top-1's path + # but the K-beam state continues to diverge across chunks. + hyps: list[Hypothesis] = [] + for b, state in enumerate(states): + top1_slot = beam_indices_cpu[b] + window_tokens = state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] + window_ts = state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] + cum_tokens = state.window_committed_tokens + list(window_tokens) + cum_ts = state.window_committed_timestamps + list(window_ts) + + hyps.append( + Hypothesis( + score=float(scores_cpu[b, top1_slot].item()), + y_sequence=cum_tokens, + timestamp=cum_ts, + length=len(cum_tokens), + ) + ) + + return hyps, new_context + + def _update_windowed_beam_state( + self, + states: list[CacheAwareRNNTMALSDStreamingState], + best_batched_hyps: BatchedBeamHyps, + ) -> None: + """ + Extend each state's per-slot ``window_beam_tokens[k]`` with the chunk-local + emissions of the slot that originated from carry slot ``k`` at chunk start. + + The helper exposes per-(batch, beam) chunk-local tokens/timestamps and the + chunk-start -> chunk-end descent map; the permute-then-append windowed-beam + policy lives here. + """ + chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) + beam_size = best_batched_hyps.beam_size + for state, ct, cts, rp in zip(states, chunk_tokens, chunk_timestamps, root_ptrs): + prev_t = state.window_beam_tokens or [[] for _ in range(beam_size)] + prev_ts = state.window_beam_timestamps or [[] for _ in range(beam_size)] + state.window_beam_tokens = [prev_t[int(rp[k])] + ct[k] for k in range(beam_size)] + state.window_beam_timestamps = [prev_ts[int(rp[k])] + cts[k] for k in range(beam_size)] + + def run_malsd_decoder( + self, state: CacheAwareRNNTMALSDStreamingState, request: Request, hyp: Hypothesis + ) -> bool: + """ + MALSD counterpart to :meth:`run_greedy_decoder`. + + Reuses the greedy decoder for EOU detection, label-buffer rolling and + offset bookkeeping. Then RESYNCS ``state.tokens`` / ``state.timesteps`` / + ``state.confidences`` with the current top-1's cumulative slice + (``hyp.y_sequence[_malsd_utterance_start:]``). + + The resync is the load-bearing step that distinguishes MALSD from + greedy: between chunks, MALSD's raw-argmax top-1 can switch beams with + incompatible token histories (beam A: ``["I"]`` at chunk t, beam B: + ``["I", "I"]`` at chunk t+1). ``run_greedy_decoder`` appends + ``hyp.y_sequence[offset:]`` onto whatever was already in ``state.tokens``, + which would splice A's prefix with B's new tokens into a Frankenstein + transcript. Overwriting with the actual current top-1 belief keeps the + published transcript consistent with whichever beam currently wins. + + On EOU we bump ``_malsd_utterance_start`` to the current cumulative + length so the next utterance's resync slice starts past the cleared + previous utterance, then collapse the per-stream MALSD carry to its + top-1 beam: the K-beam state diverges intra-utterance and snaps to the + chosen path at the natural utterance boundary. + """ + eou_detected = self.run_greedy_decoder(state, request, hyp) + + # Resync state.tokens / state.timesteps / state.confidences with the + # current top-1's cumulative slice for this utterance. + all_tokens = list(hyp.y_sequence) if hyp.y_sequence is not None else [] + all_timestamps = list(hyp.timestamp) if hyp.timestamp is not None else [] + start = max(0, int(state._malsd_utterance_start)) + start = min(start, len(all_tokens)) + tokens_list = all_tokens[start:] + timestamps_list = all_timestamps[start:] + + state.tokens = list(tokens_list) + state.timesteps = list(timestamps_list) + state.confidences = [0.0] * len(tokens_list) + if tokens_list: + state.last_token = tokens_list[-1] + state.last_token_idx = timestamps_list[-1] if timestamps_list else None + + if eou_detected: + # Mark the boundary so the next utterance's slice starts past the + # tokens we just finalised. + state._malsd_utterance_start = len(all_tokens) + + # EOU-driven collapse: promote the chosen window into the committed + # prefix and replicate the winning beam across all K slots of the + # per-stream carry. The predictor stays warm at the top-1's last + # label so the next utterance benefits from cross-utterance context. + if state.hyp_decoding_state is not None: + top1 = int(state.hyp_decoding_state.score.argmax().item()) + self.decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) + state.window_committed_tokens = list(all_tokens) + state.window_committed_timestamps = list(all_timestamps) + state.window_beam_tokens = None + state.window_beam_timestamps = None + return eou_detected + def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: """ Run the greedy RNNT decoder on the hypothesis and update the state @@ -384,9 +617,30 @@ def cache_aware_transcribe_step( if request.is_last and state.has_biasing_request(): if state.options.biasing_cfg.auto_manage_multi_model: state.options.biasing_cfg.remove_from_multi_model( - biasing_multi_model=decoding_computer.biasing_multi_model + biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model ) + def _debug_print_finals(self, ready_state_ids: set) -> None: + """DEBUG: print finalised transcripts so greedy vs MALSD logs can be diffed.""" + strategy = "malsd" if self.decoding_computer is not None else "greedy" + for sid in sorted(ready_state_ids): + state = self.get_state(sid) + print( + f"[CMP][FINAL] strategy={strategy} stream={sid} text={state.final_transcript!r}", + flush=True, + ) + + def _debug_print_partials(self, requests: list[Request]) -> None: + """DEBUG: print partial / current-step transcripts so greedy vs MALSD logs can be diffed.""" + strategy = "malsd" if self.decoding_computer is not None else "greedy" + for req in requests: + state = self.get_state(req.stream_id) + print( + f"[CMP][PARTIAL] strategy={strategy} stream={req.stream_id} " + f"partial={state.partial_transcript!r} step={state.current_step_transcript!r}", + flush=True, + ) + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. @@ -425,9 +679,11 @@ def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(fbuffers, self.tokenizer, self.leading_regex_pattern) + self._debug_print_partials(fbuffers) def transcribe_step_for_frames(self, frames: list[Frame]) -> None: """ @@ -471,9 +727,11 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None: # post-process the ready states if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) + self._debug_print_partials(frames) def get_request_generator(self) -> ContinuousBatchedRequestStreamer: """ diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index c9374c37ba26..b30cc7e55aa0 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -64,3 +64,52 @@ def reset_previous_hypothesis(self) -> None: Reset the previous hypothesis to None """ self.previous_hypothesis = None + + +class CacheAwareRNNTMALSDStreamingState(CacheAwareRNNTStreamingState): + """ + Cache-aware RNNT state with MALSD beam-search per-stream bookkeeping. + + Adds the following fields on top of the greedy state: + + - ``hyp_decoding_state``: per-stream beam carry (``MALSDStateItem``-like) + shuttled between :meth:`merge_to_batched_state` and :meth:`split_batched_state`. + - ``window_committed_tokens`` / ``window_committed_timestamps``: cumulative + prefix shared by all surviving beams at the most recent EOU boundary. + - ``window_beam_tokens`` / ``window_beam_timestamps``: per-slot chunk-local + cumulative emissions since the last EOU (one list per beam slot). Beams + stay diverged across chunks; the chosen path is committed at EOU. + - ``_malsd_utterance_start``: position in the cumulative ``hyp.y_sequence`` + where the current utterance begins, so EOU + ``cleanup_after_eou`` can + correctly slice past previously emitted (and cleared) utterances. + """ + + def _additional_params_reset(self) -> None: + """ + Reset MALSD per-stream carry on top of the greedy state. + """ + super()._additional_params_reset() + self.hyp_decoding_state: Any = None + self.window_committed_tokens: list[int] = [] + self.window_committed_timestamps: list[int] = [] + self.window_beam_tokens: list[list[int]] | None = None + self.window_beam_timestamps: list[list[int]] | None = None + self._malsd_utterance_start: int = 0 + + def reset_previous_hypothesis(self) -> None: + """ + Reset the previous hypothesis and all MALSD beam-search bookkeeping. + + Called at end-of-stream. Zeroes out the MALSD per-stream carry so the + next utterance starts from SOS with an empty windowed-beam state. + """ + super().reset_previous_hypothesis() + self.hyp_decoding_state = None + self.window_committed_tokens = [] + self.window_committed_timestamps = [] + self.window_beam_tokens = None + self.window_beam_timestamps = None + # NB: ``_malsd_utterance_start`` is intentionally NOT reset here because + # the cumulative ``hyp.y_sequence`` it indexes is owned by the pipeline + # and bumped after the call when the previous utterance is being + # finalised. The pipeline bumps it explicitly after publishing. diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 133f1dc82fb3..ebeca4a521e9 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1378,6 +1378,315 @@ def _create_decoding_state( **self.state.batched_hyps.export_cross_chunk_state(batch_size=current_batch_size), ) + def _get_state_item_after_sos(self, device: torch.device | str) -> MALSDStateItem: + """ + Per-stream after-SOS state. Used by :meth:`merge_to_batched_state` to fill + ``None`` items (fresh streams that joined the batch mid-flight). + + Built by constructing a ``batch_size=1`` batched after-SOS state and + taking the first item out of :meth:`split_batched_state` - mirrors the + greedy ``_get_decoding_state_item_after_sos`` pattern. + """ + batched = self._get_batched_state_after_sos(device=device, batch_size=1) + return self.split_batched_state(batched)[0] + + def _get_batched_state_after_sos( + self, device: torch.device | str, batch_size: int + ) -> BatchedBeamState: + """ + Build a fresh batched MALSD state after ````. + + Shapes follow the contract consumed by :meth:`split_batched_state`: + predictor state/outputs are flat ``[B*K, ...]``; per-beam fields are + ``[B, K]``; fusion states are ``[B, K, ...]``; ``decoded_lengths`` is ``[B]``. + Slot 0 starts active (``score=0.0``); slots ``1..K-1`` start inactive so the + next chunk's top-k expands the surviving beam. + """ + beam_size = self.beam_size + total = batch_size * beam_size + + sos_labels = torch.full([total], fill_value=self._SOS, dtype=torch.long, device=device) + decoder_output, predictor_state, *_ = self.decoder.predict( + sos_labels.unsqueeze(1), None, add_sos=False, batch_size=total + ) + decoder_output = self.joint.project_prednet(decoder_output) # [B*K, 1, D] + + scores = torch.full( + [batch_size, beam_size], fill_value=INACTIVE_SCORE, dtype=decoder_output.dtype, device=device + ) + scores[:, 0] = 0.0 + + fusion_states_list: list[torch.Tensor] = [] + if self.fusion_models is not None: + for fm in self.fusion_models: + fs = fm.get_init_states(batch_size=total, bos=True).to(device) + fusion_states_list.append(fs.reshape(batch_size, beam_size, *fs.shape[1:])) + + def zeros_bk() -> torch.Tensor: + return torch.zeros([batch_size, beam_size], dtype=torch.long, device=device) + + return BatchedBeamState( + predictor_states=predictor_state, + predictor_outputs=decoder_output, + labels=sos_labels.view(batch_size, beam_size), + decoded_lengths=torch.zeros([batch_size], dtype=torch.long, device=device), + fusion_states_list=fusion_states_list, + time_jumps=None, + scores=scores, + transcript_hash=zeros_bk(), + current_lengths_nb=zeros_bk(), + last_timestamp_lasts=zeros_bk(), + transcript_prefix_hash=None, + ) + + def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: + """ + Split a batched MALSD state into per-stream ``MALSDStateItem``s. + + Mirrors ``GreedyBatchedLabelLoopingComputerBase.split_batched_state`` for + beam-search shapes: + + - the predictor state was created with batch dimension ``B * beam_size``; + we slice it into ``B`` groups of ``beam_size`` consecutive rows and + re-batch each group with ``decoder.batch_unsplit_states``. + - ``labels`` / ``decoded_lengths`` and per-beam cross-chunk scalars are + split along the batch axis. + - ``fusion_states_list`` has each element as ``[B, beam_size, ...]``. + """ + if state is None: + return [] + batch_size = state.labels.shape[0] + beam_size = self.beam_size + + per_row_states = self.decoder.batch_split_states(state.predictor_states) + if len(per_row_states) != batch_size * beam_size: + raise AssertionError( + f"Expected predictor states with batch dim {batch_size * beam_size}, " + f"got {len(per_row_states)} per-row items" + ) + + items: list[MALSDStateItem] = [] + for i in range(batch_size): + stream_predictor_state = self.decoder.batch_unsplit_states( + per_row_states[i * beam_size : (i + 1) * beam_size] + ) + # ``state.fusion_states_list[k]`` is stored as ``[B, K]`` (see + # ``modified_alsd_torch``'s ``s.view(batch_size, self.beam_size)`` step). + fusion_state_list = ( + [fs[i].clone() for fs in state.fusion_states_list] if state.fusion_states_list else [] + ) + items.append( + MALSDStateItem( + predictor_state=stream_predictor_state, + predictor_output=state.predictor_outputs[i * beam_size : (i + 1) * beam_size].clone(), + label=state.labels[i].clone(), + decoded_length=state.decoded_lengths[i].clone(), + score=state.scores[i].clone() if state.scores is not None else None, + transcript_hash=( + state.transcript_hash[i].clone() if state.transcript_hash is not None else None + ), + current_lengths_nb=( + state.current_lengths_nb[i].clone() if state.current_lengths_nb is not None else None + ), + last_timestamp_lasts=( + state.last_timestamp_lasts[i].clone() if state.last_timestamp_lasts is not None else None + ), + transcript_prefix_hash=( + state.transcript_prefix_hash[i].clone() + if state.transcript_prefix_hash is not None + else None + ), + fusion_state_list=fusion_state_list, + ) + ) + return items + + def merge_to_batched_state(self, state_items: list[Optional[MALSDStateItem]]) -> BatchedBeamState: + """ + Merge a list of per-stream ``MALSDStateItem``s into a single batched MALSD state. + + ``None`` entries (e.g. fresh streams that joined a batch mid-flight) are + replaced with a freshly-initialised after-SOS state. + """ + if any(item is None for item in state_items): + not_none_item = next(item for item in state_items if item is not None) + device = not_none_item.predictor_output.device + start_item = self._get_state_item_after_sos(device=device) + state_items = [item if item is not None else start_item for item in state_items] + + per_row_states: list[Any] = [] + for item in state_items: + per_row_states.extend(self.decoder.batch_split_states(item.predictor_state)) + batched_predictor_state = self.decoder.batch_unsplit_states(per_row_states) + + predictor_outputs = torch.cat([item.predictor_output for item in state_items], dim=0) + labels = torch.stack([item.label for item in state_items], dim=0) + decoded_lengths = torch.stack([item.decoded_length for item in state_items], dim=0) + scores = torch.stack([item.score for item in state_items], dim=0) + transcript_hash = torch.stack([item.transcript_hash for item in state_items], dim=0) + current_lengths_nb = torch.stack([item.current_lengths_nb for item in state_items], dim=0) + last_timestamp_lasts = ( + torch.stack([item.last_timestamp_lasts for item in state_items], dim=0) + if state_items[0].last_timestamp_lasts is not None + else None + ) + transcript_prefix_hash = ( + torch.stack([item.transcript_prefix_hash for item in state_items], dim=0) + if state_items[0].transcript_prefix_hash is not None + else None + ) + + num_fusion = len(state_items[0].fusion_state_list) + # Per-stream ``fusion_state_list[fusion_idx]`` is ``[K]``; stack along a new dim 0 + # to produce ``[B, K]`` (NOT ``cat`` which would give the flat ``[B*K]`` shape used + # by ``predictor_*`` and would trip downstream shape mismatches). + fusion_states_list = [ + torch.stack([item.fusion_state_list[fi] for item in state_items], dim=0) for fi in range(num_fusion) + ] + + return BatchedBeamState( + predictor_states=batched_predictor_state, + predictor_outputs=predictor_outputs, + labels=labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + scores=scores, + transcript_hash=transcript_hash, + current_lengths_nb=current_lengths_nb, + last_timestamp_lasts=last_timestamp_lasts, + transcript_prefix_hash=transcript_prefix_hash, + ) + + def collapse_batched_state_to_beams_( + self, + state: BatchedBeamState, + batched_hyps: BatchedBeamHyps, + beam_indices: torch.Tensor, + ) -> None: + """ + In-place: collapse each row of a batched MALSD state and its associated + :class:`BatchedBeamHyps` to a single surviving beam, replicated across all + ``beam_size`` slots. + + After the call, every per-beam tensor on ``state`` and on ``batched_hyps`` + carries the chosen beam's value at slot 0 and identical clones at slots + 1..beam_size-1; ``scores[:, 1:]`` is set to ``INACTIVE_SCORE`` so the next + chunk's top-k repopulates them through normal expansion of the surviving beam. + + Args: + state: batched MALSD state to collapse in place. + batched_hyps: prefix-tree object returned alongside ``state``. Mutated + in place via :meth:`BatchedBeamHyps.keep_beam_`. + beam_indices: ``[batch_size]`` long tensor giving the beam to keep per row. + """ + batch_size = state.labels.shape[0] + beam_size = self.beam_size + if beam_indices.shape != (batch_size,): + raise ValueError( + f"beam_indices must have shape [batch_size={batch_size}], got {tuple(beam_indices.shape)}" + ) + + device = state.labels.device + beam_indices = beam_indices.to(dtype=torch.long, device=device) + + row_offsets = torch.arange(batch_size, device=device, dtype=torch.long) * beam_size + chosen_flat_idx = row_offsets + beam_indices # [B] + flat_perm = chosen_flat_idx.unsqueeze(-1).expand(batch_size, beam_size).reshape(-1) # [B*K] + + per_row = self.decoder.batch_split_states(state.predictor_states) + if len(per_row) != batch_size * beam_size: + raise AssertionError( + f"Expected predictor states with batch dim {batch_size * beam_size}, " + f"got {len(per_row)} per-row items" + ) + replicated_per_row = [per_row[int(idx)] for idx in flat_perm.tolist()] + state.predictor_states = self.decoder.batch_unsplit_states(replicated_per_row) + + state.predictor_outputs = state.predictor_outputs.index_select(0, flat_perm).contiguous() + + beam_perm = beam_indices.unsqueeze(-1).expand(batch_size, beam_size) + state.labels = torch.gather(state.labels, dim=1, index=beam_perm).contiguous() + if state.scores is not None: + state.scores = torch.gather(state.scores, dim=1, index=beam_perm).contiguous() + state.scores[:, 1:].fill_(INACTIVE_SCORE) + if state.transcript_hash is not None: + state.transcript_hash = torch.gather(state.transcript_hash, dim=1, index=beam_perm).contiguous() + if state.current_lengths_nb is not None: + state.current_lengths_nb = torch.gather(state.current_lengths_nb, dim=1, index=beam_perm).contiguous() + if state.last_timestamp_lasts is not None: + state.last_timestamp_lasts = torch.gather( + state.last_timestamp_lasts, dim=1, index=beam_perm + ).contiguous() + if state.transcript_prefix_hash is not None: + state.transcript_prefix_hash = torch.gather( + state.transcript_prefix_hash, dim=1, index=beam_perm + ).contiguous() + + if state.fusion_states_list: + # Fusion states are reshaped to ``[B, K]`` inside ``modified_alsd_torch`` + # so use the per-stream ``beam_perm`` gather along the beam axis. + for fs in state.fusion_states_list: + if fs.ndim != 2: + raise NotImplementedError( + f"collapse_batched_state_to_beams_ only supports rank-2 [B, K] " + f"fusion states; got shape {tuple(fs.shape)}" + ) + state.fusion_states_list = [ + torch.gather(fs, dim=1, index=beam_perm).contiguous() for fs in state.fusion_states_list + ] + + batched_hyps.keep_beam_(beam_indices) + + def collapse_state_item_to_top1_(self, item: MALSDStateItem, beam_index: int) -> None: + """ + In-place per-stream variant of :meth:`collapse_batched_state_to_beams_`. + + Replicates beam ``beam_index`` across all ``beam_size`` slots of ``item`` + and sets ``score[1:] = INACTIVE_SCORE`` so the next chunk's top-k expands + the surviving beam. Used by streaming pipelines to collapse a single + stream's MALSD carry at its EOU boundary without disturbing other rows + of a batched run. + + Wraps mutations in :func:`torch.inference_mode` so it can be called from + outside the encoder/decoder inference region (the per-stream tensors are + inference tensors produced by :meth:`split_batched_state`). + """ + beam_size = self.beam_size + if not 0 <= beam_index < beam_size: + raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") + + with torch.inference_mode(): + per_row = self.decoder.batch_split_states(item.predictor_state) + if len(per_row) != beam_size: + raise AssertionError( + f"Expected per-stream predictor state with batch dim {beam_size}, got {len(per_row)}" + ) + item.predictor_state = self.decoder.batch_unsplit_states([per_row[beam_index]] * beam_size) + + item.predictor_output = ( + item.predictor_output[beam_index : beam_index + 1] + .expand(beam_size, *item.predictor_output.shape[1:]) + .contiguous() + ) + + idx = torch.full([beam_size], fill_value=beam_index, dtype=torch.long, device=item.label.device) + item.label = item.label.index_select(0, idx).contiguous() + if item.score is not None: + item.score = item.score.index_select(0, idx).contiguous() + item.score[1:].fill_(INACTIVE_SCORE) + if item.transcript_hash is not None: + item.transcript_hash = item.transcript_hash.index_select(0, idx).contiguous() + if item.current_lengths_nb is not None: + item.current_lengths_nb = item.current_lengths_nb.index_select(0, idx).contiguous() + if item.last_timestamp_lasts is not None: + item.last_timestamp_lasts = item.last_timestamp_lasts.index_select(0, idx).contiguous() + if item.transcript_prefix_hash is not None: + item.transcript_prefix_hash = item.transcript_prefix_hash.index_select(0, idx).contiguous() + + for fi, fs in enumerate(item.fusion_state_list): + item.fusion_state_list[fi] = fs.index_select(0, idx).contiguous() + def __call__( self, x: torch.Tensor, From 03937b0d6a13db1a795ff087b869214fc29d929a Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 8 Jun 2026 22:30:05 +0400 Subject: [PATCH 02/28] add n-chunk reseting working Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_inference_wrapper.py | 65 ++++-- .../pipelines/cache_aware_rnnt_pipeline.py | 207 +++++++++++------- .../streaming/state/cache_aware_rnnt_state.py | 23 +- .../submodules/rnnt_malsd_batched_computer.py | 72 ++---- .../utils/batched_beam_decoding_utils.py | 206 +++++++++++------ .../asr/parts/utils/streaming_utils.py | 4 +- 6 files changed, 354 insertions(+), 223 deletions(-) diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index 3aedcab10abc..4855c3868c4d 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -73,32 +73,24 @@ def get_vocabulary(self) -> list[str]: """ return self.asr_model.joint.vocabulary - def execute_step( + def encode_step( self, processed_signal: Tensor, processed_signal_length: Tensor, context: CacheAwareContext, - previous_hypotheses: list[Hypothesis] | None, drop_extra_pre_encoded: int | None, keep_all_outputs: bool, drop_left_context: int | None = None, valid_out_len: int | None = None, - prompt_vectors: Tensor | None = None, - ) -> tuple[list[Hypothesis], CacheAwareContext]: + ) -> tuple[Tensor, Tensor, CacheAwareContext]: """ - Executes a single streaming step. - Args: - processed_signal: (Tensor) input signal tensor. - processed_signal_length: (Tensor) input signal length tensor. - context: (CacheAwareContext) context object. - previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. - drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. - keep_all_outputs: (bool) whether to keep all outputs or not. - drop_left_context: (int | None) number of left context frames to drop. - valid_out_len: (int | None) number of valid output frames. - prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. - Returns: - (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + Run the cache-aware encoder for one streaming chunk, returning the (trimmed) + encoder output and updated streaming context. Decoder is NOT invoked. + + Used by :meth:`execute_step` (greedy decoder runs right after) and by + beam-search pipelines that drive the decoder themselves with a + per-stream beam carry (they call ``encode_step`` directly inside their + own ``autocast`` + ``inference_mode`` region). """ ( encoded, @@ -131,6 +123,45 @@ def execute_step( encoded = encoded[:, :, :valid_out_len] encoded_len = torch.ones_like(encoded_len) * valid_out_len + return encoded, encoded_len, new_context + + def execute_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + previous_hypotheses: list[Hypothesis] | None, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + prompt_vectors: Tensor | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. + Returns: + (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + """ + encoded, encoded_len, new_context = self.encode_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=drop_left_context, + valid_out_len=valid_out_len, + ) + best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses ) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index b3a9142c1a9c..f648dc94fb5f 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -31,7 +31,10 @@ from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions -from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTStreamingState +from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import ( + CacheAwareRNNTMALSDStreamingState, + CacheAwareRNNTStreamingState, +) from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames from nemo.collections.asr.inference.utils.enums import RequestType from nemo.collections.asr.inference.utils.pipeline_utils import ( @@ -39,6 +42,11 @@ drop_trailing_features, get_confidence_utils, ) +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( + BatchedBeamHyps, + export_batched_beam_hyps_to_cpu_lists, +) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -76,8 +84,31 @@ def __init__( self.init_endpointer() self.init_text_processor(cfg, itn_model) self.init_nmt_model(nmt_model) + self.init_decoding_computer() super().__init__() + def init_decoding_computer(self) -> None: + """ + Probe the model's decoding stack once and stash the resulting computer + on ``self`` so per-chunk code can branch on it without re-doing the + attribute-chain dive. + + Exactly one of ``self.decoding_computer`` (MALSD beam-search) and + ``self.greedy_decoding_computer`` (greedy, used for per-stream biasing + detection) is non-``None`` for any supported decoding stack; both are + ``None`` if the stack exposes no ``decoding_computer`` at all. + """ + try: + decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + except AttributeError: + decoding_computer = None + if isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer): + self.decoding_computer: ModifiedALSDBatchedRNNTComputer | None = decoding_computer + self.greedy_decoding_computer = None + else: + self.decoding_computer = None + self.greedy_decoding_computer = decoding_computer + def init_parameters(self, cfg: DictConfig) -> None: """ Initialize the parameters. @@ -149,6 +180,13 @@ def init_parameters(self, cfg: DictConfig) -> None: self.request_type = RequestType.from_str(cfg.streaming.request_type) + # MALSD beam-search streaming knobs. ``chunks_per_beam_reset == 1`` collapses + # the K-beam state down to a single top-1 hypothesis after every chunk, which + # matches the original cache-aware behaviour. Higher values preserve beam + # diversity across multiple chunks before collapsing - currently only the + # ``== 1`` path is fully wired up; larger values fall back to it. + self.chunks_per_beam_reset = int(cfg.streaming.get("chunks_per_beam_reset", 1)) + def init_greedy_rnnt_decoder(self) -> None: """Initialize the RNNT decoder.""" check_existance_of_required_attributes(self, ['vocabulary', 'conf_func']) @@ -179,9 +217,12 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta Args: options: (ASRRequestOptions) Request options for particular stream. Returns: - (CacheAwareRNNTStreamingState) New empty state. + (CacheAwareRNNTStreamingState) New empty state. Returns the MALSD subclass + when the pipeline is configured for beam-search decoding. """ - state = CacheAwareRNNTStreamingState() + state = ( + CacheAwareRNNTMALSDStreamingState() if self.decoding_computer is not None else CacheAwareRNNTStreamingState() + ) state.set_global_offset(0) new_options = options.fill_defaults( default_enable_itn=self.text_processor.itn_enabled, @@ -297,19 +338,17 @@ def _malsd_stream_step( 2. Merge per-stream ``MALSDStateItem``s into a batched MALSD state. 3. Run :class:`ModifiedALSDBatchedRNNTComputer` for this chunk. 4. Update per-stream windowed-beam tracking from this chunk's emissions. - 5. Split the batched MALSD state back into per-stream carries. - 6. Build a cumulative ``Hypothesis`` per stream from + 5. Optionally collapse to the chunk's top-1 (current default behaviour). + 6. Split the batched MALSD state back into per-stream carries. + 7. Build a cumulative ``Hypothesis`` per stream from ``window_committed + window_beam_tokens[top1]``. - Collapse to the chunk's top-1 is NOT performed here - beams stay - diverged across chunks and are collapsed per-stream at the EOU - boundary inside :meth:`run_malsd_decoder`. - Returns a list of cumulative ``Hypothesis`` per stream and the new encoder cache context, matching the shape of ``stream_step``. """ # Per-stream multi-biasing ids: not yet supported on the MALSD streaming # path. Greedy-side per-stream biasing knobs stay independent. + multi_biasing_ids = None if biasing_enabled: logging.warning( "Per-stream biasing is not yet wired up on the MALSD cache-aware " @@ -324,10 +363,11 @@ def _malsd_stream_step( else: batched_state = self.decoding_computer.merge_to_batched_state(carries) - # All MALSD GPU work (encoder, decoder, windowed walk, split) shares one - # ``inference_mode`` region: ``split_batched_state`` mutates the inference - # tensors returned by ``decoding_computer(...)`` in place, which is illegal - # once we've left the captured ``inference_mode`` region. + # All MALSD GPU work (encoder, decoder, windowed walk, collapse, split) + # shares one ``inference_mode`` region: ``collapse_batched_state_to_beams_`` + # and ``split_batched_state`` mutate the inference tensors returned by + # ``decoding_computer(...)`` in place, which is illegal once we've left + # the captured ``inference_mode`` region. with ( torch.amp.autocast( device_type=self.asr_model.device_str, @@ -355,37 +395,65 @@ def _malsd_stream_step( self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) - # Per-stream top-1 beam slot. Indexes ``window_beam_tokens`` (which was - # just appended against the diverged beam slots) to build the publishable - # cumulative hypothesis below. + # Capture pre-collapse argmax + scores. After ``collapse_batched_state_to_beams_`` + # runs, ``scores[:, 1:]`` is forced to ``INACTIVE_SCORE`` and ``scores[:, 0]`` + # carries the winner - so any post-collapse argmax returns 0 unconditionally. + # We need the PRE-collapse slot index to index ``window_beam_tokens`` (which + # was just computed against the diverged pre-collapse slots). beam_indices_cpu = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() - scores_cpu = best_batched_hyps.scores.detach().cpu() + scores_pre_collapse = best_batched_hyps.scores.detach().cpu() + + # Collapse the K-beam state at the configured cadence. For now we always + # collapse every chunk (``chunks_per_beam_reset == 1``); the multi-chunk + # window is a follow-up that needs full prefix-tree carry across chunks. + for state in states: + state._malsd_chunk_count += 1 + do_collapse = self.chunks_per_beam_reset <= 1 or any( + state._malsd_chunk_count >= self.chunks_per_beam_reset for state in states + ) + if do_collapse: + beam_indices = best_batched_hyps.scores.argmax(dim=-1).to(torch.long) + self.decoding_computer.collapse_batched_state_to_beams_( + batched_state, best_batched_hyps, beam_indices + ) carry_items = self.decoding_computer.split_batched_state(batched_state) for state, carry in zip(states, carry_items): state.hyp_decoding_state = carry - # Build per-stream cumulative ``Hypothesis`` from the windowed state. - # Collapse + window promotion is deferred to ``run_malsd_decoder`` and - # triggered by EOU, so the published hyp is the current top-1's path - # but the K-beam state continues to diverge across chunks. + # Build per-stream cumulative ``Hypothesis`` from the windowed state, + # then (on collapse chunks) promote the chosen beam's window tokens into + # the committed prefix and clear the window. The published hypothesis + # is identical pre/post-collapse promotion - just with everything moved + # into ``committed`` afterwards. hyps: list[Hypothesis] = [] for b, state in enumerate(states): top1_slot = beam_indices_cpu[b] - window_tokens = state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] - window_ts = state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] + window_tokens = ( + state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] + ) + window_ts = ( + state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] + ) cum_tokens = state.window_committed_tokens + list(window_tokens) cum_ts = state.window_committed_timestamps + list(window_ts) hyps.append( Hypothesis( - score=float(scores_cpu[b, top1_slot].item()), + score=float(scores_pre_collapse[b, top1_slot].item()), y_sequence=cum_tokens, timestamp=cum_ts, length=len(cum_tokens), ) ) + if do_collapse: + state._malsd_chunk_count = 0 + state.window_committed_tokens = list(cum_tokens) + state.window_committed_timestamps = list(cum_ts) + state.window_beam_tokens = None + state.window_beam_timestamps = None + return hyps, new_context def _update_windowed_beam_state( @@ -431,9 +499,7 @@ def run_malsd_decoder( On EOU we bump ``_malsd_utterance_start`` to the current cumulative length so the next utterance's resync slice starts past the cleared - previous utterance, then collapse the per-stream MALSD carry to its - top-1 beam: the K-beam state diverges intra-utterance and snaps to the - chosen path at the natural utterance boundary. + previous utterance. """ eou_detected = self.run_greedy_decoder(state, request, hyp) @@ -454,21 +520,9 @@ def run_malsd_decoder( state.last_token_idx = timestamps_list[-1] if timestamps_list else None if eou_detected: - # Mark the boundary so the next utterance's slice starts past the - # tokens we just finalised. + # mark the boundary so the next utterance's slice starts past the + # tokens we just finalised state._malsd_utterance_start = len(all_tokens) - - # EOU-driven collapse: promote the chosen window into the committed - # prefix and replicate the winning beam across all K slots of the - # per-stream carry. The predictor stays warm at the top-1's last - # label so the next utterance benefits from cross-utterance context. - if state.hyp_decoding_state is not None: - top1 = int(state.hyp_decoding_state.score.argmax().item()) - self.decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) - state.window_committed_tokens = list(all_tokens) - state.window_committed_timestamps = list(all_timestamps) - state.window_beam_tokens = None - state.window_beam_timestamps = None return eou_detected def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: @@ -543,12 +597,13 @@ def cache_aware_transcribe_step( previous_hypotheses = [state.get_previous_hypothesis() for state in states] - try: - decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer - biasing_enabled = decoding_computer.per_stream_biasing_enabled - except AttributeError: - decoding_computer = None - biasing_enabled = False + # Per-stream biasing is only wired up on the greedy decoder. When MALSD + # is active ``self.greedy_decoding_computer`` is ``None`` (see + # :meth:`init_decoding_computer`) so ``biasing_enabled`` falls back to + # ``False`` and the warning in ``_malsd_stream_step`` covers the rest. + biasing_enabled = ( + self.greedy_decoding_computer is not None and self.greedy_decoding_computer.per_stream_biasing_enabled + ) if not biasing_enabled and any(state.has_biasing_request() for state in states): logging.warning("Biasing request is not empty, but decoder does not support per-stream biasing. Skipping") @@ -561,7 +616,7 @@ def cache_aware_transcribe_step( if state.options.biasing_cfg.auto_manage_multi_model: state.options.biasing_cfg.add_to_multi_model( tokenizer=self.asr_model.tokenizer, - biasing_multi_model=decoding_computer.biasing_multi_model, + biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model, ) else: logging.warning( @@ -579,38 +634,49 @@ def cache_aware_transcribe_step( prompt_vectors = self._build_prompt_vectors(states) drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded - best_hyp, new_context = self.asr_model.stream_step( - processed_signal=feature_buffers, - processed_signal_length=feature_buffer_lens, + best_hyp, new_context = self._streaming_step( + states=states, + feature_buffers=feature_buffers, + feature_buffer_lens=feature_buffer_lens, context=context, previous_hypotheses=previous_hypotheses, drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, - drop_left_context=self.drop_left_context, - valid_out_len=self.valid_out_len, prompt_vectors=prompt_vectors, + biasing_enabled=biasing_enabled, ) # update the cache and reset the cache slots for the streams that has ended self.context_manager.update_cache(stream_ids, new_context, mapping) self.context_manager.reset_slots(stream_ids, eos_flags) - # update the previous hypothesis and reset the previous hypothesis for the streams that has ended + # update the previous hypothesis for non-eos streams. For greedy this is the + # ``Hypothesis`` returned by ``rnnt_decoder_predictions_tensor``; for MALSD + # it is the cumulative ``Hypothesis`` built in ``_malsd_stream_step``. The + # eos reset is deferred to *after* the per-request decoder loop below so + # that ``run_malsd_decoder`` can still see the current utterance start. for state, hyp, eos in zip(states, best_hyp, eos_flags): - if eos: - state.reset_previous_hypothesis() - else: + if not eos: state.set_previous_hypothesis(hyp) - # run greedy decoder for each request-state-hypothesis tuple + # run per-request decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): - eou_detected = self.run_greedy_decoder(state, request, hyp) + if self.decoding_computer is not None: + eou_detected = self.run_malsd_decoder(state, request, hyp) + else: + eou_detected = self.run_greedy_decoder(state, request, hyp) if eou_detected: self.bpe_decoder.decode_bpe_tokens(state) state.cleanup_after_eou() ready_state_ids.add(request.stream_id) - # Cleanup per-stream biasing models when stream ends + # Deferred eos reset - now safe to clear MALSD per-stream carry too. + for state, eos in zip(states, eos_flags): + if eos: + state.reset_previous_hypothesis() + + # Cleanup per-stream biasing models when stream ends (greedy path only; + # ``biasing_enabled`` is True only when ``self.greedy_decoding_computer`` is set). if biasing_enabled: for request, state in zip(requests, states): # only the first request contains biasing options; biasing options for the stream are stored in state @@ -620,27 +686,6 @@ def cache_aware_transcribe_step( biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model ) - def _debug_print_finals(self, ready_state_ids: set) -> None: - """DEBUG: print finalised transcripts so greedy vs MALSD logs can be diffed.""" - strategy = "malsd" if self.decoding_computer is not None else "greedy" - for sid in sorted(ready_state_ids): - state = self.get_state(sid) - print( - f"[CMP][FINAL] strategy={strategy} stream={sid} text={state.final_transcript!r}", - flush=True, - ) - - def _debug_print_partials(self, requests: list[Request]) -> None: - """DEBUG: print partial / current-step transcripts so greedy vs MALSD logs can be diffed.""" - strategy = "malsd" if self.decoding_computer is not None else "greedy" - for req in requests: - state = self.get_state(req.stream_id) - print( - f"[CMP][PARTIAL] strategy={strategy} stream={req.stream_id} " - f"partial={state.partial_transcript!r} step={state.current_step_transcript!r}", - flush=True, - ) - def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index b30cc7e55aa0..74880fc94774 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -13,13 +13,19 @@ # limitations under the License. +from typing import Any + from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ - State of the cache aware RNNT streaming pipelines + State of the cache aware RNNT streaming pipelines (greedy decoder). + + Extends :class:`CacheAwareStreamingState` with greedy-decoding bookkeeping + (``previous_hypothesis``). The MALSD beam-search variant adds its own + per-stream carry in :class:`CacheAwareRNNTMALSDStreamingState`. """ def __init__(self): @@ -61,7 +67,7 @@ def get_previous_hypothesis(self) -> Hypothesis | None: def reset_previous_hypothesis(self) -> None: """ - Reset the previous hypothesis to None + Reset the previous hypothesis. Called at utterance end (EOU). """ self.previous_hypothesis = None @@ -75,10 +81,11 @@ class CacheAwareRNNTMALSDStreamingState(CacheAwareRNNTStreamingState): - ``hyp_decoding_state``: per-stream beam carry (``MALSDStateItem``-like) shuttled between :meth:`merge_to_batched_state` and :meth:`split_batched_state`. - ``window_committed_tokens`` / ``window_committed_timestamps``: cumulative - prefix shared by all surviving beams at the most recent EOU boundary. + prefix shared by all surviving beams at the most recent collapse boundary. - ``window_beam_tokens`` / ``window_beam_timestamps``: per-slot chunk-local - cumulative emissions since the last EOU (one list per beam slot). Beams - stay diverged across chunks; the chosen path is committed at EOU. + cumulative emissions since the last collapse (one list per beam slot). + - ``_malsd_chunk_count``: number of MALSD chunks processed since the last + collapse - used by ``chunks_per_beam_reset`` to decide when to collapse. - ``_malsd_utterance_start``: position in the cumulative ``hyp.y_sequence`` where the current utterance begins, so EOU + ``cleanup_after_eou`` can correctly slice past previously emitted (and cleared) utterances. @@ -94,14 +101,15 @@ def _additional_params_reset(self) -> None: self.window_committed_timestamps: list[int] = [] self.window_beam_tokens: list[list[int]] | None = None self.window_beam_timestamps: list[list[int]] | None = None + self._malsd_chunk_count: int = 0 self._malsd_utterance_start: int = 0 def reset_previous_hypothesis(self) -> None: """ Reset the previous hypothesis and all MALSD beam-search bookkeeping. - Called at end-of-stream. Zeroes out the MALSD per-stream carry so the - next utterance starts from SOS with an empty windowed-beam state. + Called at utterance end (EOU). Zeroes out the MALSD per-stream carry so + the next utterance starts from SOS with an empty windowed-beam state. """ super().reset_previous_hypothesis() self.hyp_decoding_state = None @@ -109,6 +117,7 @@ def reset_previous_hypothesis(self) -> None: self.window_committed_timestamps = [] self.window_beam_tokens = None self.window_beam_timestamps = None + self._malsd_chunk_count = 0 # NB: ``_malsd_utterance_start`` is intentionally NOT reset here because # the cumulative ``hyp.y_sequence`` it indexes is owned by the pipeline # and bumped after the call when the previous utterance is being diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index ebeca4a521e9..cdec9b06657e 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -204,6 +204,29 @@ class SeparateGraphsMALSD: loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) +@dataclass +class MALSDStateItem: + """ + Per-stream decoding state for ``ModifiedALSDBatchedRNNTComputer``. + + Used by streaming pipelines that maintain per-stream state. Mirrors + ``LabelLoopingStateItem`` (greedy) with beam-shaped tensors + (``[beam_size, ...]`` instead of scalar/``[D]``) plus the cross-chunk + per-beam fields needed to seed the next MALSD chunk. + """ + + predictor_state: Any # opaque per-stream predictor state of size beam_size + predictor_output: torch.Tensor # [beam_size, 1, D] + label: torch.Tensor # [beam_size] + decoded_length: torch.Tensor # scalar + score: torch.Tensor # [beam_size] + transcript_hash: torch.Tensor # [beam_size] + current_lengths_nb: torch.Tensor # [beam_size] + last_timestamp_lasts: Optional[torch.Tensor] = None # [beam_size] or None + transcript_prefix_hash: Optional[torch.Tensor] = None # [beam_size] or None + fusion_state_list: list[torch.Tensor] = field(default_factory=list) # each [beam_size, ...] + + class ModifiedALSDBatchedRNNTComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Batched Alignment-Length Synchronous Decoding implementation. Callable. @@ -1638,55 +1661,6 @@ def collapse_batched_state_to_beams_( batched_hyps.keep_beam_(beam_indices) - def collapse_state_item_to_top1_(self, item: MALSDStateItem, beam_index: int) -> None: - """ - In-place per-stream variant of :meth:`collapse_batched_state_to_beams_`. - - Replicates beam ``beam_index`` across all ``beam_size`` slots of ``item`` - and sets ``score[1:] = INACTIVE_SCORE`` so the next chunk's top-k expands - the surviving beam. Used by streaming pipelines to collapse a single - stream's MALSD carry at its EOU boundary without disturbing other rows - of a batched run. - - Wraps mutations in :func:`torch.inference_mode` so it can be called from - outside the encoder/decoder inference region (the per-stream tensors are - inference tensors produced by :meth:`split_batched_state`). - """ - beam_size = self.beam_size - if not 0 <= beam_index < beam_size: - raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") - - with torch.inference_mode(): - per_row = self.decoder.batch_split_states(item.predictor_state) - if len(per_row) != beam_size: - raise AssertionError( - f"Expected per-stream predictor state with batch dim {beam_size}, got {len(per_row)}" - ) - item.predictor_state = self.decoder.batch_unsplit_states([per_row[beam_index]] * beam_size) - - item.predictor_output = ( - item.predictor_output[beam_index : beam_index + 1] - .expand(beam_size, *item.predictor_output.shape[1:]) - .contiguous() - ) - - idx = torch.full([beam_size], fill_value=beam_index, dtype=torch.long, device=item.label.device) - item.label = item.label.index_select(0, idx).contiguous() - if item.score is not None: - item.score = item.score.index_select(0, idx).contiguous() - item.score[1:].fill_(INACTIVE_SCORE) - if item.transcript_hash is not None: - item.transcript_hash = item.transcript_hash.index_select(0, idx).contiguous() - if item.current_lengths_nb is not None: - item.current_lengths_nb = item.current_lengths_nb.index_select(0, idx).contiguous() - if item.last_timestamp_lasts is not None: - item.last_timestamp_lasts = item.last_timestamp_lasts.index_select(0, idx).contiguous() - if item.transcript_prefix_hash is not None: - item.transcript_prefix_hash = item.transcript_prefix_hash.index_select(0, idx).contiguous() - - for fi, fs in enumerate(item.fusion_state_list): - item.fusion_state_list[fi] = fs.index_select(0, idx).contiguous() - def __call__( self, x: torch.Tensor, diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index af7ac09619ef..43683f8a40bf 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -323,6 +323,32 @@ def clone(self, batch_size: Optional[int] = None) -> "BatchedBeamHyps": new_hyps.token_durations.copy_(self.token_durations[:out_batch]) return new_hyps + def keep_beam_(self, beam_indices: torch.Tensor) -> None: + """ + In-place: collapse each row to a single surviving beam, replicated across all + ``beam_size`` slots, with the other slots' scores set to ``INACTIVE_SCORE``. + + Used by streaming pipelines to commit the per-chunk best beam as the + definitive history before the next chunk, so the carried predictor state and + the published transcript stay consistent. + + Args: + beam_indices: ``[batch_size]`` long tensor giving the beam to keep for + each row in the batch. + """ + if self.beam_size <= 1: + return + permutation = ( + beam_indices.to(dtype=torch.long, device=self.device) + .unsqueeze(-1) + .expand(self.batch_size, self.beam_size) + .contiguous() + ) + self._flatten_with_permutation_(permutation) + # Mark all but the first slot as inactive so the next iteration's top-k repopulates them. + self.scores[:, 1:].fill_(INACTIVE_SCORE) + + def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: """ Get last labels for each hypothesis in the beam. @@ -608,57 +634,20 @@ def recombine_prefixes(self, label_logps: torch.Tensor, active_mask: torch.Tenso to_update_mask = torch.logical_and(active_mask, self.scores != INACTIVE_SCORE) self.scores = torch.where(to_update_mask, torch.logaddexp(self.scores, prefix_label_logps), self.scores) - def _export_hypothesis_timestamps( - self, - beam_timestamps: torch.Tensor, - beam_durations: Optional[torch.Tensor], - mask: torch.Tensor, - ) -> tuple: - """Convert internal beam timestamps into Hypothesis timestamp fields.""" - end_times = beam_timestamps[mask] - if self.model_type == ASRModelTypeEnum.TDT: - durations = beam_durations[mask] - start_times = end_times - durations - return ( - start_times.cpu().detach().numpy(), - durations.cpu().detach().numpy(), - ) - return end_times.cpu().detach().numpy(), None - def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]: """ - Converts the batched beam search results into a list of signle best hypotheses for each batch. + Converts the batched beam search results into a list of single best hypotheses for each batch. Args: score_norm (bool): If True, normalize the scores before sorting. Defaults to True. Returns: list[Hypothesis]: A list where each element corresponds to a batch and contains best hypothesis. """ - self.flatten_sort_(score_norm) - - scores = self.scores[self.batch_indices, 0].tolist() - - max_idx = self.current_lengths_wb.max() - 1 - timestamps = self.timestamps[..., 0, : max_idx + 1] - transcripts = self.transcript_wb[..., 0, : max_idx + 1] - durations = self.token_durations[..., 0, : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None - hypotheses = [] - for batch_idx in range(self.batch_size): - mask = self._create_transcripts_mask(transcripts[batch_idx]) - timestamp, token_duration = self._export_hypothesis_timestamps( - timestamps[batch_idx], durations[batch_idx] if durations is not None else None, mask - ) - hypotheses.append( - Hypothesis( - score=scores[batch_idx], - y_sequence=transcripts[batch_idx][mask].cpu().detach().numpy(), - timestamp=timestamp, - token_duration=token_duration, - alignments=None, - dec_state=None, - ) - ) - return hypotheses + scores, transcripts, timestamps, durations, _ = self._export(sort=True, score_norm=score_norm) + return [ + self._hypothesis_from_flat(b, 0, scores, transcripts, timestamps, durations) + for b in range(self.batch_size) + ] def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: """ @@ -669,41 +658,80 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: list[NBestHypotheses]: A list where each element corresponds to a batch and contains N-best hypotheses. """ - - self.flatten_sort_(score_norm) - - scores = self.scores.tolist() - - max_idx = self.current_lengths_wb.max() - 1 - transcripts = self.transcript_wb[..., : max_idx + 1] - timestamps = self.timestamps[..., : max_idx + 1] - durations = self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None + scores, transcripts, timestamps, durations, _ = self._export(sort=True, score_norm=score_norm) hypotheses = [] for batch_idx in range(self.batch_size): nbest = [] for beam_idx in range(self.beam_size): if scores[batch_idx][beam_idx] <= INACTIVE_SCORE: continue - mask = self._create_transcripts_mask(transcripts[batch_idx][beam_idx]) - timestamp, token_duration = self._export_hypothesis_timestamps( - timestamps[batch_idx][beam_idx], - durations[batch_idx][beam_idx] if durations is not None else None, - mask, - ) nbest.append( - Hypothesis( - score=scores[batch_idx][beam_idx], - y_sequence=transcripts[batch_idx][beam_idx][mask].cpu().detach().numpy(), - timestamp=timestamp, - token_duration=token_duration, - alignments=None, - dec_state=None, - ) + self._hypothesis_from_flat(batch_idx, beam_idx, scores, transcripts, timestamps, durations) ) hypotheses.append(NBestHypotheses(nbest)) return hypotheses - def flatten_sort_(self, score_norm: bool = True): + def _export( + self, sort: bool = True, score_norm: bool = True + ) -> tuple[list[list[float]], torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """ + Flatten the prefix tree and return per-(batch, beam) views. + + Args: + sort: if True, flatten by descending (normalized) score; otherwise + flatten while preserving slot order. + score_norm: passed to :meth:`flatten_sort_` when ``sort=True``. + + Returns: + (scores, transcripts, timestamps, durations, root_ptrs). The first four + are inputs for :meth:`_hypothesis_from_flat`; ``root_ptrs`` is the + chunk-start -> chunk-end slot descent map (``[batch, beam]`` long + tensor) for the current beam ordering. + """ + if sort: + root_ptrs = self.flatten_sort_(score_norm) + else: + root_ptrs = self.flatten_() + scores = self.scores.tolist() + max_idx = self.current_lengths_wb.max() - 1 + transcripts = self.transcript_wb[..., : max_idx + 1] + timestamps = self.timestamps[..., : max_idx + 1] + durations = ( + self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None + ) + return scores, transcripts, timestamps, durations, root_ptrs + + def _hypothesis_from_flat( + self, + batch_idx: int, + beam_idx: int, + scores: list[list[float]], + transcripts: torch.Tensor, + timestamps: torch.Tensor, + durations: Optional[torch.Tensor], + ) -> Hypothesis: + """Build one ``Hypothesis`` from already-flattened per-(batch, beam) views.""" + transcript = transcripts[batch_idx][beam_idx] + mask = self._create_transcripts_mask(transcript) + end_times = timestamps[batch_idx][beam_idx][mask] + if durations is not None: + # TDT: report per-token start times and durations. + token_duration = durations[batch_idx][beam_idx][mask] + timestamp = (end_times - token_duration).cpu().detach().numpy() + token_duration = token_duration.cpu().detach().numpy() + else: + timestamp = end_times.cpu().detach().numpy() + token_duration = None + return Hypothesis( + score=scores[batch_idx][beam_idx], + y_sequence=transcript[mask].cpu().detach().numpy(), + timestamp=timestamp, + token_duration=token_duration, + alignments=None, + dec_state=None, + ) + + def flatten_sort_(self, score_norm: bool = True) -> torch.Tensor: """ Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. Args: @@ -715,6 +743,11 @@ def flatten_sort_(self, score_norm: bool = True): 3. Iteratively reconstructs the tokens and timestamps for each hypothesis in reverse order. 4. Updates the internal state of the object, including transcripts, timestamps, scores, lengths, labels, and other metadata, based on the sorted order. + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the chunk-start beam index + (before the first ``add_results_*`` write) from which each sorted output beam + descends. Same semantics as :meth:`flatten_`, but for the sorted ordering. """ # add one for consistency with non-batched decodings, that use SOS. @@ -722,7 +755,7 @@ def flatten_sort_(self, score_norm: bool = True): self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores ) _, indices = torch.sort(normalized_scores, dim=-1, descending=True) - self._flatten_with_permutation_(indices) + return self._flatten_with_permutation_(indices) def flatten_(self) -> torch.Tensor: """ @@ -974,3 +1007,40 @@ def merge_( self.last_timestamp_lasts.copy_(other.last_timestamp_lasts) return self + + +def export_batched_beam_hyps_to_cpu_lists( + bbh: BatchedBeamHyps, +) -> tuple[list[list[list[int]]], list[list[list[int]]], list[list[int]]]: + """ + Streaming-pipeline helper: flatten ``bbh`` in-place (identity permutation) and + return CPU-side per-(batch, beam) chunk-local emissions plus the chunk-start + descent map. Intended for windowed-beam aggregation outside the engine. + + Returns: + (tokens, timestamps, root_ptrs): + * ``tokens``: ``[batch_size][beam_size]`` non-blank/non-padding token + IDs for this chunk. + * ``timestamps``: ``[batch_size][beam_size]`` matching step indices. + * ``root_ptrs``: ``[batch_size][beam_size]`` chunk-start beam index + from which each current slot descends. + """ + _, transcripts, timestamps, _, root_ptrs = bbh._export(sort=False) + # One sync to CPU; per-slot masking + .tolist() stays on CPU. + root_ptrs_list = root_ptrs.detach().cpu().tolist() + transcripts_cpu = transcripts.detach().cpu() + timestamps_cpu = timestamps.detach().cpu() + + tokens: list[list[list[int]]] = [] + timestamps_out: list[list[list[int]]] = [] + for b in range(bbh.batch_size): + bt: list[list[int]] = [] + bts: list[list[int]] = [] + for k in range(bbh.beam_size): + t = transcripts_cpu[b, k] + mask = bbh._create_transcripts_mask(t) + bt.append(t[mask].tolist()) + bts.append(timestamps_cpu[b, k][mask].tolist()) + tokens.append(bt) + timestamps_out.append(bts) + return tokens, timestamps_out, root_ptrs_list diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 34c18463549d..5a61e0870fc9 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -2464,7 +2464,9 @@ def append_no_checks_(self, data: torch.Tensor, lengths: torch.Tensor | None = N indices = torch.arange(other_len, device=self.device) shifted_indices = self.lengths[:, None] + indices[None, :] # add trailing len(dim_shape) axes to shifted_indices - shifted_indices = shifted_indices[..., *[None for _ in range(len(self.dim_shape))]] + # NB: ``a[..., *unpack]`` subscript-unpacking is Python 3.11+; loop ``unsqueeze`` for 3.10. + for _ in range(len(self.dim_shape)): + shifted_indices = shifted_indices.unsqueeze(-1) self.data.scatter_(dim=1, index=shifted_indices.expand([-1, -1] + self.dim_shape), src=data) if lengths is None: self.lengths += other_len From e889eea29a9cbca37124f0813b39dcbbdcb992cb Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 8 Jun 2026 23:40:13 +0400 Subject: [PATCH 03/28] saving config Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt.yaml | 14 +- .../cache_aware_rnnt_malsd.yaml | 147 ++++++++++++++++++ .../pipelines/cache_aware_rnnt_pipeline.py | 4 - 3 files changed, 156 insertions(+), 9 deletions(-) create mode 100644 examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index 8c5f754d0c97..3758408c167c 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -37,11 +37,15 @@ asr: allow_cuda_graphs: true # n-gram LM (off by default) ngram_lm_model: null - key_phrases_file: null - key_phrases_list: null - key_phrase_items_list: null - source_lang: "en" - boosting_tree_alpha: 0.0 + ngram_lm_alpha: 0.0 + # phrase boosting (off by default) + boosting_tree: + model_path: null + key_phrases_file: null + key_phrases_list: null + key_phrase_items_list: null + source_lang: "en" + boosting_tree_alpha: 0.0 # ========================================== # Inverse Text Normalization Configuration diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml new file mode 100644 index 000000000000..233301e8d7ae --- /dev/null +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml @@ -0,0 +1,147 @@ +# ================================ +# ASR Configuration (MALSD beam-search variant of cache_aware_rnnt.yaml) +# ================================ +asr: + model_name: nvidia/nemotron-speech-streaming-en-0.6b # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path + device: cuda # Device for inference: 'cuda' or 'cpu' + device_id: 0 # GPU device ID + compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' + use_amp: true # Enable Automatic Mixed Precision + decoding: + strategy: "malsd_batch" + preserve_alignments: false + fused_batch_size: -1 + beam: + beam_size: 4 + return_best_hypothesis: true + score_norm: true + allow_cuda_graphs: true + # n-gram LM (off by default) + ngram_lm_model: null + ngram_lm_alpha: 0.0 + # phrase boosting (off by default) + boosting_tree: + model_path: null + key_phrases_file: null + key_phrases_list: null + key_phrase_items_list: null + source_lang: "en" + boosting_tree_alpha: 0.0 + +# ========================================== +# Inverse Text Normalization Configuration +# ========================================== +itn: + input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' + whitelist: null # Custom whitelist for ITN processing + overwrite_cache: false # Whether to overwrite existing cache files + max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing + left_padding_size: 4 # Padding size (#spans) for ITN context + batch_size: 32 # Batch size for ITN inference + n_jobs: 16 # Number of parallel jobs for ITN processing + + +# ================================ +# Neural Machine Translation Configuration +# ================================ +nmt: + model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name + source_language: "English" # Source language code + target_language: "Russian" # Target language code + waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations. + device: cuda # Device for translation: 'cuda'. 'cpu' is not supported. + device_id: 1 # GPU device ID for translation + batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size + llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details + dtype: "auto" # Compute precision + seed: 42 # The seed to initialize the random number generator for sampling + sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details + max_tokens: 100 # Maximum number of tokens to generate with LLM + temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy) + top_p: 0.9 # The cumulative probability threshold for nucleus sampling + seed: 42 # The seed to initialize the random number generator for sampling + + +# ======================== +# Confidence estimation +# ======================== +confidence: + exclude_blank: true # Exclude blank tokens when calculating confidence + aggregation: mean # Aggregation method for confidence across time steps + method_cfg: + name: entropy # Confidence estimation method: 'max_prob' or 'entropy' + entropy_type: tsallis + alpha: 0.5 + entropy_norm: exp + + +# ======================== +# Endpointing settings +# ======================== +endpointing: + stop_history_eou: 800 # Time window (ms) for evaluating EoU + residue_tokens_at_end: 2 # Number of residual tokens used for EoU + + +# ======================== +# Streaming configuration +# ======================== +streaming: + sample_rate: 16000 # Audio sample rate in Hz + batch_size: 64 # Number of audio frames per batch + word_boundary_tolerance: 4 # Tolerance for word boundaries + att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0] + use_cache: true # Whether to use cache for streaming + use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer + chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames. + request_type: frame # Type of request: frame or feature_buffer + num_slots: 256 # Number of slots in the context manager: must be >= batch_size + chunks_per_beam_reset: 20 # MALSD: number of chunks between full K-beam collapses to a single top-1 hypothesis. + # 1 (default) reproduces per-chunk collapse behaviour; higher values preserve beam + # diversity across multiple chunks before collapsing. + + +# ======================== +# Pipeline settings +# ======================== +matmul_precision: high # Matrix multiplication precision: highest, high, medium +log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) +pipeline_type: cache_aware # Pipeline type: buffered, cache_aware +asr_decoding_type: rnnt # Decoding method: ctc or rnnt + + +# ======================== +# Runtime arguments defined at runtime via command line +# ======================== +audio_file: null # Path to audio file, directory, or manifest JSON +output_filename: null # Path to output transcription JSON file +output_dir: null # Directory to save time-aligned output +enable_itn: false # Whether to apply inverse text normalization +enable_nmt: false # Whether to apply neural machine translation +asr_output_granularity: segment # Output granularity: word or segment +cache_dir: null # Directory to store cache (e.g., .far files) +lang: null # Language code for ASR model +return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer +calculate_wer: true # Whether to calculate WER +calculate_bleu: true # Whether to calculate BLEU score +warmup_steps: 0 # Number of warmup steps for RTFx and LAAL calculation +run_steps: 1 # Number of run steps for RTFx and LAAL calculation + + +# ======================== +# Metrics +# ======================== +metrics: + asr: + gt_text_attr_name: text # Attribute name for ground truth text + clean_groundtruth_text: false # Whether to clean ground truth text + langid: en # Language code for text normalization; only "en" is supported + use_cer: false # Whether to use character error rate + ignore_capitalization: true # Whether to ignore capitalization + ignore_punctuation: true # Whether to ignore punctuation + strip_punc_space: false # Whether to strip punctuation and space + nmt: + gt_text_attr_name: answer # Attribute name for ground truth text + ignore_capitalization: false # Whether to ignore capitalization + ignore_punctuation: false # Whether to ignore punctuation + strip_punc_space: false # Whether to strip punctuation and space diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index f648dc94fb5f..154c60a0b3b8 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -724,11 +724,9 @@ def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) - self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(fbuffers, self.tokenizer, self.leading_regex_pattern) - self._debug_print_partials(fbuffers) def transcribe_step_for_frames(self, frames: list[Frame]) -> None: """ @@ -772,11 +770,9 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None: # post-process the ready states if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) - self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) - self._debug_print_partials(frames) def get_request_generator(self) -> ContinuousBatchedRequestStreamer: """ From e51ee3c4eb07cede92534c797c335ac780bcba55 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 8 Jun 2026 22:30:05 +0400 Subject: [PATCH 04/28] add n-chunk reseting working Signed-off-by: lilithgrigoryan From 3765accbd3c5526a1b88162cb4ebd9df816304c4 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Mon, 8 Jun 2026 22:46:10 +0400 Subject: [PATCH 05/28] add eou resetting Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_malsd.yaml | 3 - .../pipelines/cache_aware_rnnt_pipeline.py | 121 +++++++++--------- .../streaming/state/cache_aware_rnnt_state.py | 13 +- .../submodules/rnnt_malsd_batched_computer.py | 49 +++++++ 4 files changed, 117 insertions(+), 69 deletions(-) diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml index 233301e8d7ae..c77c78426b39 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml @@ -96,9 +96,6 @@ streaming: chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames. request_type: frame # Type of request: frame or feature_buffer num_slots: 256 # Number of slots in the context manager: must be >= batch_size - chunks_per_beam_reset: 20 # MALSD: number of chunks between full K-beam collapses to a single top-1 hypothesis. - # 1 (default) reproduces per-chunk collapse behaviour; higher values preserve beam - # diversity across multiple chunks before collapsing. # ======================== diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 154c60a0b3b8..b49dd260bb67 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -180,13 +180,6 @@ def init_parameters(self, cfg: DictConfig) -> None: self.request_type = RequestType.from_str(cfg.streaming.request_type) - # MALSD beam-search streaming knobs. ``chunks_per_beam_reset == 1`` collapses - # the K-beam state down to a single top-1 hypothesis after every chunk, which - # matches the original cache-aware behaviour. Higher values preserve beam - # diversity across multiple chunks before collapsing - currently only the - # ``== 1`` path is fully wired up; larger values fall back to it. - self.chunks_per_beam_reset = int(cfg.streaming.get("chunks_per_beam_reset", 1)) - def init_greedy_rnnt_decoder(self) -> None: """Initialize the RNNT decoder.""" check_existance_of_required_attributes(self, ['vocabulary', 'conf_func']) @@ -338,17 +331,19 @@ def _malsd_stream_step( 2. Merge per-stream ``MALSDStateItem``s into a batched MALSD state. 3. Run :class:`ModifiedALSDBatchedRNNTComputer` for this chunk. 4. Update per-stream windowed-beam tracking from this chunk's emissions. - 5. Optionally collapse to the chunk's top-1 (current default behaviour). - 6. Split the batched MALSD state back into per-stream carries. - 7. Build a cumulative ``Hypothesis`` per stream from + 5. Split the batched MALSD state back into per-stream carries. + 6. Build a cumulative ``Hypothesis`` per stream from ``window_committed + window_beam_tokens[top1]``. + Collapse to the chunk's top-1 is NOT performed here - beams stay + diverged across chunks and are collapsed per-stream at the EOU + boundary inside :meth:`run_malsd_decoder`. + Returns a list of cumulative ``Hypothesis`` per stream and the new encoder cache context, matching the shape of ``stream_step``. """ # Per-stream multi-biasing ids: not yet supported on the MALSD streaming # path. Greedy-side per-stream biasing knobs stay independent. - multi_biasing_ids = None if biasing_enabled: logging.warning( "Per-stream biasing is not yet wired up on the MALSD cache-aware " @@ -363,11 +358,10 @@ def _malsd_stream_step( else: batched_state = self.decoding_computer.merge_to_batched_state(carries) - # All MALSD GPU work (encoder, decoder, windowed walk, collapse, split) - # shares one ``inference_mode`` region: ``collapse_batched_state_to_beams_`` - # and ``split_batched_state`` mutate the inference tensors returned by - # ``decoding_computer(...)`` in place, which is illegal once we've left - # the captured ``inference_mode`` region. + # All MALSD GPU work (encoder, decoder, windowed walk, split) shares one + # ``inference_mode`` region: ``split_batched_state`` mutates the inference + # tensors returned by ``decoding_computer(...)`` in place, which is illegal + # once we've left the captured ``inference_mode`` region. with ( torch.amp.autocast( device_type=self.asr_model.device_str, @@ -395,65 +389,37 @@ def _malsd_stream_step( self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) - # Capture pre-collapse argmax + scores. After ``collapse_batched_state_to_beams_`` - # runs, ``scores[:, 1:]`` is forced to ``INACTIVE_SCORE`` and ``scores[:, 0]`` - # carries the winner - so any post-collapse argmax returns 0 unconditionally. - # We need the PRE-collapse slot index to index ``window_beam_tokens`` (which - # was just computed against the diverged pre-collapse slots). + # Per-stream top-1 beam slot. Indexes ``window_beam_tokens`` (which was + # just appended against the diverged beam slots) to build the publishable + # cumulative hypothesis below. beam_indices_cpu = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() - scores_pre_collapse = best_batched_hyps.scores.detach().cpu() - - # Collapse the K-beam state at the configured cadence. For now we always - # collapse every chunk (``chunks_per_beam_reset == 1``); the multi-chunk - # window is a follow-up that needs full prefix-tree carry across chunks. - for state in states: - state._malsd_chunk_count += 1 - do_collapse = self.chunks_per_beam_reset <= 1 or any( - state._malsd_chunk_count >= self.chunks_per_beam_reset for state in states - ) - if do_collapse: - beam_indices = best_batched_hyps.scores.argmax(dim=-1).to(torch.long) - self.decoding_computer.collapse_batched_state_to_beams_( - batched_state, best_batched_hyps, beam_indices - ) + scores_cpu = best_batched_hyps.scores.detach().cpu() carry_items = self.decoding_computer.split_batched_state(batched_state) for state, carry in zip(states, carry_items): state.hyp_decoding_state = carry - # Build per-stream cumulative ``Hypothesis`` from the windowed state, - # then (on collapse chunks) promote the chosen beam's window tokens into - # the committed prefix and clear the window. The published hypothesis - # is identical pre/post-collapse promotion - just with everything moved - # into ``committed`` afterwards. + # Build per-stream cumulative ``Hypothesis`` from the windowed state. + # Collapse + window promotion is deferred to ``run_malsd_decoder`` and + # triggered by EOU, so the published hyp is the current top-1's path + # but the K-beam state continues to diverge across chunks. hyps: list[Hypothesis] = [] for b, state in enumerate(states): top1_slot = beam_indices_cpu[b] - window_tokens = ( - state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] - ) - window_ts = ( - state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] - ) + window_tokens = state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] + window_ts = state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] cum_tokens = state.window_committed_tokens + list(window_tokens) cum_ts = state.window_committed_timestamps + list(window_ts) hyps.append( Hypothesis( - score=float(scores_pre_collapse[b, top1_slot].item()), + score=float(scores_cpu[b, top1_slot].item()), y_sequence=cum_tokens, timestamp=cum_ts, length=len(cum_tokens), ) ) - if do_collapse: - state._malsd_chunk_count = 0 - state.window_committed_tokens = list(cum_tokens) - state.window_committed_timestamps = list(cum_ts) - state.window_beam_tokens = None - state.window_beam_timestamps = None - return hyps, new_context def _update_windowed_beam_state( @@ -499,7 +465,9 @@ def run_malsd_decoder( On EOU we bump ``_malsd_utterance_start`` to the current cumulative length so the next utterance's resync slice starts past the cleared - previous utterance. + previous utterance, then collapse the per-stream MALSD carry to its + top-1 beam: the K-beam state diverges intra-utterance and snaps to the + chosen path at the natural utterance boundary. """ eou_detected = self.run_greedy_decoder(state, request, hyp) @@ -520,9 +488,21 @@ def run_malsd_decoder( state.last_token_idx = timestamps_list[-1] if timestamps_list else None if eou_detected: - # mark the boundary so the next utterance's slice starts past the - # tokens we just finalised + # Mark the boundary so the next utterance's slice starts past the + # tokens we just finalised. state._malsd_utterance_start = len(all_tokens) + + # EOU-driven collapse: promote the chosen window into the committed + # prefix and replicate the winning beam across all K slots of the + # per-stream carry. The predictor stays warm at the top-1's last + # label so the next utterance benefits from cross-utterance context. + if state.hyp_decoding_state is not None: + top1 = int(state.hyp_decoding_state.score.argmax().item()) + self.decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) + state.window_committed_tokens = list(all_tokens) + state.window_committed_timestamps = list(all_timestamps) + state.window_beam_tokens = None + state.window_beam_timestamps = None return eou_detected def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: @@ -686,6 +666,27 @@ def cache_aware_transcribe_step( biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model ) + def _debug_print_finals(self, ready_state_ids: set) -> None: + """DEBUG: print finalised transcripts so greedy vs MALSD logs can be diffed.""" + strategy = "malsd" if self.decoding_computer is not None else "greedy" + for sid in sorted(ready_state_ids): + state = self.get_state(sid) + print( + f"[CMP][FINAL] strategy={strategy} stream={sid} text={state.final_transcript!r}", + flush=True, + ) + + def _debug_print_partials(self, requests: list[Request]) -> None: + """DEBUG: print partial / current-step transcripts so greedy vs MALSD logs can be diffed.""" + strategy = "malsd" if self.decoding_computer is not None else "greedy" + for req in requests: + state = self.get_state(req.stream_id) + print( + f"[CMP][PARTIAL] strategy={strategy} stream={req.stream_id} " + f"partial={state.partial_transcript!r} step={state.current_step_transcript!r}", + flush=True, + ) + def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. @@ -724,9 +725,11 @@ def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(fbuffers, self.tokenizer, self.leading_regex_pattern) + self._debug_print_partials(fbuffers) def transcribe_step_for_frames(self, frames: list[Frame]) -> None: """ @@ -770,9 +773,11 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None: # post-process the ready states if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) + self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) + self._debug_print_partials(frames) def get_request_generator(self) -> ContinuousBatchedRequestStreamer: """ diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 74880fc94774..9729394abc08 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -81,11 +81,10 @@ class CacheAwareRNNTMALSDStreamingState(CacheAwareRNNTStreamingState): - ``hyp_decoding_state``: per-stream beam carry (``MALSDStateItem``-like) shuttled between :meth:`merge_to_batched_state` and :meth:`split_batched_state`. - ``window_committed_tokens`` / ``window_committed_timestamps``: cumulative - prefix shared by all surviving beams at the most recent collapse boundary. + prefix shared by all surviving beams at the most recent EOU boundary. - ``window_beam_tokens`` / ``window_beam_timestamps``: per-slot chunk-local - cumulative emissions since the last collapse (one list per beam slot). - - ``_malsd_chunk_count``: number of MALSD chunks processed since the last - collapse - used by ``chunks_per_beam_reset`` to decide when to collapse. + cumulative emissions since the last EOU (one list per beam slot). Beams + stay diverged across chunks; the chosen path is committed at EOU. - ``_malsd_utterance_start``: position in the cumulative ``hyp.y_sequence`` where the current utterance begins, so EOU + ``cleanup_after_eou`` can correctly slice past previously emitted (and cleared) utterances. @@ -101,15 +100,14 @@ def _additional_params_reset(self) -> None: self.window_committed_timestamps: list[int] = [] self.window_beam_tokens: list[list[int]] | None = None self.window_beam_timestamps: list[list[int]] | None = None - self._malsd_chunk_count: int = 0 self._malsd_utterance_start: int = 0 def reset_previous_hypothesis(self) -> None: """ Reset the previous hypothesis and all MALSD beam-search bookkeeping. - Called at utterance end (EOU). Zeroes out the MALSD per-stream carry so - the next utterance starts from SOS with an empty windowed-beam state. + Called at end-of-stream. Zeroes out the MALSD per-stream carry so the + next utterance starts from SOS with an empty windowed-beam state. """ super().reset_previous_hypothesis() self.hyp_decoding_state = None @@ -117,7 +115,6 @@ def reset_previous_hypothesis(self) -> None: self.window_committed_timestamps = [] self.window_beam_tokens = None self.window_beam_timestamps = None - self._malsd_chunk_count = 0 # NB: ``_malsd_utterance_start`` is intentionally NOT reset here because # the cumulative ``hyp.y_sequence`` it indexes is owned by the pipeline # and bumped after the call when the previous utterance is being diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index cdec9b06657e..e6e3518a9c55 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1661,6 +1661,55 @@ def collapse_batched_state_to_beams_( batched_hyps.keep_beam_(beam_indices) + def collapse_state_item_to_top1_(self, item: MALSDStateItem, beam_index: int) -> None: + """ + In-place per-stream variant of :meth:`collapse_batched_state_to_beams_`. + + Replicates beam ``beam_index`` across all ``beam_size`` slots of ``item`` + and sets ``score[1:] = INACTIVE_SCORE`` so the next chunk's top-k expands + the surviving beam. Used by streaming pipelines to collapse a single + stream's MALSD carry at its EOU boundary without disturbing other rows + of a batched run. + + Wraps mutations in :func:`torch.inference_mode` so it can be called from + outside the encoder/decoder inference region (the per-stream tensors are + inference tensors produced by :meth:`split_batched_state`). + """ + beam_size = self.beam_size + if not 0 <= beam_index < beam_size: + raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") + + with torch.inference_mode(): + per_row = self.decoder.batch_split_states(item.predictor_state) + if len(per_row) != beam_size: + raise AssertionError( + f"Expected per-stream predictor state with batch dim {beam_size}, got {len(per_row)}" + ) + item.predictor_state = self.decoder.batch_unsplit_states([per_row[beam_index]] * beam_size) + + item.predictor_output = ( + item.predictor_output[beam_index : beam_index + 1] + .expand(beam_size, *item.predictor_output.shape[1:]) + .contiguous() + ) + + idx = torch.full([beam_size], fill_value=beam_index, dtype=torch.long, device=item.label.device) + item.label = item.label.index_select(0, idx).contiguous() + if item.score is not None: + item.score = item.score.index_select(0, idx).contiguous() + item.score[1:].fill_(INACTIVE_SCORE) + if item.transcript_hash is not None: + item.transcript_hash = item.transcript_hash.index_select(0, idx).contiguous() + if item.current_lengths_nb is not None: + item.current_lengths_nb = item.current_lengths_nb.index_select(0, idx).contiguous() + if item.last_timestamp_lasts is not None: + item.last_timestamp_lasts = item.last_timestamp_lasts.index_select(0, idx).contiguous() + if item.transcript_prefix_hash is not None: + item.transcript_prefix_hash = item.transcript_prefix_hash.index_select(0, idx).contiguous() + + for fi, fs in enumerate(item.fusion_state_list): + item.fusion_state_list[fi] = fs.index_select(0, idx).contiguous() + def __call__( self, x: torch.Tensor, From 0e11a4f0c34560ba1bc1bb8d99eb6615fc0160be Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 9 Jun 2026 01:30:16 +0400 Subject: [PATCH 06/28] clean up debug prints Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_malsd.yaml | 144 ------------------ .../pipelines/cache_aware_rnnt_pipeline.py | 25 --- 2 files changed, 169 deletions(-) delete mode 100644 examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml deleted file mode 100644 index c77c78426b39..000000000000 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt_malsd.yaml +++ /dev/null @@ -1,144 +0,0 @@ -# ================================ -# ASR Configuration (MALSD beam-search variant of cache_aware_rnnt.yaml) -# ================================ -asr: - model_name: nvidia/nemotron-speech-streaming-en-0.6b # Pre-trained CTC/hybrid model from NGC/HuggingFace or local .nemo file path - device: cuda # Device for inference: 'cuda' or 'cpu' - device_id: 0 # GPU device ID - compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' - use_amp: true # Enable Automatic Mixed Precision - decoding: - strategy: "malsd_batch" - preserve_alignments: false - fused_batch_size: -1 - beam: - beam_size: 4 - return_best_hypothesis: true - score_norm: true - allow_cuda_graphs: true - # n-gram LM (off by default) - ngram_lm_model: null - ngram_lm_alpha: 0.0 - # phrase boosting (off by default) - boosting_tree: - model_path: null - key_phrases_file: null - key_phrases_list: null - key_phrase_items_list: null - source_lang: "en" - boosting_tree_alpha: 0.0 - -# ========================================== -# Inverse Text Normalization Configuration -# ========================================== -itn: - input_case: lower_cased # Input text case handling: 'lower_cased', 'cased' - whitelist: null # Custom whitelist for ITN processing - overwrite_cache: false # Whether to overwrite existing cache files - max_number_of_permutations_per_split: 729 # Maximum permutations allowed per text split during ITN processing - left_padding_size: 4 # Padding size (#spans) for ITN context - batch_size: 32 # Batch size for ITN inference - n_jobs: 16 # Number of parallel jobs for ITN processing - - -# ================================ -# Neural Machine Translation Configuration -# ================================ -nmt: - model_name: "utter-project/EuroLLM-1.7B-Instruct" # vLLM-supported model name - source_language: "English" # Source language code - target_language: "Russian" # Target language code - waitk: -1 # Max allowed lag (in words) between ASR transcript and translation; -1 disables it and uses only the longest common prefix between current and previous translations. - device: cuda # Device for translation: 'cuda'. 'cpu' is not supported. - device_id: 1 # GPU device ID for translation - batch_size: 16 # Batch size for translation, if -1, the batch size is equal to the ASR batch size - llm_params: # See https://docs.vllm.ai/en/v0.8.1/api/offline_inference/llm.html for more details - dtype: "auto" # Compute precision - seed: 42 # The seed to initialize the random number generator for sampling - sampling_params: # See https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html for more details - max_tokens: 100 # Maximum number of tokens to generate with LLM - temperature: 0.0 # LLM sampling temperature, default for translation is 0 (greedy) - top_p: 0.9 # The cumulative probability threshold for nucleus sampling - seed: 42 # The seed to initialize the random number generator for sampling - - -# ======================== -# Confidence estimation -# ======================== -confidence: - exclude_blank: true # Exclude blank tokens when calculating confidence - aggregation: mean # Aggregation method for confidence across time steps - method_cfg: - name: entropy # Confidence estimation method: 'max_prob' or 'entropy' - entropy_type: tsallis - alpha: 0.5 - entropy_norm: exp - - -# ======================== -# Endpointing settings -# ======================== -endpointing: - stop_history_eou: 800 # Time window (ms) for evaluating EoU - residue_tokens_at_end: 2 # Number of residual tokens used for EoU - - -# ======================== -# Streaming configuration -# ======================== -streaming: - sample_rate: 16000 # Audio sample rate in Hz - batch_size: 64 # Number of audio frames per batch - word_boundary_tolerance: 4 # Tolerance for word boundaries - att_context_size: [70,13] # Attention context size: [70,13],[70,6],[70,1],[70,0] - use_cache: true # Whether to use cache for streaming - use_feat_cache: true # Whether to cache mel-spec features, set false to re-calculate all mel-spec features in audio buffer - chunk_size_in_secs: null # Amount of audio to load for each streaming step, e.g., 0.08s for FastConformer. Set to `null` for using default size equal to 1+lookahead frames. - request_type: frame # Type of request: frame or feature_buffer - num_slots: 256 # Number of slots in the context manager: must be >= batch_size - - -# ======================== -# Pipeline settings -# ======================== -matmul_precision: high # Matrix multiplication precision: highest, high, medium -log_level: 20 # Logging level: 0 (NOTSET), 10 (DEBUG), 20 (INFO), 30 (WARNING), 40 (ERROR), 50 (CRITICAL) -pipeline_type: cache_aware # Pipeline type: buffered, cache_aware -asr_decoding_type: rnnt # Decoding method: ctc or rnnt - - -# ======================== -# Runtime arguments defined at runtime via command line -# ======================== -audio_file: null # Path to audio file, directory, or manifest JSON -output_filename: null # Path to output transcription JSON file -output_dir: null # Directory to save time-aligned output -enable_itn: false # Whether to apply inverse text normalization -enable_nmt: false # Whether to apply neural machine translation -asr_output_granularity: segment # Output granularity: word or segment -cache_dir: null # Directory to store cache (e.g., .far files) -lang: null # Language code for ASR model -return_tail_result: false # Whether to return the tail labels left in the right padded side of the buffer -calculate_wer: true # Whether to calculate WER -calculate_bleu: true # Whether to calculate BLEU score -warmup_steps: 0 # Number of warmup steps for RTFx and LAAL calculation -run_steps: 1 # Number of run steps for RTFx and LAAL calculation - - -# ======================== -# Metrics -# ======================== -metrics: - asr: - gt_text_attr_name: text # Attribute name for ground truth text - clean_groundtruth_text: false # Whether to clean ground truth text - langid: en # Language code for text normalization; only "en" is supported - use_cer: false # Whether to use character error rate - ignore_capitalization: true # Whether to ignore capitalization - ignore_punctuation: true # Whether to ignore punctuation - strip_punc_space: false # Whether to strip punctuation and space - nmt: - gt_text_attr_name: answer # Attribute name for ground truth text - ignore_capitalization: false # Whether to ignore capitalization - ignore_punctuation: false # Whether to ignore punctuation - strip_punc_space: false # Whether to strip punctuation and space diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index b49dd260bb67..7895a11095f6 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -666,27 +666,6 @@ def cache_aware_transcribe_step( biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model ) - def _debug_print_finals(self, ready_state_ids: set) -> None: - """DEBUG: print finalised transcripts so greedy vs MALSD logs can be diffed.""" - strategy = "malsd" if self.decoding_computer is not None else "greedy" - for sid in sorted(ready_state_ids): - state = self.get_state(sid) - print( - f"[CMP][FINAL] strategy={strategy} stream={sid} text={state.final_transcript!r}", - flush=True, - ) - - def _debug_print_partials(self, requests: list[Request]) -> None: - """DEBUG: print partial / current-step transcripts so greedy vs MALSD logs can be diffed.""" - strategy = "malsd" if self.decoding_computer is not None else "greedy" - for req in requests: - state = self.get_state(req.stream_id) - print( - f"[CMP][PARTIAL] strategy={strategy} stream={req.stream_id} " - f"partial={state.partial_transcript!r} step={state.current_step_transcript!r}", - flush=True, - ) - def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. @@ -725,11 +704,9 @@ def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) - self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(fbuffers, self.tokenizer, self.leading_regex_pattern) - self._debug_print_partials(fbuffers) def transcribe_step_for_frames(self, frames: list[Frame]) -> None: """ @@ -773,11 +750,9 @@ def transcribe_step_for_frames(self, frames: list[Frame]) -> None: # post-process the ready states if len(ready_state_ids) > 0: self.text_processor.process([self.get_state(stream_id) for stream_id in ready_state_ids]) - self._debug_print_finals(ready_state_ids) ready_state_ids.clear() self.update_partial_transcript(frames, self.tokenizer, self.leading_regex_pattern) - self._debug_print_partials(frames) def get_request_generator(self) -> ContinuousBatchedRequestStreamer: """ From 6dc14230a456af32219c6b99d504e97dd77387a5 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 16 Jun 2026 20:19:38 +0400 Subject: [PATCH 07/28] typecast fix Signed-off-by: lilithgrigoryan --- .../asr/inference/pipelines/cache_aware_rnnt_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 7895a11095f6..b5e7662708bb 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -370,6 +370,7 @@ def _malsd_stream_step( ), torch.inference_mode(), ): + feature_buffers = feature_buffers.to(self.asr_model.cast_dtype) encoded, encoded_len, new_context = self.asr_model.encode_step( processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens, From 0a69dee8ecab681ce79c0ad975ac89a6cc9808e4 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 16 Jun 2026 21:47:04 +0400 Subject: [PATCH 08/28] clean up Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt.yaml | 4 +- .../cache_aware_rnnt_inference_wrapper.py | 9 +---- .../pipelines/cache_aware_rnnt_pipeline.py | 2 +- .../streaming/state/cache_aware_rnnt_state.py | 38 ++++++++++--------- 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index 5b0d8dff4150..ef8bc1d2b4d0 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -8,7 +8,7 @@ asr: compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32' use_amp: false # Enable Automatic Mixed Precision decoding: - strategy: "greedy_batch" + strategy: "greedy_batch" preserve_alignments: false fused_batch_size: -1 greedy: @@ -32,8 +32,6 @@ asr: boosting_tree_alpha: 0.0 # Weight of the boosting tree beam: beam_size: 4 - return_best_hypothesis: true - score_norm: true allow_cuda_graphs: true # n-gram LM (off by default) ngram_lm_model: null diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index c5afc78e1b71..ccb6c1faf533 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -76,7 +76,7 @@ def get_vocabulary(self) -> list[str]: """ return self.asr_model.joint.vocabulary - def encode_step( + def encoder_step( self, processed_signal: Tensor, processed_signal_length: Tensor, @@ -89,11 +89,6 @@ def encode_step( """ Run the cache-aware encoder for one streaming chunk, returning the (trimmed) encoder output and updated streaming context. Decoder is NOT invoked. - - Used by :meth:`execute_step` (greedy decoder runs right after) and by - beam-search pipelines that drive the decoder themselves with a - per-stream beam carry (they call ``encode_step`` directly inside their - own ``autocast`` + ``inference_mode`` region). """ ( encoded, @@ -155,7 +150,7 @@ def execute_step( Returns: (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. """ - encoded, encoded_len, new_context = self.encode_step( + encoded, encoded_len, new_context = self.encoder_step( processed_signal=processed_signal, processed_signal_length=processed_signal_length, context=context, diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index b5e7662708bb..0689b526b5a2 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -371,7 +371,7 @@ def _malsd_stream_step( torch.inference_mode(), ): feature_buffers = feature_buffers.to(self.asr_model.cast_dtype) - encoded, encoded_len, new_context = self.asr_model.encode_step( + encoded, encoded_len, new_context = self.asr_model.encoder_step( processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens, context=context, diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 9729394abc08..34e1e36c36a5 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -13,11 +13,14 @@ # limitations under the License. -from typing import Any +from typing import TYPE_CHECKING from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +if TYPE_CHECKING: + from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import MALSDStateItem + class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ @@ -67,27 +70,20 @@ def get_previous_hypothesis(self) -> Hypothesis | None: def reset_previous_hypothesis(self) -> None: """ - Reset the previous hypothesis. Called at utterance end (EOU). + Reset the previous hypothesis. """ self.previous_hypothesis = None class CacheAwareRNNTMALSDStreamingState(CacheAwareRNNTStreamingState): """ - Cache-aware RNNT state with MALSD beam-search per-stream bookkeeping. - - Adds the following fields on top of the greedy state: - - - ``hyp_decoding_state``: per-stream beam carry (``MALSDStateItem``-like) - shuttled between :meth:`merge_to_batched_state` and :meth:`split_batched_state`. - - ``window_committed_tokens`` / ``window_committed_timestamps``: cumulative - prefix shared by all surviving beams at the most recent EOU boundary. - - ``window_beam_tokens`` / ``window_beam_timestamps``: per-slot chunk-local - cumulative emissions since the last EOU (one list per beam slot). Beams - stay diverged across chunks; the chosen path is committed at EOU. - - ``_malsd_utterance_start``: position in the cumulative ``hyp.y_sequence`` - where the current utterance begins, so EOU + ``cleanup_after_eou`` can - correctly slice past previously emitted (and cleared) utterances. + Cache-aware RNNT state for MALSD beam-search streaming. + + Transcript assembly is ``committed prefix + live beam suffix``. Beams may + disagree within an utterance; at EOU the top-1 path is promoted into the + committed prefix and per-beam suffixes are cleared. + + See :class:`CacheAwareRNNTPipeline` (``_malsd_stream_step``, ``run_malsd_decoder``). """ def _additional_params_reset(self) -> None: @@ -95,11 +91,19 @@ def _additional_params_reset(self) -> None: Reset MALSD per-stream carry on top of the greedy state. """ super()._additional_params_reset() - self.hyp_decoding_state: Any = None + # Per-stream MALSD decoder carry (``MALSDStateItem``); Shuttled through + # ``merge_to_batched_state`` / ``split_batched_state`` each chunk. + self.hyp_decoding_state: "MALSDStateItem | None" = None + # Finalized transcript prefix at the last EOU; identical for every beam slot. self.window_committed_tokens: list[int] = [] + # Frame timestamps aligned with ``window_committed_tokens``. self.window_committed_timestamps: list[int] = [] + # Per-beam suffix since last EOU; slot k may differ while beams compete. self.window_beam_tokens: list[list[int]] | None = None + # Per-beam frame timestamps aligned with ``window_beam_tokens`` (same slot layout). self.window_beam_timestamps: list[list[int]] | None = None + # Index into cumulative ``hyp.y_sequence`` where the current utterance starts + # (skips tokens from prior utterances still present in the cumulative hyp). self._malsd_utterance_start: int = 0 def reset_previous_hypothesis(self) -> None: From e051b12b53b38768d62ef3c7fff44cabd40f4746 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 16 Jun 2026 21:49:20 +0400 Subject: [PATCH 09/28] isort and black + clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 12 +++++------ .../streaming/state/cache_aware_rnnt_state.py | 8 ++------ .../submodules/rnnt_malsd_batched_computer.py | 20 +++++-------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 0689b526b5a2..6c6ea0a44f46 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -214,7 +214,9 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta when the pipeline is configured for beam-search decoding. """ state = ( - CacheAwareRNNTMALSDStreamingState() if self.decoding_computer is not None else CacheAwareRNNTStreamingState() + CacheAwareRNNTMALSDStreamingState() + if self.decoding_computer is not None + else CacheAwareRNNTStreamingState() ) state.set_global_offset(0) new_options = options.fill_defaults( @@ -384,9 +386,7 @@ def _malsd_stream_step( # computer expects [B, T, D] (matches the rest of the decoding stack). encs_dim_last = encoded.transpose(1, 2).contiguous() - best_batched_hyps, batched_state = self.decoding_computer( - encs_dim_last, encoded_len, batched_state - ) + best_batched_hyps, batched_state = self.decoding_computer(encs_dim_last, encoded_len, batched_state) self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) @@ -444,9 +444,7 @@ def _update_windowed_beam_state( state.window_beam_tokens = [prev_t[int(rp[k])] + ct[k] for k in range(beam_size)] state.window_beam_timestamps = [prev_ts[int(rp[k])] + cts[k] for k in range(beam_size)] - def run_malsd_decoder( - self, state: CacheAwareRNNTMALSDStreamingState, request: Request, hyp: Hypothesis - ) -> bool: + def run_malsd_decoder(self, state: CacheAwareRNNTMALSDStreamingState, request: Request, hyp: Hypothesis) -> bool: """ MALSD counterpart to :meth:`run_greedy_decoder`. diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 34e1e36c36a5..b839ed279eda 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -24,11 +24,7 @@ class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ - State of the cache aware RNNT streaming pipelines (greedy decoder). - - Extends :class:`CacheAwareStreamingState` with greedy-decoding bookkeeping - (``previous_hypothesis``). The MALSD beam-search variant adds its own - per-stream carry in :class:`CacheAwareRNNTMALSDStreamingState`. + State of the cache aware RNNT streaming pipelines """ def __init__(self): @@ -70,7 +66,7 @@ def get_previous_hypothesis(self) -> Hypothesis | None: def reset_previous_hypothesis(self) -> None: """ - Reset the previous hypothesis. + Reset the previous hypothesis to None """ self.previous_hypothesis = None diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 3fdf6cff22fc..2ac13cee2f99 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1511,9 +1511,7 @@ def _get_state_item_after_sos(self, device: torch.device | str) -> MALSDStateIte batched = self._get_batched_state_after_sos(device=device, batch_size=1) return self.split_batched_state(batched)[0] - def _get_batched_state_after_sos( - self, device: torch.device | str, batch_size: int - ) -> BatchedBeamState: + def _get_batched_state_after_sos(self, device: torch.device | str, batch_size: int) -> BatchedBeamState: """ Build a fresh batched MALSD state after ````. @@ -1593,9 +1591,7 @@ def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: ) # ``state.fusion_states_list[k]`` is stored as ``[B, K]`` (see # ``modified_alsd_torch``'s ``s.view(batch_size, self.beam_size)`` step). - fusion_state_list = ( - [fs[i].clone() for fs in state.fusion_states_list] if state.fusion_states_list else [] - ) + fusion_state_list = [fs[i].clone() for fs in state.fusion_states_list] if state.fusion_states_list else [] items.append( MALSDStateItem( predictor_state=stream_predictor_state, @@ -1603,9 +1599,7 @@ def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: label=state.labels[i].clone(), decoded_length=state.decoded_lengths[i].clone(), score=state.scores[i].clone() if state.scores is not None else None, - transcript_hash=( - state.transcript_hash[i].clone() if state.transcript_hash is not None else None - ), + transcript_hash=(state.transcript_hash[i].clone() if state.transcript_hash is not None else None), current_lengths_nb=( state.current_lengths_nb[i].clone() if state.current_lengths_nb is not None else None ), @@ -1613,9 +1607,7 @@ def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: state.last_timestamp_lasts[i].clone() if state.last_timestamp_lasts is not None else None ), transcript_prefix_hash=( - state.transcript_prefix_hash[i].clone() - if state.transcript_prefix_hash is not None - else None + state.transcript_prefix_hash[i].clone() if state.transcript_prefix_hash is not None else None ), fusion_state_list=fusion_state_list, ) @@ -1736,9 +1728,7 @@ def collapse_batched_state_to_beams_( if state.current_lengths_nb is not None: state.current_lengths_nb = torch.gather(state.current_lengths_nb, dim=1, index=beam_perm).contiguous() if state.last_timestamp_lasts is not None: - state.last_timestamp_lasts = torch.gather( - state.last_timestamp_lasts, dim=1, index=beam_perm - ).contiguous() + state.last_timestamp_lasts = torch.gather(state.last_timestamp_lasts, dim=1, index=beam_perm).contiguous() if state.transcript_prefix_hash is not None: state.transcript_prefix_hash = torch.gather( state.transcript_prefix_hash, dim=1, index=beam_perm From 4dce1e6215a8eee3d32383aa314f76afabc70c9c Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 16 Jun 2026 22:01:52 +0400 Subject: [PATCH 10/28] clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 86 ++++++++----------- .../streaming/state/cache_aware_rnnt_state.py | 4 +- 2 files changed, 38 insertions(+), 52 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 6c6ea0a44f46..6060cf81d1f5 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -32,7 +32,7 @@ from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import ( - CacheAwareRNNTMALSDStreamingState, + CacheAwareRNNTBeamStreamingState, CacheAwareRNNTStreamingState, ) from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames @@ -88,26 +88,19 @@ def __init__( super().__init__() def init_decoding_computer(self) -> None: - """ - Probe the model's decoding stack once and stash the resulting computer - on ``self`` so per-chunk code can branch on it without re-doing the - attribute-chain dive. - - Exactly one of ``self.decoding_computer`` (MALSD beam-search) and - ``self.greedy_decoding_computer`` (greedy, used for per-stream biasing - detection) is non-``None`` for any supported decoding stack; both are - ``None`` if the stack exposes no ``decoding_computer`` at all. - """ + """Initialize ``decoding_computer``.""" + self.decoding_computer = None try: - decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer except AttributeError: - decoding_computer = None - if isinstance(decoding_computer, ModifiedALSDBatchedRNNTComputer): - self.decoding_computer: ModifiedALSDBatchedRNNTComputer | None = decoding_computer - self.greedy_decoding_computer = None - else: - self.decoding_computer = None - self.greedy_decoding_computer = decoding_computer + pass + + @property + def malsd_decoding_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: + """Return ``decoding_computer`` when beam-search MALSD is active.""" + if isinstance(self.decoding_computer, ModifiedALSDBatchedRNNTComputer): + return self.decoding_computer + return None def init_parameters(self, cfg: DictConfig) -> None: """ @@ -210,12 +203,12 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta Args: options: (ASRRequestOptions) Request options for particular stream. Returns: - (CacheAwareRNNTStreamingState) New empty state. Returns the MALSD subclass + (CacheAwareRNNTStreamingState) New empty state. Returns the beam-search subclass when the pipeline is configured for beam-search decoding. """ state = ( - CacheAwareRNNTMALSDStreamingState() - if self.decoding_computer is not None + CacheAwareRNNTBeamStreamingState() + if self.malsd_decoding_computer is not None else CacheAwareRNNTStreamingState() ) state.set_global_offset(0) @@ -287,13 +280,10 @@ def _streaming_step( biasing_enabled: bool, ) -> tuple[list[Hypothesis], object]: """ - Dispatcher between the greedy single-shot path and the MALSD beam path. - - For greedy (``self.decoding_computer is None``) this just calls the existing - ``asr_model.stream_step``. For MALSD it runs the encoder once and drives - :class:`ModifiedALSDBatchedRNNTComputer` with the per-stream beam carry. + Run one cache-aware encode/decode step for the current chunk. + Returns per-stream hypotheses and the updated encoder cache context. """ - if self.decoding_computer is None: + if self.malsd_decoding_computer is None: return self.asr_model.stream_step( processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens, @@ -306,6 +296,7 @@ def _streaming_step( prompt_vectors=prompt_vectors, ) return self._malsd_stream_step( + malsd_computer=self.malsd_decoding_computer, states=states, feature_buffers=feature_buffers, feature_buffer_lens=feature_buffer_lens, @@ -317,7 +308,8 @@ def _streaming_step( def _malsd_stream_step( self, - states: list[CacheAwareRNNTMALSDStreamingState], + malsd_computer: ModifiedALSDBatchedRNNTComputer, + states: list[CacheAwareRNNTBeamStreamingState], feature_buffers: Tensor, feature_buffer_lens: Tensor, context, @@ -358,12 +350,8 @@ def _malsd_stream_step( if all(c is None for c in carries): batched_state = None else: - batched_state = self.decoding_computer.merge_to_batched_state(carries) - - # All MALSD GPU work (encoder, decoder, windowed walk, split) shares one - # ``inference_mode`` region: ``split_batched_state`` mutates the inference - # tensors returned by ``decoding_computer(...)`` in place, which is illegal - # once we've left the captured ``inference_mode`` region. + batched_state = malsd_computer.merge_to_batched_state(carries) + with ( torch.amp.autocast( device_type=self.asr_model.device_str, @@ -386,7 +374,7 @@ def _malsd_stream_step( # computer expects [B, T, D] (matches the rest of the decoding stack). encs_dim_last = encoded.transpose(1, 2).contiguous() - best_batched_hyps, batched_state = self.decoding_computer(encs_dim_last, encoded_len, batched_state) + best_batched_hyps, batched_state = malsd_computer(encs_dim_last, encoded_len, batched_state) self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) @@ -396,7 +384,7 @@ def _malsd_stream_step( beam_indices_cpu = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() scores_cpu = best_batched_hyps.scores.detach().cpu() - carry_items = self.decoding_computer.split_batched_state(batched_state) + carry_items = malsd_computer.split_batched_state(batched_state) for state, carry in zip(states, carry_items): state.hyp_decoding_state = carry @@ -425,7 +413,7 @@ def _malsd_stream_step( def _update_windowed_beam_state( self, - states: list[CacheAwareRNNTMALSDStreamingState], + states: list[CacheAwareRNNTBeamStreamingState], best_batched_hyps: BatchedBeamHyps, ) -> None: """ @@ -444,7 +432,7 @@ def _update_windowed_beam_state( state.window_beam_tokens = [prev_t[int(rp[k])] + ct[k] for k in range(beam_size)] state.window_beam_timestamps = [prev_ts[int(rp[k])] + cts[k] for k in range(beam_size)] - def run_malsd_decoder(self, state: CacheAwareRNNTMALSDStreamingState, request: Request, hyp: Hypothesis) -> bool: + def run_malsd_decoder(self, state: CacheAwareRNNTBeamStreamingState, request: Request, hyp: Hypothesis) -> bool: """ MALSD counterpart to :meth:`run_greedy_decoder`. @@ -497,7 +485,7 @@ def run_malsd_decoder(self, state: CacheAwareRNNTMALSDStreamingState, request: R # label so the next utterance benefits from cross-utterance context. if state.hyp_decoding_state is not None: top1 = int(state.hyp_decoding_state.score.argmax().item()) - self.decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) + self.malsd_decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) state.window_committed_tokens = list(all_tokens) state.window_committed_timestamps = list(all_timestamps) state.window_beam_tokens = None @@ -576,12 +564,11 @@ def cache_aware_transcribe_step( previous_hypotheses = [state.get_previous_hypothesis() for state in states] - # Per-stream biasing is only wired up on the greedy decoder. When MALSD - # is active ``self.greedy_decoding_computer`` is ``None`` (see - # :meth:`init_decoding_computer`) so ``biasing_enabled`` falls back to - # ``False`` and the warning in ``_malsd_stream_step`` covers the rest. + # Per-stream biasing is only wired up on the greedy decoder. biasing_enabled = ( - self.greedy_decoding_computer is not None and self.greedy_decoding_computer.per_stream_biasing_enabled + self.decoding_computer is not None + and self.malsd_decoding_computer is None + and self.decoding_computer.per_stream_biasing_enabled ) if not biasing_enabled and any(state.has_biasing_request() for state in states): @@ -595,7 +582,7 @@ def cache_aware_transcribe_step( if state.options.biasing_cfg.auto_manage_multi_model: state.options.biasing_cfg.add_to_multi_model( tokenizer=self.asr_model.tokenizer, - biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model, + biasing_multi_model=self.decoding_computer.biasing_multi_model, ) else: logging.warning( @@ -640,7 +627,7 @@ def cache_aware_transcribe_step( # run per-request decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): - if self.decoding_computer is not None: + if self.malsd_decoding_computer is not None: eou_detected = self.run_malsd_decoder(state, request, hyp) else: eou_detected = self.run_greedy_decoder(state, request, hyp) @@ -654,15 +641,14 @@ def cache_aware_transcribe_step( if eos: state.reset_previous_hypothesis() - # Cleanup per-stream biasing models when stream ends (greedy path only; - # ``biasing_enabled`` is True only when ``self.greedy_decoding_computer`` is set). + # Cleanup per-stream biasing models when stream ends (greedy path only). if biasing_enabled: for request, state in zip(requests, states): # only the first request contains biasing options; biasing options for the stream are stored in state if request.is_last and state.has_biasing_request(): if state.options.biasing_cfg.auto_manage_multi_model: state.options.biasing_cfg.remove_from_multi_model( - biasing_multi_model=self.greedy_decoding_computer.biasing_multi_model + biasing_multi_model=self.decoding_computer.biasing_multi_model ) def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index b839ed279eda..1f26f8a0a7ac 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -71,9 +71,9 @@ def reset_previous_hypothesis(self) -> None: self.previous_hypothesis = None -class CacheAwareRNNTMALSDStreamingState(CacheAwareRNNTStreamingState): +class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): """ - Cache-aware RNNT state for MALSD beam-search streaming. + Cache-aware RNNT state for beam-search streaming. Transcript assembly is ``committed prefix + live beam suffix``. Beams may disagree within an utterance; at EOU the top-1 path is promoted into the From 893f656e99c7c032422114e476dda3a73086708d Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Tue, 16 Jun 2026 22:12:43 +0400 Subject: [PATCH 11/28] clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 6060cf81d1f5..5ca470db9802 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -96,8 +96,8 @@ def init_decoding_computer(self) -> None: pass @property - def malsd_decoding_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: - """Return ``decoding_computer`` when beam-search MALSD is active.""" + def beam_decoder_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: + """Return ``decoding_computer`` when beam-search decoding is active.""" if isinstance(self.decoding_computer, ModifiedALSDBatchedRNNTComputer): return self.decoding_computer return None @@ -203,12 +203,11 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta Args: options: (ASRRequestOptions) Request options for particular stream. Returns: - (CacheAwareRNNTStreamingState) New empty state. Returns the beam-search subclass - when the pipeline is configured for beam-search decoding. + (CacheAwareRNNTStreamingState) New empty state. New empty state. """ state = ( CacheAwareRNNTBeamStreamingState() - if self.malsd_decoding_computer is not None + if self.beam_decoder_computer is not None else CacheAwareRNNTStreamingState() ) state.set_global_offset(0) @@ -283,7 +282,7 @@ def _streaming_step( Run one cache-aware encode/decode step for the current chunk. Returns per-stream hypotheses and the updated encoder cache context. """ - if self.malsd_decoding_computer is None: + if self.beam_decoder_computer is None: return self.asr_model.stream_step( processed_signal=feature_buffers, processed_signal_length=feature_buffer_lens, @@ -296,7 +295,7 @@ def _streaming_step( prompt_vectors=prompt_vectors, ) return self._malsd_stream_step( - malsd_computer=self.malsd_decoding_computer, + malsd_computer=self.beam_decoder_computer, states=states, feature_buffers=feature_buffers, feature_buffer_lens=feature_buffer_lens, @@ -318,23 +317,27 @@ def _malsd_stream_step( biasing_enabled: bool, ) -> tuple[list[Hypothesis], object]: """ - One streaming step for the MALSD beam-search path: - - 1. Encoder-only pass - the decoder is driven by this pipeline, not by - the model's built-in decoding wrapper. - 2. Merge per-stream ``MALSDStateItem``s into a batched MALSD state. - 3. Run :class:`ModifiedALSDBatchedRNNTComputer` for this chunk. - 4. Update per-stream windowed-beam tracking from this chunk's emissions. - 5. Split the batched MALSD state back into per-stream carries. - 6. Build a cumulative ``Hypothesis`` per stream from - ``window_committed + window_beam_tokens[top1]``. - - Collapse to the chunk's top-1 is NOT performed here - beams stay - diverged across chunks and are collapsed per-stream at the EOU - boundary inside :meth:`run_malsd_decoder`. - - Returns a list of cumulative ``Hypothesis`` per stream and the new - encoder cache context, matching the shape of ``stream_step``. + Cache-aware encode/decode step for MALSD beam search. + + Greedy decoding uses ``asr_model.stream_step``, which fuses encoder and + label-looping decode and returns a single-path ``Hypothesis``. Beam search + instead merges per-stream ``hyp_decoding_state`` into a batched MALSD state, + decodes the chunk, and splits the carry back out. That lifecycle is owned + here rather than by the model wrapper, so this method calls ``encoder_step`` + and then ``malsd_computer`` on the encoded frames. + + The MALSD computer returns chunk-local emissions per beam slot, not a full + cross-chunk transcript. Beams may diverge within an utterance and the + score argmax top-1 can change between chunks. Text is therefore assembled + on the CPU: ``window_committed_*`` stores the prefix frozen at the last + EOU; ``window_beam_*`` stores per-beam suffixes since then (permuted each + chunk in :meth:`_update_windowed_beam_state`). Each returned ``Hypothesis`` + is ``window_committed + window_beam[top1]``. GPU beams are left diverged + until EOU, when :meth:`run_malsd_decoder` collapses ``hyp_decoding_state``. + + Returns: + Per-stream hypotheses and the updated encoder cache context (same + contract as ``stream_step``). """ # Per-stream multi-biasing ids: not yet supported on the MALSD streaming # path. Greedy-side per-stream biasing knobs stay independent. @@ -485,7 +488,7 @@ def run_malsd_decoder(self, state: CacheAwareRNNTBeamStreamingState, request: Re # label so the next utterance benefits from cross-utterance context. if state.hyp_decoding_state is not None: top1 = int(state.hyp_decoding_state.score.argmax().item()) - self.malsd_decoding_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) + self.beam_decoder_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) state.window_committed_tokens = list(all_tokens) state.window_committed_timestamps = list(all_timestamps) state.window_beam_tokens = None @@ -567,7 +570,7 @@ def cache_aware_transcribe_step( # Per-stream biasing is only wired up on the greedy decoder. biasing_enabled = ( self.decoding_computer is not None - and self.malsd_decoding_computer is None + and self.beam_decoder_computer is None and self.decoding_computer.per_stream_biasing_enabled ) @@ -627,7 +630,7 @@ def cache_aware_transcribe_step( # run per-request decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): - if self.malsd_decoding_computer is not None: + if self.beam_decoder_computer is not None: eou_detected = self.run_malsd_decoder(state, request, hyp) else: eou_detected = self.run_greedy_decoder(state, request, hyp) From 664a2462ee3590ab3fa0f766a3009b890897a983 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 17 Jun 2026 00:05:57 +0400 Subject: [PATCH 12/28] isort and black Signed-off-by: lilithgrigoryan --- .../asr/inference/pipelines/cache_aware_rnnt_pipeline.py | 2 +- .../asr/parts/utils/batched_beam_decoding_utils.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 5ca470db9802..309e25124122 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -354,7 +354,7 @@ def _malsd_stream_step( batched_state = None else: batched_state = malsd_computer.merge_to_batched_state(carries) - + with ( torch.amp.autocast( device_type=self.asr_model.device_str, diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 43683f8a40bf..a32b0ac31b2f 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -348,7 +348,6 @@ def keep_beam_(self, beam_indices: torch.Tensor) -> None: # Mark all but the first slot as inactive so the next iteration's top-k repopulates them. self.scores[:, 1:].fill_(INACTIVE_SCORE) - def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: """ Get last labels for each hypothesis in the beam. @@ -696,9 +695,7 @@ def _export( max_idx = self.current_lengths_wb.max() - 1 transcripts = self.transcript_wb[..., : max_idx + 1] timestamps = self.timestamps[..., : max_idx + 1] - durations = ( - self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None - ) + durations = self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None return scores, transcripts, timestamps, durations, root_ptrs def _hypothesis_from_flat( From b342a16b5410c8b332ce7c431f4ed2b6c7aebb28 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 17 Jun 2026 20:57:52 +0400 Subject: [PATCH 13/28] add per-stream biasing Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 126 +++++++++++------- .../submodules/rnnt_malsd_batched_computer.py | 28 ++-- 2 files changed, 95 insertions(+), 59 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 309e25124122..a5a5bec88654 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -241,6 +241,23 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta return state + def close_session(self) -> None: + """Close the session and release per-stream biasing models held in the decoder.""" + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + biasing_multi_model = self.decoding_computer.biasing_multi_model + active_model_ids = [ + model_id + for model_id in range(biasing_multi_model.num_models) + if biasing_multi_model.model2active[model_id].item() + ] + with torch.inference_mode(): + for model_id in sorted(active_model_ids, reverse=True): + biasing_multi_model.remove_model(model_id) + for state in self._state_pool.values(): + if state.has_biasing_request(): + state.options.biasing_cfg.multi_model_id = None + super().close_session() + def get_sep(self) -> str: """Return the separator for the text processor.""" return self.sep @@ -276,7 +293,7 @@ def _streaming_step( drop_extra_pre_encoded: int, keep_all_outputs: bool, prompt_vectors: Tensor | None, - biasing_enabled: bool, + multi_biasing_ids: Tensor | None = None, ) -> tuple[list[Hypothesis], object]: """ Run one cache-aware encode/decode step for the current chunk. @@ -302,9 +319,57 @@ def _streaming_step( context=context, drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, - biasing_enabled=biasing_enabled, + multi_biasing_ids=multi_biasing_ids, ) + def _prepare_per_stream_biasing( + self, + states: list[CacheAwareRNNTStreamingState], + previous_hypotheses: list[Hypothesis | None], + device: torch.device, + ) -> tuple[list[Hypothesis | None], Tensor | None]: + if self.decoding_computer is None or not self.decoding_computer.per_stream_biasing_enabled: + if any(state.has_biasing_request() for state in states): + logging.warning( + "Biasing request is not empty, but decoder does not support per-stream biasing. Skipping" + ) + return previous_hypotheses, None + + biasing_multi_model = self.decoding_computer.biasing_multi_model + multi_biasing_ids_np = np.full([len(states)], fill_value=-1) + for i, (state, previous_hyp) in enumerate(zip(states, previous_hypotheses)): + if not state.has_biasing_request(): + continue + + biasing_cfg = state.options.biasing_cfg + model_id = biasing_cfg.multi_model_id + if model_id is not None and not biasing_multi_model.model2active[model_id].item(): + model_id = biasing_cfg.multi_model_id = None + + if model_id is None: + if biasing_cfg.auto_manage_multi_model: + with torch.inference_mode(): + biasing_cfg.add_to_multi_model( + tokenizer=self.asr_model.tokenizer, + biasing_multi_model=biasing_multi_model, + ) + else: + logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping") + continue + + multi_biasing_ids_np[i] = biasing_cfg.multi_model_id + + if self.beam_decoder_computer is None: + if previous_hyp is None: + previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(biasing_cfg) + else: + previous_hyp.biasing_cfg = biasing_cfg + + multi_biasing_ids = None + if self.beam_decoder_computer is not None: + multi_biasing_ids = torch.from_numpy(multi_biasing_ids_np).to(device=device) + return previous_hypotheses, multi_biasing_ids + def _malsd_stream_step( self, malsd_computer: ModifiedALSDBatchedRNNTComputer, @@ -314,7 +379,7 @@ def _malsd_stream_step( context, drop_extra_pre_encoded: int, keep_all_outputs: bool, - biasing_enabled: bool, + multi_biasing_ids: Tensor | None = None, ) -> tuple[list[Hypothesis], object]: """ Cache-aware encode/decode step for MALSD beam search. @@ -339,14 +404,6 @@ def _malsd_stream_step( Per-stream hypotheses and the updated encoder cache context (same contract as ``stream_step``). """ - # Per-stream multi-biasing ids: not yet supported on the MALSD streaming - # path. Greedy-side per-stream biasing knobs stay independent. - if biasing_enabled: - logging.warning( - "Per-stream biasing is not yet wired up on the MALSD cache-aware " - "streaming path; ignoring biasing requests for this chunk." - ) - # Merge per-stream carries into a batched MALSD state. ``None`` entries # (fresh streams) are filled with the after-SOS state inside ``merge_to_batched_state``. carries = [state.hyp_decoding_state for state in states] @@ -377,7 +434,9 @@ def _malsd_stream_step( # computer expects [B, T, D] (matches the rest of the decoding stack). encs_dim_last = encoded.transpose(1, 2).contiguous() - best_batched_hyps, batched_state = malsd_computer(encs_dim_last, encoded_len, batched_state) + best_batched_hyps, batched_state = malsd_computer( + encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids + ) self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) @@ -567,35 +626,12 @@ def cache_aware_transcribe_step( previous_hypotheses = [state.get_previous_hypothesis() for state in states] - # Per-stream biasing is only wired up on the greedy decoder. - biasing_enabled = ( - self.decoding_computer is not None - and self.beam_decoder_computer is None - and self.decoding_computer.per_stream_biasing_enabled + previous_hypotheses, multi_biasing_ids = self._prepare_per_stream_biasing( + states=states, + previous_hypotheses=previous_hypotheses, + device=feature_buffers.device, ) - if not biasing_enabled and any(state.has_biasing_request() for state in states): - logging.warning("Biasing request is not empty, but decoder does not support per-stream biasing. Skipping") - - # Handle per-stream biasing: add biasing models to multi_model if needed - if biasing_enabled: - for i, (request, state, previous_hyp) in enumerate(zip(requests, states, previous_hypotheses)): - if state.has_biasing_request(): - if state.options.biasing_cfg.multi_model_id is None: - if state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.add_to_multi_model( - tokenizer=self.asr_model.tokenizer, - biasing_multi_model=self.decoding_computer.biasing_multi_model, - ) - else: - logging.warning( - "Biasing request is not empty, not auto managed and not compiled. Skipping" - ) - if previous_hyp is None: - previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(state.options.biasing_cfg) - else: - previous_hyp.biasing_cfg = state.options.biasing_cfg - context, mapping = self.context_manager.get_context(stream_ids) prompt_vectors = None @@ -612,7 +648,7 @@ def cache_aware_transcribe_step( drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, prompt_vectors=prompt_vectors, - biasing_enabled=biasing_enabled, + multi_biasing_ids=multi_biasing_ids, ) # update the cache and reset the cache slots for the streams that has ended @@ -644,16 +680,6 @@ def cache_aware_transcribe_step( if eos: state.reset_previous_hypothesis() - # Cleanup per-stream biasing models when stream ends (greedy path only). - if biasing_enabled: - for request, state in zip(requests, states): - # only the first request contains biasing options; biasing options for the stream are stored in state - if request.is_last and state.has_biasing_request(): - if state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.remove_from_multi_model( - biasing_multi_model=self.decoding_computer.biasing_multi_model - ) - def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ Transcribes the feature buffers in a streaming manner. diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 2ac13cee2f99..56d1377587c3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1536,8 +1536,8 @@ def _get_batched_state_after_sos(self, device: torch.device | str, batch_size: i scores[:, 0] = 0.0 fusion_states_list: list[torch.Tensor] = [] - if self.fusion_models is not None: - for fm in self.fusion_models: + if self.has_fusion_models: + for fm in self._all_fusion_models(): fs = fm.get_init_states(batch_size=total, bos=True).to(device) fusion_states_list.append(fs.reshape(batch_size, beam_size, *fs.shape[1:])) @@ -1649,13 +1649,23 @@ def merge_to_batched_state(self, state_items: list[Optional[MALSDStateItem]]) -> else None ) - num_fusion = len(state_items[0].fusion_state_list) - # Per-stream ``fusion_state_list[fusion_idx]`` is ``[K]``; stack along a new dim 0 - # to produce ``[B, K]`` (NOT ``cat`` which would give the flat ``[B*K]`` shape used - # by ``predictor_*`` and would trip downstream shape mismatches). - fusion_states_list = [ - torch.stack([item.fusion_state_list[fi] for item in state_items], dim=0) for fi in range(num_fusion) - ] + num_fusion = max(len(item.fusion_state_list) for item in state_items) + if num_fusion > 0: + sos_fusion_template: list[torch.Tensor] | None = None + for item in state_items: + if len(item.fusion_state_list) < num_fusion: + if sos_fusion_template is None: + sos_fusion_template = self._get_state_item_after_sos( + device=item.predictor_output.device + ).fusion_state_list + for fi in range(len(item.fusion_state_list), num_fusion): + item.fusion_state_list.append(sos_fusion_template[fi].clone()) + + fusion_states_list = [ + torch.stack([item.fusion_state_list[fi] for item in state_items], dim=0) for fi in range(num_fusion) + ] + else: + fusion_states_list = [] return BatchedBeamState( predictor_states=batched_predictor_state, From b9a31a4074778f2346ff7c8ba1635fa89d673ce3 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 17 Jun 2026 21:11:34 +0400 Subject: [PATCH 14/28] clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 75 +---------------- .../streaming/state/cache_aware_rnnt_state.py | 37 ++------ .../submodules/rnnt_malsd_batched_computer.py | 84 ++----------------- .../utils/batched_beam_decoding_utils.py | 45 +--------- 4 files changed, 20 insertions(+), 221 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index a5a5bec88654..6a49949e7332 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -381,31 +381,7 @@ def _malsd_stream_step( keep_all_outputs: bool, multi_biasing_ids: Tensor | None = None, ) -> tuple[list[Hypothesis], object]: - """ - Cache-aware encode/decode step for MALSD beam search. - - Greedy decoding uses ``asr_model.stream_step``, which fuses encoder and - label-looping decode and returns a single-path ``Hypothesis``. Beam search - instead merges per-stream ``hyp_decoding_state`` into a batched MALSD state, - decodes the chunk, and splits the carry back out. That lifecycle is owned - here rather than by the model wrapper, so this method calls ``encoder_step`` - and then ``malsd_computer`` on the encoded frames. - - The MALSD computer returns chunk-local emissions per beam slot, not a full - cross-chunk transcript. Beams may diverge within an utterance and the - score argmax top-1 can change between chunks. Text is therefore assembled - on the CPU: ``window_committed_*`` stores the prefix frozen at the last - EOU; ``window_beam_*`` stores per-beam suffixes since then (permuted each - chunk in :meth:`_update_windowed_beam_state`). Each returned ``Hypothesis`` - is ``window_committed + window_beam[top1]``. GPU beams are left diverged - until EOU, when :meth:`run_malsd_decoder` collapses ``hyp_decoding_state``. - - Returns: - Per-stream hypotheses and the updated encoder cache context (same - contract as ``stream_step``). - """ - # Merge per-stream carries into a batched MALSD state. ``None`` entries - # (fresh streams) are filled with the after-SOS state inside ``merge_to_batched_state``. + """Cache-aware MALSD encode/decode step for one chunk.""" carries = [state.hyp_decoding_state for state in states] if all(c is None for c in carries): batched_state = None @@ -430,8 +406,6 @@ def _malsd_stream_step( drop_left_context=self.drop_left_context, valid_out_len=self.valid_out_len, ) - # ``encoded`` from the encoder wrapper is shaped [B, D, T]; the MALSD - # computer expects [B, T, D] (matches the rest of the decoding stack). encs_dim_last = encoded.transpose(1, 2).contiguous() best_batched_hyps, batched_state = malsd_computer( @@ -478,14 +452,7 @@ def _update_windowed_beam_state( states: list[CacheAwareRNNTBeamStreamingState], best_batched_hyps: BatchedBeamHyps, ) -> None: - """ - Extend each state's per-slot ``window_beam_tokens[k]`` with the chunk-local - emissions of the slot that originated from carry slot ``k`` at chunk start. - - The helper exposes per-(batch, beam) chunk-local tokens/timestamps and the - chunk-start -> chunk-end descent map; the permute-then-append windowed-beam - policy lives here. - """ + """Append chunk-local beam emissions to each stream's windowed-beam state.""" chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) beam_size = best_batched_hyps.beam_size for state, ct, cts, rp in zip(states, chunk_tokens, chunk_timestamps, root_ptrs): @@ -496,32 +463,11 @@ def _update_windowed_beam_state( def run_malsd_decoder(self, state: CacheAwareRNNTBeamStreamingState, request: Request, hyp: Hypothesis) -> bool: """ - MALSD counterpart to :meth:`run_greedy_decoder`. - - Reuses the greedy decoder for EOU detection, label-buffer rolling and - offset bookkeeping. Then RESYNCS ``state.tokens`` / ``state.timesteps`` / - ``state.confidences`` with the current top-1's cumulative slice - (``hyp.y_sequence[_malsd_utterance_start:]``). - - The resync is the load-bearing step that distinguishes MALSD from - greedy: between chunks, MALSD's raw-argmax top-1 can switch beams with - incompatible token histories (beam A: ``["I"]`` at chunk t, beam B: - ``["I", "I"]`` at chunk t+1). ``run_greedy_decoder`` appends - ``hyp.y_sequence[offset:]`` onto whatever was already in ``state.tokens``, - which would splice A's prefix with B's new tokens into a Frankenstein - transcript. Overwriting with the actual current top-1 belief keeps the - published transcript consistent with whichever beam currently wins. - - On EOU we bump ``_malsd_utterance_start`` to the current cumulative - length so the next utterance's resync slice starts past the cleared - previous utterance, then collapse the per-stream MALSD carry to its - top-1 beam: the K-beam state diverges intra-utterance and snaps to the - chosen path at the natural utterance boundary. + Run greedy EOU/label logic, then resync ``state.tokens`` from the current + top-1 cumulative hyp. On EOU, collapse the MALSD carry and commit the window. """ eou_detected = self.run_greedy_decoder(state, request, hyp) - # Resync state.tokens / state.timesteps / state.confidences with the - # current top-1's cumulative slice for this utterance. all_tokens = list(hyp.y_sequence) if hyp.y_sequence is not None else [] all_timestamps = list(hyp.timestamp) if hyp.timestamp is not None else [] start = max(0, int(state._malsd_utterance_start)) @@ -537,14 +483,8 @@ def run_malsd_decoder(self, state: CacheAwareRNNTBeamStreamingState, request: Re state.last_token_idx = timestamps_list[-1] if timestamps_list else None if eou_detected: - # Mark the boundary so the next utterance's slice starts past the - # tokens we just finalised. state._malsd_utterance_start = len(all_tokens) - # EOU-driven collapse: promote the chosen window into the committed - # prefix and replicate the winning beam across all K slots of the - # per-stream carry. The predictor stays warm at the top-1's last - # label so the next utterance benefits from cross-utterance context. if state.hyp_decoding_state is not None: top1 = int(state.hyp_decoding_state.score.argmax().item()) self.beam_decoder_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) @@ -655,16 +595,10 @@ def cache_aware_transcribe_step( self.context_manager.update_cache(stream_ids, new_context, mapping) self.context_manager.reset_slots(stream_ids, eos_flags) - # update the previous hypothesis for non-eos streams. For greedy this is the - # ``Hypothesis`` returned by ``rnnt_decoder_predictions_tensor``; for MALSD - # it is the cumulative ``Hypothesis`` built in ``_malsd_stream_step``. The - # eos reset is deferred to *after* the per-request decoder loop below so - # that ``run_malsd_decoder`` can still see the current utterance start. for state, hyp, eos in zip(states, best_hyp, eos_flags): if not eos: state.set_previous_hypothesis(hyp) - # run per-request decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): if self.beam_decoder_computer is not None: eou_detected = self.run_malsd_decoder(state, request, hyp) @@ -675,7 +609,6 @@ def cache_aware_transcribe_step( state.cleanup_after_eou() ready_state_ids.add(request.stream_id) - # Deferred eos reset - now safe to clear MALSD per-stream carry too. for state, eos in zip(states, eos_flags): if eos: state.reset_previous_hypothesis() diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 1f26f8a0a7ac..0b40b3357417 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -12,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from typing import TYPE_CHECKING - from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis -if TYPE_CHECKING: - from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import MALSDStateItem - class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ @@ -72,50 +66,31 @@ def reset_previous_hypothesis(self) -> None: class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): - """ - Cache-aware RNNT state for beam-search streaming. - - Transcript assembly is ``committed prefix + live beam suffix``. Beams may - disagree within an utterance; at EOU the top-1 path is promoted into the - committed prefix and per-beam suffixes are cleared. - - See :class:`CacheAwareRNNTPipeline` (``_malsd_stream_step``, ``run_malsd_decoder``). - """ + """Cache-aware RNNT state for MALSD beam-search streaming.""" def _additional_params_reset(self) -> None: - """ - Reset MALSD per-stream carry on top of the greedy state. - """ super()._additional_params_reset() - # Per-stream MALSD decoder carry (``MALSDStateItem``); Shuttled through - # ``merge_to_batched_state`` / ``split_batched_state`` each chunk. self.hyp_decoding_state: "MALSDStateItem | None" = None # Finalized transcript prefix at the last EOU; identical for every beam slot. self.window_committed_tokens: list[int] = [] # Frame timestamps aligned with ``window_committed_tokens``. self.window_committed_timestamps: list[int] = [] + # Per-beam suffix since last EOU; slot k may differ while beams compete. self.window_beam_tokens: list[list[int]] | None = None - # Per-beam frame timestamps aligned with ``window_beam_tokens`` (same slot layout). + # Per-beam frame timestamps aligned with ``window_beam_tokens``. self.window_beam_timestamps: list[list[int]] | None = None + # Index into cumulative ``hyp.y_sequence`` where the current utterance starts # (skips tokens from prior utterances still present in the cumulative hyp). self._malsd_utterance_start: int = 0 def reset_previous_hypothesis(self) -> None: - """ - Reset the previous hypothesis and all MALSD beam-search bookkeeping. - - Called at end-of-stream. Zeroes out the MALSD per-stream carry so the - next utterance starts from SOS with an empty windowed-beam state. - """ + """Reset carry and windowed-beam state at end-of-stream.""" super().reset_previous_hypothesis() self.hyp_decoding_state = None self.window_committed_tokens = [] self.window_committed_timestamps = [] self.window_beam_tokens = None self.window_beam_timestamps = None - # NB: ``_malsd_utterance_start`` is intentionally NOT reset here because - # the cumulative ``hyp.y_sequence`` it indexes is owned by the pipeline - # and bumped after the call when the previous utterance is being - # finalised. The pipeline bumps it explicitly after publishing. + # _malsd_utterance_start is bumped by the pipeline on EOU, not here. diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 56d1377587c3..adf28a42363c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -213,14 +213,7 @@ class SeparateGraphsMALSD: @dataclass class MALSDStateItem: - """ - Per-stream decoding state for ``ModifiedALSDBatchedRNNTComputer``. - - Used by streaming pipelines that maintain per-stream state. Mirrors - ``LabelLoopingStateItem`` (greedy) with beam-shaped tensors - (``[beam_size, ...]`` instead of scalar/``[D]``) plus the cross-chunk - per-beam fields needed to seed the next MALSD chunk. - """ + """Per-stream MALSD carry for cache-aware streaming (beam-shaped tensors).""" predictor_state: Any # opaque per-stream predictor state of size beam_size predictor_output: torch.Tensor # [beam_size, 1, D] @@ -1500,27 +1493,12 @@ def _create_decoding_state( ) def _get_state_item_after_sos(self, device: torch.device | str) -> MALSDStateItem: - """ - Per-stream after-SOS state. Used by :meth:`merge_to_batched_state` to fill - ``None`` items (fresh streams that joined the batch mid-flight). - - Built by constructing a ``batch_size=1`` batched after-SOS state and - taking the first item out of :meth:`split_batched_state` - mirrors the - greedy ``_get_decoding_state_item_after_sos`` pattern. - """ + """After-SOS per-stream state; used to fill ``None`` entries in merge.""" batched = self._get_batched_state_after_sos(device=device, batch_size=1) return self.split_batched_state(batched)[0] def _get_batched_state_after_sos(self, device: torch.device | str, batch_size: int) -> BatchedBeamState: - """ - Build a fresh batched MALSD state after ````. - - Shapes follow the contract consumed by :meth:`split_batched_state`: - predictor state/outputs are flat ``[B*K, ...]``; per-beam fields are - ``[B, K]``; fusion states are ``[B, K, ...]``; ``decoded_lengths`` is ``[B]``. - Slot 0 starts active (``score=0.0``); slots ``1..K-1`` start inactive so the - next chunk's top-k expands the surviving beam. - """ + """Fresh batched MALSD state after ```` (slot 0 active, others inactive).""" beam_size = self.beam_size total = batch_size * beam_size @@ -1559,19 +1537,7 @@ def zeros_bk() -> torch.Tensor: ) def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: - """ - Split a batched MALSD state into per-stream ``MALSDStateItem``s. - - Mirrors ``GreedyBatchedLabelLoopingComputerBase.split_batched_state`` for - beam-search shapes: - - - the predictor state was created with batch dimension ``B * beam_size``; - we slice it into ``B`` groups of ``beam_size`` consecutive rows and - re-batch each group with ``decoder.batch_unsplit_states``. - - ``labels`` / ``decoded_lengths`` and per-beam cross-chunk scalars are - split along the batch axis. - - ``fusion_states_list`` has each element as ``[B, beam_size, ...]``. - """ + """Split a batched MALSD state into per-stream items.""" if state is None: return [] batch_size = state.labels.shape[0] @@ -1589,8 +1555,6 @@ def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: stream_predictor_state = self.decoder.batch_unsplit_states( per_row_states[i * beam_size : (i + 1) * beam_size] ) - # ``state.fusion_states_list[k]`` is stored as ``[B, K]`` (see - # ``modified_alsd_torch``'s ``s.view(batch_size, self.beam_size)`` step). fusion_state_list = [fs[i].clone() for fs in state.fusion_states_list] if state.fusion_states_list else [] items.append( MALSDStateItem( @@ -1615,12 +1579,7 @@ def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: return items def merge_to_batched_state(self, state_items: list[Optional[MALSDStateItem]]) -> BatchedBeamState: - """ - Merge a list of per-stream ``MALSDStateItem``s into a single batched MALSD state. - - ``None`` entries (e.g. fresh streams that joined a batch mid-flight) are - replaced with a freshly-initialised after-SOS state. - """ + """Merge per-stream items into one batched state; ``None`` entries get after-SOS fillers.""" if any(item is None for item in state_items): not_none_item = next(item for item in state_items if item is not None) device = not_none_item.predictor_output.device @@ -1687,22 +1646,7 @@ def collapse_batched_state_to_beams_( batched_hyps: BatchedBeamHyps, beam_indices: torch.Tensor, ) -> None: - """ - In-place: collapse each row of a batched MALSD state and its associated - :class:`BatchedBeamHyps` to a single surviving beam, replicated across all - ``beam_size`` slots. - - After the call, every per-beam tensor on ``state`` and on ``batched_hyps`` - carries the chosen beam's value at slot 0 and identical clones at slots - 1..beam_size-1; ``scores[:, 1:]`` is set to ``INACTIVE_SCORE`` so the next - chunk's top-k repopulates them through normal expansion of the surviving beam. - - Args: - state: batched MALSD state to collapse in place. - batched_hyps: prefix-tree object returned alongside ``state``. Mutated - in place via :meth:`BatchedBeamHyps.keep_beam_`. - beam_indices: ``[batch_size]`` long tensor giving the beam to keep per row. - """ + """Collapse each batch row to one beam, replicated across all slots.""" batch_size = state.labels.shape[0] beam_size = self.beam_size if beam_indices.shape != (batch_size,): @@ -1745,8 +1689,6 @@ def collapse_batched_state_to_beams_( ).contiguous() if state.fusion_states_list: - # Fusion states are reshaped to ``[B, K]`` inside ``modified_alsd_torch`` - # so use the per-stream ``beam_perm`` gather along the beam axis. for fs in state.fusion_states_list: if fs.ndim != 2: raise NotImplementedError( @@ -1760,19 +1702,7 @@ def collapse_batched_state_to_beams_( batched_hyps.keep_beam_(beam_indices) def collapse_state_item_to_top1_(self, item: MALSDStateItem, beam_index: int) -> None: - """ - In-place per-stream variant of :meth:`collapse_batched_state_to_beams_`. - - Replicates beam ``beam_index`` across all ``beam_size`` slots of ``item`` - and sets ``score[1:] = INACTIVE_SCORE`` so the next chunk's top-k expands - the surviving beam. Used by streaming pipelines to collapse a single - stream's MALSD carry at its EOU boundary without disturbing other rows - of a batched run. - - Wraps mutations in :func:`torch.inference_mode` so it can be called from - outside the encoder/decoder inference region (the per-stream tensors are - inference tensors produced by :meth:`split_batched_state`). - """ + """In-place per-stream collapse to one beam (used at EOU in streaming).""" beam_size = self.beam_size if not 0 <= beam_index < beam_size: raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index a32b0ac31b2f..5de814d5185a 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -81,21 +81,7 @@ def seed_batched_hyps_from_state( state: BatchedBeamState, batch_size: Optional[int] = None, ) -> None: - """Copy cross-chunk per-beam fields from a :class:`BatchedBeamState` snapshot - into ``hyps`` (in-place). Inverse of - :meth:`BatchedBeamHyps.export_cross_chunk_state`. - - Used by streaming beam-search decoders to seed a ``BatchedBeamHyps`` from the previous - chunk's snapshot. Chunk-local buffers (prefix tree / timestamps / write cursor) - and the per-beam time cursor are NOT touched -- the caller is responsible for - wiping them. - - Args: - hyps: destination ``BatchedBeamHyps`` (modified in place). - state: source snapshot. No-op when ``state.scores`` is ``None`` (first chunk). - batch_size: optional number of leading rows to copy. Defaults to - ``state.scores.shape[0]``. - """ + """Seed ``hyps`` cross-chunk fields from a ``BatchedBeamState`` snapshot (in-place).""" if state.scores is None: return bs = state.scores.shape[0] if batch_size is None else batch_size @@ -324,18 +310,7 @@ def clone(self, batch_size: Optional[int] = None) -> "BatchedBeamHyps": return new_hyps def keep_beam_(self, beam_indices: torch.Tensor) -> None: - """ - In-place: collapse each row to a single surviving beam, replicated across all - ``beam_size`` slots, with the other slots' scores set to ``INACTIVE_SCORE``. - - Used by streaming pipelines to commit the per-chunk best beam as the - definitive history before the next chunk, so the carried predictor state and - the published transcript stay consistent. - - Args: - beam_indices: ``[batch_size]`` long tensor giving the beam to keep for - each row in the batch. - """ + """Collapse each row to one beam, replicated across all slots (in-place).""" if self.beam_size <= 1: return permutation = ( @@ -345,7 +320,6 @@ def keep_beam_(self, beam_indices: torch.Tensor) -> None: .contiguous() ) self._flatten_with_permutation_(permutation) - # Mark all but the first slot as inactive so the next iteration's top-k repopulates them. self.scores[:, 1:].fill_(INACTIVE_SCORE) def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: @@ -1009,21 +983,8 @@ def merge_( def export_batched_beam_hyps_to_cpu_lists( bbh: BatchedBeamHyps, ) -> tuple[list[list[list[int]]], list[list[list[int]]], list[list[int]]]: - """ - Streaming-pipeline helper: flatten ``bbh`` in-place (identity permutation) and - return CPU-side per-(batch, beam) chunk-local emissions plus the chunk-start - descent map. Intended for windowed-beam aggregation outside the engine. - - Returns: - (tokens, timestamps, root_ptrs): - * ``tokens``: ``[batch_size][beam_size]`` non-blank/non-padding token - IDs for this chunk. - * ``timestamps``: ``[batch_size][beam_size]`` matching step indices. - * ``root_ptrs``: ``[batch_size][beam_size]`` chunk-start beam index - from which each current slot descends. - """ + """Export chunk-local per-beam tokens/timestamps and beam descent map to CPU lists.""" _, transcripts, timestamps, _, root_ptrs = bbh._export(sort=False) - # One sync to CPU; per-slot masking + .tolist() stays on CPU. root_ptrs_list = root_ptrs.detach().cpu().tolist() transcripts_cpu = transcripts.detach().cpu() timestamps_cpu = timestamps.detach().cpu() From 0ed3d923fd4263407bc79481f7ee30d04c555cb5 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 17 Jun 2026 21:14:51 +0400 Subject: [PATCH 15/28] clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 10 ++++++---- .../streaming/state/cache_aware_rnnt_state.py | 13 ++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 6a49949e7332..179aa1f7ca15 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -90,10 +90,12 @@ def __init__( def init_decoding_computer(self) -> None: """Initialize ``decoding_computer``.""" self.decoding_computer = None - try: - self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer - except AttributeError: - pass + asr_model = getattr(self.asr_model, "asr_model", None) + if asr_model is None: + return + decoding = getattr(getattr(asr_model, "decoding", None), "decoding", None) + if decoding is not None: + self.decoding_computer = getattr(decoding, "decoding_computer", None) @property def beam_decoder_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 0b40b3357417..c9d7475f1128 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +if TYPE_CHECKING: + from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import MALSDStateItem + class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ @@ -70,17 +77,17 @@ class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): def _additional_params_reset(self) -> None: super()._additional_params_reset() - self.hyp_decoding_state: "MALSDStateItem | None" = None + self.hyp_decoding_state: MALSDStateItem | None = None # Finalized transcript prefix at the last EOU; identical for every beam slot. self.window_committed_tokens: list[int] = [] # Frame timestamps aligned with ``window_committed_tokens``. self.window_committed_timestamps: list[int] = [] - + # Per-beam suffix since last EOU; slot k may differ while beams compete. self.window_beam_tokens: list[list[int]] | None = None # Per-beam frame timestamps aligned with ``window_beam_tokens``. self.window_beam_timestamps: list[list[int]] | None = None - + # Index into cumulative ``hyp.y_sequence`` where the current utterance starts # (skips tokens from prior utterances still present in the cumulative hyp). self._malsd_utterance_start: int = 0 From 56816dfc59580a3d2605c1250ebeefb87501baa8 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 20:58:02 +0400 Subject: [PATCH 16/28] refactor, separate state Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 101 +++------------ .../streaming/state/cache_aware_rnnt_state.py | 118 +++++++++++++++--- 2 files changed, 118 insertions(+), 101 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 179aa1f7ca15..262da0a48d5d 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -43,10 +43,7 @@ get_confidence_utils, ) from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer -from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( - BatchedBeamHyps, - export_batched_beam_hyps_to_cpu_lists, -) +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -414,87 +411,30 @@ def _malsd_stream_step( encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids ) - self._update_windowed_beam_state(states=states, best_batched_hyps=best_batched_hyps) - - # Per-stream top-1 beam slot. Indexes ``window_beam_tokens`` (which was - # just appended against the diverged beam slots) to build the publishable - # cumulative hypothesis below. - beam_indices_cpu = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() + chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) + beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() scores_cpu = best_batched_hyps.scores.detach().cpu() carry_items = malsd_computer.split_batched_state(batched_state) - for state, carry in zip(states, carry_items): + for state, ct, cts, rp, top1, carry in zip( + states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items + ): + state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, top1) state.hyp_decoding_state = carry - # Build per-stream cumulative ``Hypothesis`` from the windowed state. - # Collapse + window promotion is deferred to ``run_malsd_decoder`` and - # triggered by EOU, so the published hyp is the current top-1's path - # but the K-beam state continues to diverge across chunks. - hyps: list[Hypothesis] = [] - for b, state in enumerate(states): - top1_slot = beam_indices_cpu[b] - window_tokens = state.window_beam_tokens[top1_slot] if state.window_beam_tokens else [] - window_ts = state.window_beam_timestamps[top1_slot] if state.window_beam_timestamps else [] - cum_tokens = state.window_committed_tokens + list(window_tokens) - cum_ts = state.window_committed_timestamps + list(window_ts) - - hyps.append( - Hypothesis( - score=float(scores_cpu[b, top1_slot].item()), - y_sequence=cum_tokens, - timestamp=cum_ts, - length=len(cum_tokens), - ) - ) - + hyps = [ + state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) + for b, state in enumerate(states) + ] return hyps, new_context - def _update_windowed_beam_state( - self, - states: list[CacheAwareRNNTBeamStreamingState], - best_batched_hyps: BatchedBeamHyps, - ) -> None: - """Append chunk-local beam emissions to each stream's windowed-beam state.""" - chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) - beam_size = best_batched_hyps.beam_size - for state, ct, cts, rp in zip(states, chunk_tokens, chunk_timestamps, root_ptrs): - prev_t = state.window_beam_tokens or [[] for _ in range(beam_size)] - prev_ts = state.window_beam_timestamps or [[] for _ in range(beam_size)] - state.window_beam_tokens = [prev_t[int(rp[k])] + ct[k] for k in range(beam_size)] - state.window_beam_timestamps = [prev_ts[int(rp[k])] + cts[k] for k in range(beam_size)] - - def run_malsd_decoder(self, state: CacheAwareRNNTBeamStreamingState, request: Request, hyp: Hypothesis) -> bool: - """ - Run greedy EOU/label logic, then resync ``state.tokens`` from the current - top-1 cumulative hyp. On EOU, collapse the MALSD carry and commit the window. - """ - eou_detected = self.run_greedy_decoder(state, request, hyp) - - all_tokens = list(hyp.y_sequence) if hyp.y_sequence is not None else [] - all_timestamps = list(hyp.timestamp) if hyp.timestamp is not None else [] - start = max(0, int(state._malsd_utterance_start)) - start = min(start, len(all_tokens)) - tokens_list = all_tokens[start:] - timestamps_list = all_timestamps[start:] - - state.tokens = list(tokens_list) - state.timesteps = list(timestamps_list) - state.confidences = [0.0] * len(tokens_list) - if tokens_list: - state.last_token = tokens_list[-1] - state.last_token_idx = timestamps_list[-1] if timestamps_list else None - - if eou_detected: - state._malsd_utterance_start = len(all_tokens) - - if state.hyp_decoding_state is not None: - top1 = int(state.hyp_decoding_state.score.argmax().item()) - self.beam_decoder_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, top1) - state.window_committed_tokens = list(all_tokens) - state.window_committed_timestamps = list(all_timestamps) - state.window_beam_tokens = None - state.window_beam_timestamps = None - return eou_detected + def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: + """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" + if eou_detected and state.hyp_decoding_state is not None: + self.beam_decoder_computer.collapse_state_item_to_top1_( + state.hyp_decoding_state, state.get_top1_beam_index() + ) + state.update_(eou_detected) def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: """ @@ -602,10 +542,9 @@ def cache_aware_transcribe_step( state.set_previous_hypothesis(hyp) for request, state, hyp in zip(requests, states, best_hyp): + eou_detected = self.run_greedy_decoder(state, request, hyp) if self.beam_decoder_computer is not None: - eou_detected = self.run_malsd_decoder(state, request, hyp) - else: - eou_detected = self.run_greedy_decoder(state, request, hyp) + self._apply_beam_update_(state, eou_detected) if eou_detected: self.bpe_decoder.decode_bpe_tokens(state) state.cleanup_after_eou() diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index c9d7475f1128..5d928ea3a35e 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -73,31 +73,109 @@ def reset_previous_hypothesis(self) -> None: class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): - """Cache-aware RNNT state for MALSD beam-search streaming.""" + """MALSD beam-search streaming state; decoder carry + cumulative/partial tokens. + + ``hyp_decoding_state``: K-beam MALSD carry across chunks (collapsed to top1 on EOU in the pipeline). + ``cumulative_*``: tokens/timestamps sealed at each EOU (prior utterances in a stream). + ``partial_*[k]``: per-beam in-flight suffix since last EOU (chunk-local exports merged via lineage). + ``partial_top1_slot``: beam index for publish (chunk argmax). + + Chunk overlap is handled inside MALSD carry (same as greedy label-looping); Python only + concatenates chunk-local exports onto ``partial_*`` following ``root_ptrs`` beam lineage. + ``_cumulative_tokens_len`` slices ``state.tokens`` to the current utterance for publish. + On EOU: ``update_(eou_detected=True)`` folds tokens into ``cumulative_*``, clears ``partial_*``. + """ def _additional_params_reset(self) -> None: super()._additional_params_reset() self.hyp_decoding_state: MALSDStateItem | None = None - # Finalized transcript prefix at the last EOU; identical for every beam slot. - self.window_committed_tokens: list[int] = [] - # Frame timestamps aligned with ``window_committed_tokens``. - self.window_committed_timestamps: list[int] = [] - - # Per-beam suffix since last EOU; slot k may differ while beams compete. - self.window_beam_tokens: list[list[int]] | None = None - # Per-beam frame timestamps aligned with ``window_beam_tokens``. - self.window_beam_timestamps: list[list[int]] | None = None - - # Index into cumulative ``hyp.y_sequence`` where the current utterance starts - # (skips tokens from prior utterances still present in the cumulative hyp). - self._malsd_utterance_start: int = 0 + self.cumulative_tokens: list[int] = [] + self.cumulative_timestamps: list[int] = [] + self.partial_tokens: list[list[int]] | None = None + self.partial_timestamps: list[list[int]] | None = None + self._cumulative_tokens_len: int = 0 + self.partial_top1_slot: int | None = None def reset_previous_hypothesis(self) -> None: - """Reset carry and windowed-beam state at end-of-stream.""" + """Reset carry and partial-beam state at end-of-stream.""" super().reset_previous_hypothesis() self.hyp_decoding_state = None - self.window_committed_tokens = [] - self.window_committed_timestamps = [] - self.window_beam_tokens = None - self.window_beam_timestamps = None - # _malsd_utterance_start is bumped by the pipeline on EOU, not here. + self.cumulative_tokens = [] + self.cumulative_timestamps = [] + self.partial_tokens = None + self.partial_timestamps = None + self.partial_top1_slot = None + + def append_chunk_beam_( + self, + chunk_tokens: list[list[int]], + chunk_timestamps: list[list[int]], + root_ptrs: list[int], + beam_size: int, + top1_slot: int, + ) -> None: + """Append chunk-local beam exports into state.""" + prev_t = self.partial_tokens or [[] for _ in range(beam_size)] + prev_ts = self.partial_timestamps or [[] for _ in range(beam_size)] + next_tokens: list[list[int]] = [] + next_timestamps: list[list[int]] = [] + for k in range(beam_size): + lineage = int(root_ptrs[k]) + next_tokens.append(prev_t[lineage] + list(chunk_tokens[k])) + next_timestamps.append(prev_ts[lineage] + list(chunk_timestamps[k])) + self.partial_tokens = next_tokens + self.partial_timestamps = next_timestamps + self.partial_top1_slot = top1_slot + + def get_top1_beam_index(self) -> int: + """Beam slot used for publish (chunk argmax, or score argmax from carry).""" + if self.partial_top1_slot is not None: + return int(self.partial_top1_slot) + if self.hyp_decoding_state is None: + raise RuntimeError("Cannot resolve top-1 beam index without decoding carry.") + return int(self.hyp_decoding_state.score.argmax().item()) + + def _get_tokens(self) -> tuple[list[int], list[int]]: + """``cumulative_*`` plus the current top-1 ``partial_*`` suffix.""" + if self.partial_tokens is None or self.hyp_decoding_state is None: + return [], [] + top1 = self.get_top1_beam_index() + return ( + self.cumulative_tokens + list(self.partial_tokens[top1]), + self.cumulative_timestamps + list(self.partial_timestamps[top1]), + ) + + def get_hypothesis(self, score: float) -> Hypothesis: + """Build the publishable cumulative hypothesis for the current top-1 beam.""" + cum_tokens, cum_ts = self._get_tokens() + return Hypothesis( + score=score, + y_sequence=cum_tokens, + timestamp=cum_ts, + length=len(cum_tokens), + ) + + def update_(self, eou_detected: bool) -> None: + """Refresh publish tokens; on EOU fold utterance into ``cumulative_*`` and clear ``partial_*``.""" + cum_tokens, cum_ts = self._get_tokens() + if cum_tokens: + start = max(0, min(int(self._cumulative_tokens_len), len(cum_tokens))) + tokens = list(cum_tokens[start:]) + timesteps = list(cum_ts[start:]) + self.tokens = tokens + self.timesteps = timesteps + self.confidences = [0.0] * len(tokens) + if tokens: + self.last_token = tokens[-1] + self.last_token_idx = timesteps[-1] if timesteps else None + + if not eou_detected: + return + + if cum_tokens: + self._cumulative_tokens_len = len(cum_tokens) + self.cumulative_tokens = list(cum_tokens) + self.cumulative_timestamps = list(cum_ts) + self.partial_tokens = None + self.partial_timestamps = None + self.partial_top1_slot = None From f6da7a5154250ff48291f14943c673fc430cd11e Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 20:58:41 +0400 Subject: [PATCH 17/28] isort and black Signed-off-by: lilithgrigoryan --- .../asr/inference/pipelines/cache_aware_rnnt_pipeline.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 262da0a48d5d..d59ab47f4767 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -422,10 +422,7 @@ def _malsd_stream_step( state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, top1) state.hyp_decoding_state = carry - hyps = [ - state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) - for b, state in enumerate(states) - ] + hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] return hyps, new_context def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: From 22f5f7d955ebe7904c6a8efec6c86060d70b17bf Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 21:09:44 +0400 Subject: [PATCH 18/28] clean up Signed-off-by: lilithgrigoryan --- .../inference/pipelines/cache_aware_rnnt_pipeline.py | 12 ++++++++---- .../streaming/state/cache_aware_rnnt_state.py | 6 +++--- .../asr/parts/utils/batched_beam_decoding_utils.py | 12 +++++++++++- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index d59ab47f4767..f544aaae7a7d 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -534,8 +534,11 @@ def cache_aware_transcribe_step( self.context_manager.update_cache(stream_ids, new_context, mapping) self.context_manager.reset_slots(stream_ids, eos_flags) + # update the previous hypothesis and reset the previous hypothesis for the streams that has ended for state, hyp, eos in zip(states, best_hyp, eos_flags): - if not eos: + if eos: + state.reset_previous_hypothesis() + else: state.set_previous_hypothesis(hyp) for request, state, hyp in zip(requests, states, best_hyp): @@ -547,9 +550,10 @@ def cache_aware_transcribe_step( state.cleanup_after_eou() ready_state_ids.add(request.stream_id) - for state, eos in zip(states, eos_flags): - if eos: - state.reset_previous_hypothesis() + if self.beam_decoder_computer is not None: + for state, eos in zip(states, eos_flags): + if eos: + state.reset_beam_decoding_state_() def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index 5d928ea3a35e..d8ccd78092e6 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -96,14 +96,14 @@ def _additional_params_reset(self) -> None: self._cumulative_tokens_len: int = 0 self.partial_top1_slot: int | None = None - def reset_previous_hypothesis(self) -> None: - """Reset carry and partial-beam state at end-of-stream.""" - super().reset_previous_hypothesis() + def reset_beam_decoding_state_(self) -> None: + """Clear MALSD carry and cumulative/partial tokens when a stream ends.""" self.hyp_decoding_state = None self.cumulative_tokens = [] self.cumulative_timestamps = [] self.partial_tokens = None self.partial_timestamps = None + self._cumulative_tokens_len = 0 self.partial_top1_slot = None def append_chunk_beam_( diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 5de814d5185a..23fe3bd9b845 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -81,7 +81,17 @@ def seed_batched_hyps_from_state( state: BatchedBeamState, batch_size: Optional[int] = None, ) -> None: - """Seed ``hyps`` cross-chunk fields from a ``BatchedBeamState`` snapshot (in-place).""" + """ + Copy cross-chunk per-beam fields from a :class:`BatchedBeamState` snapshot + into ``hyps`` (in-place). Inverse of + :meth:`BatchedBeamHyps.export_cross_chunk_state`. + + Args: + hyps: destination ``BatchedBeamHyps`` (modified in place). + state: source snapshot. + batch_size: optional number of leading rows to copy. Defaults to + ``state.scores.shape[0]``. + """ if state.scores is None: return bs = state.scores.shape[0] if batch_size is None else batch_size From a299ee4b67f3737d35def8f72c41c6195ad56054 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 23:29:59 +0400 Subject: [PATCH 19/28] restore docstring Signed-off-by: lilithgrigoryan --- .../asr/parts/utils/batched_beam_decoding_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 23fe3bd9b845..2e3d68e01dfe 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -81,14 +81,18 @@ def seed_batched_hyps_from_state( state: BatchedBeamState, batch_size: Optional[int] = None, ) -> None: - """ - Copy cross-chunk per-beam fields from a :class:`BatchedBeamState` snapshot + """Copy cross-chunk per-beam fields from a :class:`BatchedBeamState` snapshot into ``hyps`` (in-place). Inverse of :meth:`BatchedBeamHyps.export_cross_chunk_state`. + Used by streaming beam-search decoders to seed a ``BatchedBeamHyps`` from the previous + chunk's snapshot. Chunk-local buffers (prefix tree / timestamps / write cursor) + and the per-beam time cursor are NOT touched -- the caller is responsible for + wiping them. + Args: hyps: destination ``BatchedBeamHyps`` (modified in place). - state: source snapshot. + state: source snapshot. No-op when ``state.scores`` is ``None`` (first chunk). batch_size: optional number of leading rows to copy. Defaults to ``state.scores.shape[0]``. """ From 957084f4755c5bd984eb771647b8049985e9f3f9 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 23:32:35 +0400 Subject: [PATCH 20/28] move malsd stream step to model wrapper Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_inference_wrapper.py | 64 ++++++++++++++++++ .../pipelines/cache_aware_rnnt_pipeline.py | 67 ++----------------- 2 files changed, 70 insertions(+), 61 deletions(-) diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index ccb6c1faf533..3a659689391e 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -18,9 +18,12 @@ from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( CacheAwareASRInferenceWrapper, ) +from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -165,6 +168,67 @@ def execute_step( ) return best_hyp, new_context + def malsd_stream_step( + self, + malsd_computer: ModifiedALSDBatchedRNNTComputer, + states: list[CacheAwareRNNTBeamStreamingState], + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + multi_biasing_ids: Tensor | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """Cache-aware MALSD encode/decode step for one chunk.""" + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + carries = [state.hyp_decoding_state for state in states] + if all(c is None for c in carries): + batched_state = None + else: + batched_state = malsd_computer.merge_to_batched_state(carries) + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + processed_signal = processed_signal.to(self.cast_dtype) + encoded, encoded_len, new_context = self.encoder_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=drop_left_context, + valid_out_len=valid_out_len, + ) + encs_dim_last = encoded.transpose(1, 2).contiguous() + + best_batched_hyps, batched_state = malsd_computer( + encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids + ) + + chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) + beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() + scores_cpu = best_batched_hyps.scores.detach().cpu() + + carry_items = malsd_computer.split_batched_state(batched_state) + for state, ct, cts, rp, top1, carry in zip( + states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items + ): + state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, top1) + state.hyp_decoding_state = carry + + hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] + return hyps, new_context + def stream_step( self, processed_signal: Tensor, diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index f544aaae7a7d..ae91d2f8b5f6 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -43,7 +43,6 @@ get_confidence_utils, ) from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer -from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -202,7 +201,7 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta Args: options: (ASRRequestOptions) Request options for particular stream. Returns: - (CacheAwareRNNTStreamingState) New empty state. New empty state. + (CacheAwareRNNTStreamingState) New empty state. """ state = ( CacheAwareRNNTBeamStreamingState() @@ -310,14 +309,16 @@ def _streaming_step( valid_out_len=self.valid_out_len, prompt_vectors=prompt_vectors, ) - return self._malsd_stream_step( + return self.asr_model.malsd_stream_step( malsd_computer=self.beam_decoder_computer, states=states, - feature_buffers=feature_buffers, - feature_buffer_lens=feature_buffer_lens, + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, context=context, drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, multi_biasing_ids=multi_biasing_ids, ) @@ -369,62 +370,6 @@ def _prepare_per_stream_biasing( multi_biasing_ids = torch.from_numpy(multi_biasing_ids_np).to(device=device) return previous_hypotheses, multi_biasing_ids - def _malsd_stream_step( - self, - malsd_computer: ModifiedALSDBatchedRNNTComputer, - states: list[CacheAwareRNNTBeamStreamingState], - feature_buffers: Tensor, - feature_buffer_lens: Tensor, - context, - drop_extra_pre_encoded: int, - keep_all_outputs: bool, - multi_biasing_ids: Tensor | None = None, - ) -> tuple[list[Hypothesis], object]: - """Cache-aware MALSD encode/decode step for one chunk.""" - carries = [state.hyp_decoding_state for state in states] - if all(c is None for c in carries): - batched_state = None - else: - batched_state = malsd_computer.merge_to_batched_state(carries) - - with ( - torch.amp.autocast( - device_type=self.asr_model.device_str, - dtype=self.asr_model.compute_dtype, - enabled=self.asr_model.use_amp, - ), - torch.inference_mode(), - ): - feature_buffers = feature_buffers.to(self.asr_model.cast_dtype) - encoded, encoded_len, new_context = self.asr_model.encoder_step( - processed_signal=feature_buffers, - processed_signal_length=feature_buffer_lens, - context=context, - drop_extra_pre_encoded=drop_extra_pre_encoded, - keep_all_outputs=keep_all_outputs, - drop_left_context=self.drop_left_context, - valid_out_len=self.valid_out_len, - ) - encs_dim_last = encoded.transpose(1, 2).contiguous() - - best_batched_hyps, batched_state = malsd_computer( - encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids - ) - - chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) - beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() - scores_cpu = best_batched_hyps.scores.detach().cpu() - - carry_items = malsd_computer.split_batched_state(batched_state) - for state, ct, cts, rp, top1, carry in zip( - states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items - ): - state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, top1) - state.hyp_decoding_state = carry - - hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] - return hyps, new_context - def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" if eou_detected and state.hyp_decoding_state is not None: From 49eb9fe4b2d996582db3e8e4f198b28b6d22b58b Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 23:37:18 +0400 Subject: [PATCH 21/28] clean up Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_inference_wrapper.py | 4 +- .../pipelines/cache_aware_rnnt_pipeline.py | 2 +- .../streaming/state/cache_aware_rnnt_state.py | 37 ++++++++----------- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index 3a659689391e..28046042b2c3 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -220,10 +220,10 @@ def malsd_stream_step( scores_cpu = best_batched_hyps.scores.detach().cpu() carry_items = malsd_computer.split_batched_state(batched_state) - for state, ct, cts, rp, top1, carry in zip( + for state, ct, cts, rp, best_hyp_idx, carry in zip( states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items ): - state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, top1) + state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, best_hyp_idx) state.hyp_decoding_state = carry hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index ae91d2f8b5f6..a4aae2d98f8c 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -374,7 +374,7 @@ def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detec """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" if eou_detected and state.hyp_decoding_state is not None: self.beam_decoder_computer.collapse_state_item_to_top1_( - state.hyp_decoding_state, state.get_top1_beam_index() + state.hyp_decoding_state, state.get_best_hyp_idx() ) state.update_(eou_detected) diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index d8ccd78092e6..715a3717a6e7 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -73,17 +73,12 @@ def reset_previous_hypothesis(self) -> None: class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): - """MALSD beam-search streaming state; decoder carry + cumulative/partial tokens. + """Beam search streaming state; decoder carry + cumulative/partial tokens. - ``hyp_decoding_state``: K-beam MALSD carry across chunks (collapsed to top1 on EOU in the pipeline). + ``hyp_decoding_state``: K-beam carry across chunks (collapsed to top1 on EOU in the pipeline). ``cumulative_*``: tokens/timestamps sealed at each EOU (prior utterances in a stream). ``partial_*[k]``: per-beam in-flight suffix since last EOU (chunk-local exports merged via lineage). - ``partial_top1_slot``: beam index for publish (chunk argmax). - - Chunk overlap is handled inside MALSD carry (same as greedy label-looping); Python only - concatenates chunk-local exports onto ``partial_*`` following ``root_ptrs`` beam lineage. - ``_cumulative_tokens_len`` slices ``state.tokens`` to the current utterance for publish. - On EOU: ``update_(eou_detected=True)`` folds tokens into ``cumulative_*``, clears ``partial_*``. + ``best_hyp_idx``: index into ``partial_*`` for the chunk argmax beam used to publish. """ def _additional_params_reset(self) -> None: @@ -94,17 +89,17 @@ def _additional_params_reset(self) -> None: self.partial_tokens: list[list[int]] | None = None self.partial_timestamps: list[list[int]] | None = None self._cumulative_tokens_len: int = 0 - self.partial_top1_slot: int | None = None + self.best_hyp_idx: int | None = None def reset_beam_decoding_state_(self) -> None: - """Clear MALSD carry and cumulative/partial tokens when a stream ends.""" + """Clear beam search carry and cumulative/partial tokens when a stream ends.""" self.hyp_decoding_state = None self.cumulative_tokens = [] self.cumulative_timestamps = [] self.partial_tokens = None self.partial_timestamps = None self._cumulative_tokens_len = 0 - self.partial_top1_slot = None + self.best_hyp_idx = None def append_chunk_beam_( self, @@ -112,7 +107,7 @@ def append_chunk_beam_( chunk_timestamps: list[list[int]], root_ptrs: list[int], beam_size: int, - top1_slot: int, + best_hyp_idx: int, ) -> None: """Append chunk-local beam exports into state.""" prev_t = self.partial_tokens or [[] for _ in range(beam_size)] @@ -125,12 +120,12 @@ def append_chunk_beam_( next_timestamps.append(prev_ts[lineage] + list(chunk_timestamps[k])) self.partial_tokens = next_tokens self.partial_timestamps = next_timestamps - self.partial_top1_slot = top1_slot + self.best_hyp_idx = best_hyp_idx - def get_top1_beam_index(self) -> int: - """Beam slot used for publish (chunk argmax, or score argmax from carry).""" - if self.partial_top1_slot is not None: - return int(self.partial_top1_slot) + def get_best_hyp_idx(self) -> int: + """Index into ``partial_*`` for publish (chunk argmax, or score argmax from carry).""" + if self.best_hyp_idx is not None: + return int(self.best_hyp_idx) if self.hyp_decoding_state is None: raise RuntimeError("Cannot resolve top-1 beam index without decoding carry.") return int(self.hyp_decoding_state.score.argmax().item()) @@ -139,10 +134,10 @@ def _get_tokens(self) -> tuple[list[int], list[int]]: """``cumulative_*`` plus the current top-1 ``partial_*`` suffix.""" if self.partial_tokens is None or self.hyp_decoding_state is None: return [], [] - top1 = self.get_top1_beam_index() + best_hyp_idx = self.get_best_hyp_idx() return ( - self.cumulative_tokens + list(self.partial_tokens[top1]), - self.cumulative_timestamps + list(self.partial_timestamps[top1]), + self.cumulative_tokens + list(self.partial_tokens[best_hyp_idx]), + self.cumulative_timestamps + list(self.partial_timestamps[best_hyp_idx]), ) def get_hypothesis(self, score: float) -> Hypothesis: @@ -178,4 +173,4 @@ def update_(self, eou_detected: bool) -> None: self.cumulative_timestamps = list(cum_ts) self.partial_tokens = None self.partial_timestamps = None - self.partial_top1_slot = None + self.best_hyp_idx = None From 7be3088a9b5f802c286b45b2744eb94255df06d5 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Thu, 18 Jun 2026 23:56:59 +0400 Subject: [PATCH 22/28] refactor per-stream biasing, add utils Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_inference_wrapper.py | 8 +- .../pipelines/cache_aware_rnnt_pipeline.py | 74 ++++-------- .../asr/inference/utils/per_stream_biasing.py | 105 ++++++++++++++++++ 3 files changed, 135 insertions(+), 52 deletions(-) create mode 100644 nemo/collections/asr/inference/utils/per_stream_biasing.py diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index 28046042b2c3..4712c0a891f5 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -22,6 +22,7 @@ from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -179,7 +180,6 @@ def malsd_stream_step( keep_all_outputs: bool, drop_left_context: int | None = None, valid_out_len: int | None = None, - multi_biasing_ids: Tensor | None = None, ) -> tuple[list[Hypothesis], CacheAwareContext]: """Cache-aware MALSD encode/decode step for one chunk.""" if processed_signal.device != self.device: @@ -194,6 +194,12 @@ def malsd_stream_step( else: batched_state = malsd_computer.merge_to_batched_state(carries) + multi_biasing_ids = multi_biasing_ids_tensor_from_states( + states, + self.device, + per_stream_biasing_enabled=malsd_computer.per_stream_biasing_enabled, + ) + with ( torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), torch.inference_mode(), diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index a4aae2d98f8c..e5c67dfbf4e0 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -42,6 +42,10 @@ drop_trailing_features, get_confidence_utils, ) +from nemo.collections.asr.inference.utils.per_stream_biasing import ( + build_multi_biasing_ids_np, + release_all_biasing_models, +) from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -242,18 +246,7 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta def close_session(self) -> None: """Close the session and release per-stream biasing models held in the decoder.""" if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: - biasing_multi_model = self.decoding_computer.biasing_multi_model - active_model_ids = [ - model_id - for model_id in range(biasing_multi_model.num_models) - if biasing_multi_model.model2active[model_id].item() - ] - with torch.inference_mode(): - for model_id in sorted(active_model_ids, reverse=True): - biasing_multi_model.remove_model(model_id) - for state in self._state_pool.values(): - if state.has_biasing_request(): - state.options.biasing_cfg.multi_model_id = None + release_all_biasing_models(self.decoding_computer.biasing_multi_model, self._state_pool.values()) super().close_session() def get_sep(self) -> str: @@ -291,7 +284,6 @@ def _streaming_step( drop_extra_pre_encoded: int, keep_all_outputs: bool, prompt_vectors: Tensor | None, - multi_biasing_ids: Tensor | None = None, ) -> tuple[list[Hypothesis], object]: """ Run one cache-aware encode/decode step for the current chunk. @@ -319,56 +311,38 @@ def _streaming_step( keep_all_outputs=keep_all_outputs, drop_left_context=self.drop_left_context, valid_out_len=self.valid_out_len, - multi_biasing_ids=multi_biasing_ids, ) def _prepare_per_stream_biasing( self, states: list[CacheAwareRNNTStreamingState], previous_hypotheses: list[Hypothesis | None], - device: torch.device, - ) -> tuple[list[Hypothesis | None], Tensor | None]: + ) -> list[Hypothesis | None]: if self.decoding_computer is None or not self.decoding_computer.per_stream_biasing_enabled: if any(state.has_biasing_request() for state in states): logging.warning( "Biasing request is not empty, but decoder does not support per-stream biasing. Skipping" ) - return previous_hypotheses, None + return previous_hypotheses + + multi_biasing_ids_np = build_multi_biasing_ids_np( + states, + self.decoding_computer.biasing_multi_model, + self.asr_model.tokenizer, + ) + + if self.beam_decoder_computer is not None: + return previous_hypotheses - biasing_multi_model = self.decoding_computer.biasing_multi_model - multi_biasing_ids_np = np.full([len(states)], fill_value=-1) for i, (state, previous_hyp) in enumerate(zip(states, previous_hypotheses)): - if not state.has_biasing_request(): + if multi_biasing_ids_np[i] < 0: continue - biasing_cfg = state.options.biasing_cfg - model_id = biasing_cfg.multi_model_id - if model_id is not None and not biasing_multi_model.model2active[model_id].item(): - model_id = biasing_cfg.multi_model_id = None - - if model_id is None: - if biasing_cfg.auto_manage_multi_model: - with torch.inference_mode(): - biasing_cfg.add_to_multi_model( - tokenizer=self.asr_model.tokenizer, - biasing_multi_model=biasing_multi_model, - ) - else: - logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping") - continue - - multi_biasing_ids_np[i] = biasing_cfg.multi_model_id - - if self.beam_decoder_computer is None: - if previous_hyp is None: - previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(biasing_cfg) - else: - previous_hyp.biasing_cfg = biasing_cfg - - multi_biasing_ids = None - if self.beam_decoder_computer is not None: - multi_biasing_ids = torch.from_numpy(multi_biasing_ids_np).to(device=device) - return previous_hypotheses, multi_biasing_ids + if previous_hyp is None: + previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(biasing_cfg) + else: + previous_hyp.biasing_cfg = biasing_cfg + return previous_hypotheses def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" @@ -450,10 +424,9 @@ def cache_aware_transcribe_step( previous_hypotheses = [state.get_previous_hypothesis() for state in states] - previous_hypotheses, multi_biasing_ids = self._prepare_per_stream_biasing( + previous_hypotheses = self._prepare_per_stream_biasing( states=states, previous_hypotheses=previous_hypotheses, - device=feature_buffers.device, ) context, mapping = self.context_manager.get_context(stream_ids) @@ -472,7 +445,6 @@ def cache_aware_transcribe_step( drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, prompt_vectors=prompt_vectors, - multi_biasing_ids=multi_biasing_ids, ) # update the cache and reset the cache slots for the streams that has ended diff --git a/nemo/collections/asr/inference/utils/per_stream_biasing.py b/nemo/collections/asr/inference/utils/per_stream_biasing.py new file mode 100644 index 000000000000..3ea49c813d74 --- /dev/null +++ b/nemo/collections/asr/inference/utils/per_stream_biasing.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +import numpy as np +import torch +from torch import Tensor + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModelBase + + +def build_multi_biasing_ids_np( + states: Sequence[Any], + biasing_multi_model: GPUBiasingMultiModelBase, + tokenizer: TokenizerSpec, +) -> np.ndarray: + """Build per-stream biasing model ids; ``-1`` means no biasing for that stream.""" + ids_np = np.full([len(states)], fill_value=-1, dtype=np.int64) + for i, state in enumerate(states): + if not state.has_biasing_request(): + continue + + biasing_cfg = state.options.biasing_cfg + model_id = biasing_cfg.multi_model_id + if model_id is not None and not biasing_multi_model.model2active[model_id].item(): + biasing_cfg.multi_model_id = None + model_id = None + + if model_id is None: + if biasing_cfg.auto_manage_multi_model: + with torch.inference_mode(): + biasing_cfg.add_to_multi_model(tokenizer=tokenizer, biasing_multi_model=biasing_multi_model) + model_id = biasing_cfg.multi_model_id + else: + logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping") + continue + + ids_np[i] = model_id + return ids_np + + +def multi_biasing_ids_tensor_from_states( + states: Sequence[Any], + device: torch.device, + *, + per_stream_biasing_enabled: bool, +) -> Tensor | None: + """Build decode-time biasing ids from ``state.options.biasing_cfg`` (after registration).""" + if not per_stream_biasing_enabled: + return None + + ids_np = np.full([len(states)], fill_value=-1, dtype=np.int64) + for i, state in enumerate(states): + if not state.has_biasing_request(): + continue + model_id = state.options.biasing_cfg.multi_model_id + if model_id is None: + logging.warning(f"Boosting tree requested in index {i}, not compiled, skipping") + continue + ids_np[i] = model_id + + if (ids_np < 0).all(): + return None + return torch.from_numpy(ids_np).to(device=device) + + +def release_all_biasing_models(biasing_multi_model: GPUBiasingMultiModelBase, states: Sequence[Any]) -> None: + """Remove every active biasing model and clear per-stream ``multi_model_id`` bookkeeping.""" + active_model_ids = [ + model_id + for model_id in range(biasing_multi_model.num_models) + if biasing_multi_model.model2active[model_id].item() + ] + with torch.inference_mode(): + for model_id in sorted(active_model_ids, reverse=True): + biasing_multi_model.remove_model(model_id) + for state in states: + if state.has_biasing_request(): + state.options.biasing_cfg.multi_model_id = None + + +def release_auto_managed_stream_biasing(state: Any, biasing_multi_model: GPUBiasingMultiModelBase) -> None: + """Drop an auto-managed biasing model when a single stream ends.""" + if not state.has_biasing_request(): + return + if state.options.biasing_cfg.auto_manage_multi_model: + state.options.biasing_cfg.remove_from_multi_model(biasing_multi_model) From 629de908e34aad8b58f5a49c5bf3147ad743ce6f Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:03:32 +0400 Subject: [PATCH 23/28] add malsd-only warning Signed-off-by: lilithgrigoryan --- .../asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml | 1 + .../asr/inference/pipelines/cache_aware_rnnt_pipeline.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index ef8bc1d2b4d0..ee3512a3b38d 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -31,6 +31,7 @@ asr: # used with `key_phrases_file` and `key_phrases_list` boosting_tree_alpha: 0.0 # Weight of the boosting tree beam: + # Cache-aware streaming supports MALSD beam search only (set asr.decoding.strategy: malsd_batch). beam_size: 4 allow_cuda_graphs: true # n-gram LM (off by default) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index e5c67dfbf4e0..8dda5be3ae16 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -85,6 +85,12 @@ def __init__( self.init_text_processor(cfg, itn_model) self.init_nmt_model(nmt_model) self.init_decoding_computer() + strategy = str(getattr(cfg.asr.decoding, "strategy", "greedy_batch")) + if strategy in {"beam", "tsd", "alsd", "maes", "maes_batch"}: + logging.warning( + "Cache-aware RNNT streaming supports MALSD beam search only (`malsd_batch`); " + f"configured decoding strategy is `{strategy}`." + ) super().__init__() def init_decoding_computer(self) -> None: From c59bc006827654ae9e0a4d3e1133393ff987bcd9 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:05:46 +0400 Subject: [PATCH 24/28] isort and black Signed-off-by: lilithgrigoryan --- .../cache_aware_rnnt_inference_wrapper.py | 2 +- .../inference/pipelines/cache_aware_rnnt_pipeline.py | 12 +++++------- .../asr/parts/utils/batched_beam_decoding_utils.py | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index 4712c0a891f5..8beed8fa5621 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -20,9 +20,9 @@ ) from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext +from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder -from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 8dda5be3ae16..9182f4aa25a5 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -37,15 +37,15 @@ ) from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.per_stream_biasing import ( + build_multi_biasing_ids_np, + release_all_biasing_models, +) from nemo.collections.asr.inference.utils.pipeline_utils import ( check_existance_of_required_attributes, drop_trailing_features, get_confidence_utils, ) -from nemo.collections.asr.inference.utils.per_stream_biasing import ( - build_multi_biasing_ids_np, - release_all_biasing_models, -) from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -353,9 +353,7 @@ def _prepare_per_stream_biasing( def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" if eou_detected and state.hyp_decoding_state is not None: - self.beam_decoder_computer.collapse_state_item_to_top1_( - state.hyp_decoding_state, state.get_best_hyp_idx() - ) + self.beam_decoder_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, state.get_best_hyp_idx()) state.update_(eou_detected) def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index 2e3d68e01dfe..71eaa94b4822 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -89,7 +89,7 @@ def seed_batched_hyps_from_state( chunk's snapshot. Chunk-local buffers (prefix tree / timestamps / write cursor) and the per-beam time cursor are NOT touched -- the caller is responsible for wiping them. - + Args: hyps: destination ``BatchedBeamHyps`` (modified in place). state: source snapshot. No-op when ``state.scores`` is ``None`` (first chunk). From 36569384e3dd4c7d5f1b686271acd6c833678616 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:21:04 +0400 Subject: [PATCH 25/28] restore releasing biaing models Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 15 ++++++++++----- .../submodules/rnnt_malsd_batched_computer.py | 10 ++++++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 9182f4aa25a5..bf8bc5fcc6e3 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -40,6 +40,7 @@ from nemo.collections.asr.inference.utils.per_stream_biasing import ( build_multi_biasing_ids_np, release_all_biasing_models, + release_auto_managed_stream_biasing, ) from nemo.collections.asr.inference.utils.pipeline_utils import ( check_existance_of_required_attributes, @@ -353,7 +354,7 @@ def _prepare_per_stream_biasing( def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" if eou_detected and state.hyp_decoding_state is not None: - self.beam_decoder_computer.collapse_state_item_to_top1_(state.hyp_decoding_state, state.get_best_hyp_idx()) + self.beam_decoder_computer.select_beam_in_state_item_(state.hyp_decoding_state, state.get_best_hyp_idx()) state.update_(eou_detected) def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: @@ -462,6 +463,7 @@ def cache_aware_transcribe_step( else: state.set_previous_hypothesis(hyp) + # run greedy decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): eou_detected = self.run_greedy_decoder(state, request, hyp) if self.beam_decoder_computer is not None: @@ -471,10 +473,13 @@ def cache_aware_transcribe_step( state.cleanup_after_eou() ready_state_ids.add(request.stream_id) - if self.beam_decoder_computer is not None: - for state, eos in zip(states, eos_flags): - if eos: - state.reset_beam_decoding_state_() + for state, eos in zip(states, eos_flags): + if not eos: + continue + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model) + if self.beam_decoder_computer is not None: + state.reset_beam_decoding_state_() def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index adf28a42363c..f3785b6bc971 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -1701,8 +1701,14 @@ def collapse_batched_state_to_beams_( batched_hyps.keep_beam_(beam_indices) - def collapse_state_item_to_top1_(self, item: MALSDStateItem, beam_index: int) -> None: - """In-place per-stream collapse to one beam (used at EOU in streaming).""" + def select_beam_in_state_item_(self, item: MALSDStateItem, beam_index: int) -> None: + """In-place per-stream beam selection (used at EOU in streaming). + + Selects ``beam_index`` and replicates that beam's decoder carry across all + ``beam_size`` slots. Beam width is unchanged; every slot holds the same + predictor, fusion, and score state so the next decode step starts from one + committed hypothesis. + """ beam_size = self.beam_size if not 0 <= beam_index < beam_size: raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") From 7840a2211536b54bb10fac5499dad3798191f240 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:31:24 +0400 Subject: [PATCH 26/28] minor clean up Signed-off-by: lilithgrigoryan --- .../inference/pipelines/cache_aware_rnnt_pipeline.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index bf8bc5fcc6e3..8b51ce22a31c 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -87,11 +87,15 @@ def __init__( self.init_nmt_model(nmt_model) self.init_decoding_computer() strategy = str(getattr(cfg.asr.decoding, "strategy", "greedy_batch")) - if strategy in {"beam", "tsd", "alsd", "maes", "maes_batch"}: - logging.warning( - "Cache-aware RNNT streaming supports MALSD beam search only (`malsd_batch`); " + if strategy not in {"greedy_batch", "malsd_batch"}: + raise ValueError( + "Cache-aware RNNT streaming supports `greedy_batch` and `malsd_batch` only; " f"configured decoding strategy is `{strategy}`." ) + if self.beam_decoder_computer is not None and self.prompt_enabled: + raise ValueError( + "Cache-aware RNNT MALSD beam search does not yet support prompt vectors." + ) super().__init__() def init_decoding_computer(self) -> None: From fadef4e9217cf4515ae53fbbaa2940c04da932c4 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:41:42 +0400 Subject: [PATCH 27/28] clean up Signed-off-by: lilithgrigoryan --- .../pipelines/cache_aware_rnnt_pipeline.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 8b51ce22a31c..26b7c59c7a4a 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -477,13 +477,17 @@ def cache_aware_transcribe_step( state.cleanup_after_eou() ready_state_ids.add(request.stream_id) - for state, eos in zip(states, eos_flags): - if not eos: - continue - if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: - release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model) - if self.beam_decoder_computer is not None: - state.reset_beam_decoding_state_() + # Cleanup per-stream biasing models when stream ends + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + for request, state in zip(requests, states): + # only the first request contains biasing options; biasing options for the stream are stored in state + if request.is_last and state.has_biasing_request(): + release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model) + + if self.beam_decoder_computer is not None: + for state, eos in zip(states, eos_flags): + if eos: + state.reset_beam_decoding_state_() def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ From b2b7116ddce22fbb780106b06f460ea6504861da Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Fri, 19 Jun 2026 00:42:20 +0400 Subject: [PATCH 28/28] isort and black Signed-off-by: lilithgrigoryan --- .../asr/inference/pipelines/cache_aware_rnnt_pipeline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 26b7c59c7a4a..005ec62456ec 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -93,9 +93,7 @@ def __init__( f"configured decoding strategy is `{strategy}`." ) if self.beam_decoder_computer is not None and self.prompt_enabled: - raise ValueError( - "Cache-aware RNNT MALSD beam search does not yet support prompt vectors." - ) + raise ValueError("Cache-aware RNNT MALSD beam search does not yet support prompt vectors.") super().__init__() def init_decoding_computer(self) -> None: