|
4 | 4 | from einops import repeat, reduce |
5 | 5 | from typing import Optional, Union |
6 | 6 | from dataclasses import dataclass |
| 7 | +from huggingface_hub import snapshot_download as hf_snapshot_download |
7 | 8 | from modelscope import snapshot_download |
8 | 9 | import numpy as np |
9 | 10 | from PIL import Image |
@@ -196,13 +197,24 @@ def download_if_necessary(self, use_usp=False): |
196 | 197 | self.local_model_path = "./models" |
197 | 198 | if not skip_download: |
198 | 199 | downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id)) |
199 | | - snapshot_download( |
200 | | - self.model_id, |
201 | | - local_dir=os.path.join(self.local_model_path, self.model_id), |
202 | | - allow_file_pattern=allow_file_pattern, |
203 | | - ignore_file_pattern=downloaded_files, |
204 | | - local_files_only=False |
205 | | - ) |
| 200 | + if self.download_resource.lower() == "modelscope": |
| 201 | + snapshot_download( |
| 202 | + self.model_id, |
| 203 | + local_dir=os.path.join(self.local_model_path, self.model_id), |
| 204 | + allow_file_pattern=allow_file_pattern, |
| 205 | + ignore_file_pattern=downloaded_files, |
| 206 | + local_files_only=False |
| 207 | + ) |
| 208 | + elif self.download_resource.lower() == "huggingface": |
| 209 | + hf_snapshot_download( |
| 210 | + self.model_id, |
| 211 | + local_dir=os.path.join(self.local_model_path, self.model_id), |
| 212 | + allow_patterns=allow_file_pattern, |
| 213 | + ignore_patterns=downloaded_files, |
| 214 | + local_files_only=False |
| 215 | + ) |
| 216 | + else: |
| 217 | + raise ValueError("`download_resource` should be `modelscope` or `huggingface`.") |
206 | 218 |
|
207 | 219 | # Let rank 1, 2, ... wait for rank 0 |
208 | 220 | if use_usp: |
|
0 commit comments