Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils import export_utils
from funasr.utils import misc

try:
import torch_npu
npu_is_available = torch_npu.npu.is_available()
except ImportError:
npu_is_available = False
Comment on lines +32 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了与 torch.cuda.is_available() 等检查保持一致性,并提高代码的封装性,建议将NPU可用性检查封装在一个函数中。这样可以避免在模块加载时就尝试导入 torch_npu,只在需要时执行检查。请在 build_model 方法中也相应地调用此函数。

Suggested change
try:
import torch_npu
npu_is_available = torch_npu.npu.is_available()
except ImportError:
npu_is_available = False
def is_npu_available():
"""检查NPU是否可用。"""
try:
import torch_npu
return torch_npu.npu.is_available()
except ImportError:
return False


def _resolve_ncpu(config, fallback=4):
"""Return a positive integer representing CPU threads from config."""
Expand Down Expand Up @@ -199,6 +203,7 @@ def build_model(**kwargs):
if ((device =="cuda" and not torch.cuda.is_available())
or (device == "xpu" and not torch.xpu.is_available())
or (device == "mps" and not torch.backends.mps.is_available())
or (device == "npu" and not npu_is_available)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了与建议的 is_npu_available() 函数保持一致,请在此处调用该函数进行检查。

Suggested change
or (device == "npu" and not npu_is_available)
or (device == "npu" and not is_npu_available())

or kwargs.get("ngpu", 1) == 0):
device = "cpu"
kwargs["batch_size"] = 1
Expand Down