Skip to content

fix(cpu): enable robust CPU support and address PR feedback#2958

Open
Manamama-Gemini-Cloud-AI-01 wants to merge 10 commits into
modelscope:mainfrom
Manamama-Gemini-Cloud-AI-01:cpu-support-patch-v3
Open

fix(cpu): enable robust CPU support and address PR feedback#2958
Manamama-Gemini-Cloud-AI-01 wants to merge 10 commits into
modelscope:mainfrom
Manamama-Gemini-Cloud-AI-01:cpu-support-patch-v3

Conversation

@Manamama-Gemini-Cloud-AI-01
Copy link
Copy Markdown
Contributor

This PR enables robust CPU support for the FunASR Nano real-time server.

Key changes:

  • Dynamically bypasses vLLM when the device is set to a non-CUDA device (e.g., 'cpu').
  • Implements robust inference result extraction to handle different return types between AutoModel (PyTorch) and AutoModelVLLM.
  • Improves tokenizer access safety and input buffer handling (AI review feedback).
  • Optimizes hallucination detection.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a streaming WebSocket server for Fun-ASR-Nano, featuring VAD segmentation, ASR decoding, speaker diarization, and hallucination detection. Key feedback focuses on optimizing the audio buffer concatenation to prevent O(N^2) copy overhead, correcting the parameter name hotwords to hotword for standard AutoModel compatibility, fixing the hallucination truncation logic to keep exactly one occurrence of repeated patterns, and improving robustness through safer attribute access and exception handling around tokenizer operations.

Comment on lines +249 to +271
def add_audio(self, pcm_bytes):
if len(pcm_bytes) % 2 != 0:
pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)]
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
audio_float = audio_int16.astype(np.float32) / 32768.0
self.audio_buffer = np.concatenate([self.audio_buffer, audio_float])

new_audio = self.audio_buffer[self.vad_fed_samples:]
if len(new_audio) > 0:
new_confirmed = self.vad.feed(torch.from_numpy(new_audio).float(), is_final=False)
self.vad_fed_samples = len(self.audio_buffer)

for seg in new_confirmed:
seg_text = self._decode_segment(seg)
self.prev_text = ""
if not seg_text.strip():
continue
self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])})
if self.spk_tracker:
s0 = int(seg[0] * self.sample_rate / 1000)
s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer))
self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1])
logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Appending and concatenating the audio buffer on every single incoming WebSocket packet (which can be as frequent as every 10-20ms) results in an O(N^2) copy overhead. Over a long streaming session, this will cause severe CPU spikes, massive latency, and potential server freezes.

We can optimize this by keeping a list of chunks (self.audio_chunks) and only concatenating them on-demand when the contiguous audio_buffer is actually accessed (e.g., during decoding or speaker diarization). This is implemented cleanly using a property and setter.

    @property
    def audio_buffer(self):
        if getattr(self, '_audio_buffer_dirty', False):
            if self.audio_chunks:
                self._audio_buffer = np.concatenate([self._audio_buffer, *self.audio_chunks])
                self.audio_chunks = []
            self._audio_buffer_dirty = False
        return self._audio_buffer

    @audio_buffer.setter
    def audio_buffer(self, value):
        self._audio_buffer = value
        self.audio_chunks = []
        self._audio_buffer_dirty = False

    def add_audio(self, pcm_bytes):
        if len(pcm_bytes) % 2 != 0:
            pcm_bytes = pcm_bytes[:len(pcm_bytes) - (len(pcm_bytes) % 2)]
        if not pcm_bytes:
            return
        audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
        audio_float = audio_int16.astype(np.float32) / 32768.0
        
        self.audio_chunks.append(audio_float)
        self._audio_buffer_dirty = True

        new_confirmed = self.vad.feed(torch.from_numpy(audio_float).float(), is_final=False)
        self.vad_fed_samples += len(audio_float)

        for seg in new_confirmed:
            seg_text = self._decode_segment(seg)
            self.prev_text = ""
            if not seg_text.strip():
                continue
            self.locked_sentences.append({"text": seg_text, "start": int(seg[0]), "end": int(seg[1])})
            if self.spk_tracker:
                s0 = int(seg[0] * self.sample_rate / 1000)
                s1 = min(int(seg[1] * self.sample_rate / 1000), len(self.audio_buffer))
                self.spk_tracker.assign_streaming(self.audio_buffer[s0:s1], seg[0]/1000, seg[1]/1000, self.locked_sentences[-1])
            logger.info(f"Locked: [{seg[0]}-{seg[1]}ms] \"{seg_text[:40]}\"")

Comment on lines +320 to +324
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Standard AutoModel (used on non-CUDA devices like CPU) expects the parameter name hotword (singular) instead of hotwords (plural) as documented in auto_model.py. Passing hotwords will be ignored, causing hotword customization to fail on CPU.

Suggested change
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
results = self.vllm_engine.generate(
input=audio_tensor,
hotword=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Comment on lines +385 to +389
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly, standard AutoModel expects hotword instead of hotwords here.

Suggested change
results = self.vllm_engine.generate(
input=audio_tensor,
hotwords=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)
results = self.vllm_engine.generate(
input=audio_tensor,
hotword=self.asr_kwargs.get("hotwords"),
language=self.asr_kwargs.get("language"),
)

Comment on lines +41 to +44
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of hallucination truncation keeps two occurrences of the repeated pattern instead of one, because it slices up to end_pos + len(repeated). Slicing up to end_pos (the start of the second occurrence) will correctly keep exactly one occurrence of the pattern as intended by the docstring.

Suggested change
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos], True

Comment on lines +53 to +57
pos = text.find(repeated)
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, this truncation keeps two occurrences of the repeated pattern instead of one. Slicing up to end_pos will correctly keep exactly one occurrence.

Suggested change
pos = text.find(repeated)
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos + len(repeated)], True
if pos >= 0:
end_pos = text.find(repeated, pos + len(repeated))
if end_pos >= 0:
return text[:end_pos], True

Comment on lines +354 to +358
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = self.vllm_engine.kwargs.get("tokenizer")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using getattr with a default value is safer to prevent potential AttributeError if self.vllm_engine does not have the kwargs attribute.

Suggested change
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = self.vllm_engine.kwargs.get("tokenizer")
if hasattr(self.vllm_engine, '_engine'):
tokenizer = self.vllm_engine._engine.tokenizer
else:
tokenizer = getattr(self.vllm_engine, 'kwargs', {}).get("tokenizer")

Comment on lines +359 to +369
if tokenizer is not None:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
else:
self.prev_text = ""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Wrapping the tokenizer's encode/decode operations in a try...except block prevents the entire WebSocket session from crashing if the tokenizer fails or has an unexpected interface.

Suggested change
if tokenizer is not None:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
else:
self.prev_text = ""
if tokenizer is not None:
try:
encoded = tokenizer.encode(text)
if len(encoded) > 5:
try:
self.prev_text = tokenizer.decode(encoded[:-5], skip_special_tokens=True)
except TypeError:
self.prev_text = tokenizer.decode(encoded[:-5])
else:
self.prev_text = ""
except Exception as e:
logger.warning(f"Failed to encode/decode text with tokenizer: {e}")
self.prev_text = ""
else:
self.prev_text = ""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant