fix(cpu): enable robust CPU support and address PR feedback#2958
fix(cpu): enable robust CPU support and address PR feedback#2958Manamama-Gemini-Cloud-AI-01 wants to merge 10 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a streaming WebSocket server for Fun-ASR-Nano, featuring VAD segmentation, ASR decoding, speaker diarization, and hallucination detection. Key feedback focuses on optimizing the audio buffer concatenation to prevent O(N^2) copy overhead, correcting the parameter name hotwords to hotword for standard AutoModel compatibility, fixing the hallucination truncation logic to keep exactly one occurrence of repeated patterns, and improving robustness through safer attribute access and exception handling around tokenizer operations.
| def add_audio(self, pcm_bytes): | ||
| if len(pcm_bytes) % 2 != 0: | ||
| pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)] | ||
| audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16) | ||
| audio_float = audio_int16.astype(np.float32) / 32768.0 | ||
| self.audio_buffer = np.concatenate([self.audio_buffer, audio_float]) | ||
|
|
||
| new_audio = self.audio_buffer[self.vad_fed_samples:] | ||
| if len(new_audio) > 0: | ||
| new_confirmed = self.vad.feed(torch.from_numpy(new_audio).float(), is_final=False) | ||
| self.vad_fed_samples = len(self.audio_buffer) | ||
|
|
||
| for seg in new_confirmed: | ||
| seg_text = self._decode_segment(seg) | ||
| self.prev_text = "" | ||
| if not seg_text.strip(): | ||
| continue | ||
| self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])}) | ||
| if self.spk_tracker: | ||
| s0 = int(seg[0] * self.sample_rate / 1000) | ||
| s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer)) | ||
| self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1]) | ||
| logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"") |
There was a problem hiding this comment.
Appending and concatenating the audio buffer on every single incoming WebSocket packet (which can be as frequent as every 10-20ms) results in an O(N^2) copy overhead. Over a long streaming session, this will cause severe CPU spikes, massive latency, and potential server freezes.
We can optimize this by keeping a list of chunks (self.audio_chunks) and only concatenating them on-demand when the contiguous audio_buffer is actually accessed (e.g., during decoding or speaker diarization). This is implemented cleanly using a property and setter.
@property
def audio_buffer(self):
if getattr(self, '_audio_buffer_dirty', False):
if self.audio_chunks:
self._audio_buffer = np.concatenate([self._audio_buffer, *self.audio_chunks])
self.audio_chunks = []
self._audio_buffer_dirty = False
return self._audio_buffer
@audio_buffer.setter
def audio_buffer(self, value):
self._audio_buffer = value
self.audio_chunks = []
self._audio_buffer_dirty = False
def add_audio(self, pcm_bytes):
if len(pcm_bytes) % 2 != 0:
pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)]
if not pcm_bytes:
return
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
audio_float = audio_int16.astype(np.float32) / 32768.0
self.audio_chunks.append(audio_float)
self._audio_buffer_dirty = True
new_confirmed = self.vad.feed(torch.from_numpy(audio_float).float(), is_final=False)
self.vad_fed_samples += len(audio_float)
for seg in new_confirmed:
seg_text = self._decode_segment(seg)
self.prev_text = ""
if not seg_text.strip():
continue
self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])})
if self.spk_tracker:
s0 = int(seg[0] * self.sample_rate / 1000)
s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer))
self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1])
logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"")| results = self.vllm_engine.generate( | ||
| input=audio_tensor, | ||
| hotwords=self.asr_kwargs.get("hotwords"), | ||
| language=self.asr_kwargs.get("language"), | ||
| ) |
There was a problem hiding this comment.
Standard AutoModel (used on non-CUDA devices like CPU) expects the parameter name hotword (singular) instead of hotwords (plural) as documented in auto_model.py. Passing hotwords will be ignored, causing hotword customization to fail on CPU.
| results = self.vllm_engine.generate( | |
| input=audio_tensor, | |
| hotwords=self.asr_kwargs.get("hotwords"), | |
| language=self.asr_kwargs.get("language"), | |
| ) | |
| results = self.vllm_engine.generate( | |
| input=audio_tensor, | |
| hotword=self.asr_kwargs.get("hotwords"), | |
| language=self.asr_kwargs.get("language"), | |
| ) |
| results = self.vllm_engine.generate( | ||
| input=audio_tensor, | ||
| hotwords=self.asr_kwargs.get("hotwords"), | ||
| language=self.asr_kwargs.get("language"), | ||
| ) |
There was a problem hiding this comment.
Similarly, standard AutoModel expects hotword instead of hotwords here.
| results = self.vllm_engine.generate( | |
| input=audio_tensor, | |
| hotwords=self.asr_kwargs.get("hotwords"), | |
| language=self.asr_kwargs.get("language"), | |
| ) | |
| results = self.vllm_engine.generate( | |
| input=audio_tensor, | |
| hotword=self.asr_kwargs.get("hotwords"), | |
| language=self.asr_kwargs.get("language"), | |
| ) |
| if pos >= 0: | ||
| end_pos = text.find(repeated, pos + len(repeated)) | ||
| if end_pos >= 0: | ||
| return text[:end_pos + len(repeated)], True |
There was a problem hiding this comment.
The current implementation of hallucination truncation keeps two occurrences of the repeated pattern instead of one, because it slices up to end_pos + len(repeated). Slicing up to end_pos (the start of the second occurrence) will correctly keep exactly one occurrence of the pattern as intended by the docstring.
| if pos >= 0: | |
| end_pos = text.find(repeated, pos + len(repeated)) | |
| if end_pos >= 0: | |
| return text[:end_pos + len(repeated)], True | |
| if pos >= 0: | |
| end_pos = text.find(repeated, pos + len(repeated)) | |
| if end_pos >= 0: | |
| return text[:end_pos], True |
| pos = text.find(repeated) | ||
| if pos >= 0: | ||
| end_pos = text.find(repeated, pos + len(repeated)) | ||
| if end_pos >= 0: | ||
| return text[:end_pos + len(repeated)], True |
There was a problem hiding this comment.
Similarly, this truncation keeps two occurrences of the repeated pattern instead of one. Slicing up to end_pos will correctly keep exactly one occurrence.
| pos = text.find(repeated) | |
| if pos >= 0: | |
| end_pos = text.find(repeated, pos + len(repeated)) | |
| if end_pos >= 0: | |
| return text[:end_pos + len(repeated)], True | |
| if pos >= 0: | |
| end_pos = text.find(repeated, pos + len(repeated)) | |
| if end_pos >= 0: | |
| return text[:end_pos], True |
| if hasattr(self.vllm_engine, '_engine'): | ||
| tokenizer = self.vllm_engine._engine.tokenizer | ||
| else: | ||
| tokenizer = self.vllm_engine.kwargs.get("tokenizer") | ||
|
|
There was a problem hiding this comment.
Using getattr with a default value is safer to prevent potential AttributeError if self.vllm_engine does not have the kwargs attribute.
| if hasattr(self.vllm_engine, '_engine'): | |
| tokenizer = self.vllm_engine._engine.tokenizer | |
| else: | |
| tokenizer = self.vllm_engine.kwargs.get("tokenizer") | |
| if hasattr(self.vllm_engine, '_engine'): | |
| tokenizer = self.vllm_engine._engine.tokenizer | |
| else: | |
| tokenizer = getattr(self.vllm_engine, 'kwargs', {}).get("tokenizer") |
| if tokenizer is not None: | ||
| encoded = tokenizer.encode(text) | ||
| if len(encoded) > 5: | ||
| try: | ||
| self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True) | ||
| except TypeError: | ||
| self.prev_text = tokenizer.decode(encoded[:-5]) | ||
| else: | ||
| self.prev_text = "" | ||
| else: | ||
| self.prev_text = "" |
There was a problem hiding this comment.
Wrapping the tokenizer's encode/decode operations in a try...except block prevents the entire WebSocket session from crashing if the tokenizer fails or has an unexpected interface.
| if tokenizer is not None: | |
| encoded = tokenizer.encode(text) | |
| if len(encoded) > 5: | |
| try: | |
| self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True) | |
| except TypeError: | |
| self.prev_text = tokenizer.decode(encoded[:-5]) | |
| else: | |
| self.prev_text = "" | |
| else: | |
| self.prev_text = "" | |
| if tokenizer is not None: | |
| try: | |
| encoded = tokenizer.encode(text) | |
| if len(encoded) > 5: | |
| try: | |
| self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True) | |
| except TypeError: | |
| self.prev_text = tokenizer.decode(encoded[:-5]) | |
| else: | |
| self.prev_text = "" | |
| except Exception as e: | |
| logger.warning(f"Failed to encode/decode text with tokenizer: {e}") | |
| self.prev_text = "" | |
| else: | |
| self.prev_text = "" |
This PR enables robust CPU support for the FunASR Nano real-time server.
Key changes:
AutoModel(PyTorch) andAutoModelVLLM.