-
Notifications
You must be signed in to change notification settings - Fork 3.4k
add streaming beam search for cache aware models to NeMo inference #15768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3a6f29e
03937b0
e889eea
e51ee3c
3765acc
0e11a4f
913000a
6dc1423
0a69dee
e051b12
4dce1e6
893f656
664a246
b342a16
b9a31a4
0ed3d92
56816df
f6da7a5
22f5f7d
a299ee4
957084f
49eb9fe
7be3088
629de90
c59bc00
36e1702
3656938
7840a22
fadef4e
b2b7116
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,13 @@ | |
| from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( | ||
| CacheAwareASRInferenceWrapper, | ||
| ) | ||
| from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState | ||
| from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext | ||
| from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states | ||
| from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel | ||
| from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder | ||
| from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer | ||
| from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists | ||
| from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis | ||
|
|
||
|
|
||
|
|
@@ -76,32 +80,19 @@ def get_vocabulary(self) -> list[str]: | |
| """ | ||
| return self.asr_model.joint.vocabulary | ||
|
|
||
| def execute_step( | ||
| def encoder_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| previous_hypotheses: list[Hypothesis] | None, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| prompt_vectors: Tensor | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| ) -> tuple[Tensor, Tensor, CacheAwareContext]: | ||
| """ | ||
| Executes a single streaming step. | ||
| Args: | ||
| processed_signal: (Tensor) input signal tensor. | ||
| processed_signal_length: (Tensor) input signal length tensor. | ||
| context: (CacheAwareContext) context object. | ||
| previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. | ||
| drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. | ||
| keep_all_outputs: (bool) whether to keep all outputs or not. | ||
| drop_left_context: (int | None) number of left context frames to drop. | ||
| valid_out_len: (int | None) number of valid output frames. | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. | ||
| Returns: | ||
| (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. | ||
| Run the cache-aware encoder for one streaming chunk, returning the (trimmed) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency, please bring back the argument descriptions in the docstring.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe, I kept the original docstring for LMK, if I miss something.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Argument descriptions are missing for |
||
| encoder output and updated streaming context. Decoder is NOT invoked. | ||
| """ | ||
| ( | ||
| encoded, | ||
|
|
@@ -134,11 +125,116 @@ def execute_step( | |
| encoded = encoded[:, :, :valid_out_len] | ||
| encoded_len = torch.ones_like(encoded_len) * valid_out_len | ||
|
|
||
| return encoded, encoded_len, new_context | ||
|
|
||
| def execute_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| previous_hypotheses: list[Hypothesis] | None, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| prompt_vectors: Tensor | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| """ | ||
| Executes a single streaming step. | ||
| Args: | ||
| processed_signal: (Tensor) input signal tensor. | ||
| processed_signal_length: (Tensor) input signal length tensor. | ||
| context: (CacheAwareContext) context object. | ||
| previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. | ||
| drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. | ||
| keep_all_outputs: (bool) whether to keep all outputs or not. | ||
| drop_left_context: (int | None) number of left context frames to drop. | ||
| valid_out_len: (int | None) number of valid output frames. | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. | ||
| Returns: | ||
| (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. | ||
| """ | ||
| encoded, encoded_len, new_context = self.encoder_step( | ||
| processed_signal=processed_signal, | ||
| processed_signal_length=processed_signal_length, | ||
| context=context, | ||
| drop_extra_pre_encoded=drop_extra_pre_encoded, | ||
| keep_all_outputs=keep_all_outputs, | ||
| drop_left_context=drop_left_context, | ||
| valid_out_len=valid_out_len, | ||
| ) | ||
|
|
||
| best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( | ||
| encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses | ||
| ) | ||
| return best_hyp, new_context | ||
|
|
||
| def malsd_stream_step( | ||
| self, | ||
| malsd_computer: ModifiedALSDBatchedRNNTComputer, | ||
| states: list[CacheAwareRNNTBeamStreamingState], | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| """Cache-aware MALSD encode/decode step for one chunk.""" | ||
| if processed_signal.device != self.device: | ||
| processed_signal = processed_signal.to(self.device) | ||
|
|
||
| if processed_signal_length.device != self.device: | ||
| processed_signal_length = processed_signal_length.to(self.device) | ||
|
|
||
| carries = [state.hyp_decoding_state for state in states] | ||
| if all(c is None for c in carries): | ||
| batched_state = None | ||
| else: | ||
| batched_state = malsd_computer.merge_to_batched_state(carries) | ||
|
|
||
| multi_biasing_ids = multi_biasing_ids_tensor_from_states( | ||
| states, | ||
| self.device, | ||
| per_stream_biasing_enabled=malsd_computer.per_stream_biasing_enabled, | ||
| ) | ||
|
|
||
| with ( | ||
| torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), | ||
| torch.inference_mode(), | ||
| torch.no_grad(), | ||
| ): | ||
| processed_signal = processed_signal.to(self.cast_dtype) | ||
| encoded, encoded_len, new_context = self.encoder_step( | ||
| processed_signal=processed_signal, | ||
| processed_signal_length=processed_signal_length, | ||
| context=context, | ||
| drop_extra_pre_encoded=drop_extra_pre_encoded, | ||
| keep_all_outputs=keep_all_outputs, | ||
| drop_left_context=drop_left_context, | ||
| valid_out_len=valid_out_len, | ||
| ) | ||
| encs_dim_last = encoded.transpose(1, 2).contiguous() | ||
|
|
||
| best_batched_hyps, batched_state = malsd_computer( | ||
| encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids | ||
| ) | ||
|
|
||
| chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) | ||
| beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() | ||
| scores_cpu = best_batched_hyps.scores.detach().cpu() | ||
|
|
||
| carry_items = malsd_computer.split_batched_state(batched_state) | ||
| for state, ct, cts, rp, best_hyp_idx, carry in zip( | ||
| states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items | ||
| ): | ||
| state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, best_hyp_idx) | ||
| state.hyp_decoding_state = carry | ||
|
|
||
| hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] | ||
| return hyps, new_context | ||
|
|
||
| def stream_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.