diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 9bf1ad57e..81ecaa879 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -29,7 +29,11 @@ from funasr.train_utils.load_pretrained_model import load_pretrained_model from funasr.utils import export_utils from funasr.utils import misc - +try: + import torch_npu + npu_is_available = torch_npu.npu.is_available() +except ImportError: + npu_is_available = False def _resolve_ncpu(config, fallback=4): """Return a positive integer representing CPU threads from config.""" @@ -199,6 +203,7 @@ def build_model(**kwargs): if ((device =="cuda" and not torch.cuda.is_available()) or (device == "xpu" and not torch.xpu.is_available()) or (device == "mps" and not torch.backends.mps.is_available()) + or (device == "npu" and not npu_is_available) or kwargs.get("ngpu", 1) == 0): device = "cpu" kwargs["batch_size"] = 1