diff --git a/monai/apps/__init__.py b/monai/apps/__init__.py index 9cc7aeb8e0..e099efbe78 100644 --- a/monai/apps/__init__.py +++ b/monai/apps/__init__.py @@ -13,4 +13,13 @@ from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset from .mmars import MODEL_DESC, RemoteMMARKeys, download_mmar, get_model_spec, load_from_mmar -from .utils import SUPPORTED_HASH_TYPES, check_hash, download_and_extract, download_url, extractall, get_logger, logger +from .utils import ( + SUPPORTED_HASH_TYPES, + HashMismatchError, + check_hash, + download_and_extract, + download_url, + extractall, + get_logger, + logger, +) diff --git a/monai/apps/utils.py b/monai/apps/utils.py index 856bc64c9e..09af7549a7 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -42,7 +42,20 @@ else: tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm") -__all__ = ["check_hash", "download_url", "extractall", "download_and_extract", "get_logger", "SUPPORTED_HASH_TYPES"] +__all__ = [ + "check_hash", + "download_url", + "extractall", + "download_and_extract", + "get_logger", + "HashMismatchError", + "SUPPORTED_HASH_TYPES", +] + + +class HashMismatchError(RuntimeError): + """Raised when the hash of a downloaded file does not match the expected value.""" + DEFAULT_FMT = "%(asctime)s - %(levelname)s - %(message)s" SUPPORTED_HASH_TYPES = {"md5": hashlib.md5, "sha1": hashlib.sha1, "sha256": hashlib.sha256, "sha512": hashlib.sha512} @@ -268,7 +281,7 @@ def download_url( pass logger.info(f"Downloaded: {filepath}") if not check_hash(filepath, hash_val, hash_type): - raise RuntimeError( + raise HashMismatchError( f"{hash_type} check of downloaded file failed: URL={url}, " f"filepath={filepath}, expected {hash_type}={hash_val}." ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 05f7cb88d9..f35341d5e3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -41,7 +41,7 @@ import torch import torch.distributed as dist -from monai.apps.utils import download_url +from monai.apps.utils import HashMismatchError, download_url from monai.config import NdarrayTensor from monai.config.deviceconfig import USE_COMPILED from monai.config.type_definitions import NdarrayOrTensor @@ -79,7 +79,6 @@ "unexpected EOF", # incomplete download "network issue", "gdown dependency", # gdown not installed - "md5 check", "limit", # HTTP Error 503: Egress is over the account limit "authenticate", "timed out", # urlopen error [Errno 110] Connection timed out @@ -171,6 +170,8 @@ def skip_if_downloading_fails(): try: yield + except HashMismatchError: + raise except DOWNLOAD_EXCEPTS as e: raise unittest.SkipTest(f"Error while downloading: {e}") from e except ssl.SSLError as ssl_e: @@ -206,7 +207,7 @@ def test_download_url(self): hash_val=SAMPLE_TIFF_HASH, hash_type=SAMPLE_TIFF_HASH_TYPE, ) - with self.assertRaises(RuntimeError): + with self.assertRaises(HashMismatchError): download_url( url=SAMPLE_TIFF, filepath=os.path.join(tempdir, "model_bad.tiff"),