diff --git a/pyrit/common/safe_extract.py b/pyrit/common/safe_extract.py new file mode 100644 index 0000000000..f83e08e841 --- /dev/null +++ b/pyrit/common/safe_extract.py @@ -0,0 +1,204 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Defensive ZIP extraction for untrusted remote archives. + +Remote dataset loaders in PyRIT download ZIP archives from third-party sources +and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not +validate member paths, file sizes, or entry types, which leaves the loader +vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if +any upstream source is tampered with. + +``safe_extract_zip`` validates every archive member before writing anything to +disk. If any member fails validation, no archive members are written from the +failing call (pre-existing contents of ``dest_dir`` are untouched). +""" + +from __future__ import annotations + +import io +import logging +import os +import stat +import zipfile +from pathlib import Path +from typing import IO + +logger = logging.getLogger(__name__) + +# 5 GiB cumulative uncompressed size across all members +DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3 +# 1 GiB cap on any single member +DEFAULT_MAX_FILE_SIZE = 1 * 1024**3 +# 50_000 entries: above legitimate dataset sizes, defeats inode DoS +DEFAULT_MAX_FILE_COUNT = 50_000 +# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb) +DEFAULT_MAX_COMPRESSION_RATIO = 100 + +# Sanitized permissions applied to extracted entries, stripping any setuid / +# setgid / sticky / world-write bits the archive may have requested. +_EXTRACTED_FILE_MODE = 0o644 +_EXTRACTED_DIR_MODE = 0o755 + +# Predicates for entry types we refuse to extract. +_DISALLOWED_TYPE_PREDICATES = ( + stat.S_ISLNK, + stat.S_ISBLK, + stat.S_ISCHR, + stat.S_ISFIFO, + stat.S_ISSOCK, +) + +ZipSource = str | os.PathLike | bytes | IO[bytes] + + +class UnsafeArchiveError(Exception): + """Raised when an archive member fails a safe-extraction precondition.""" + + +def safe_extract_zip( + *, + source: ZipSource, + dest_dir: str | os.PathLike, + max_total_size: int = DEFAULT_MAX_TOTAL_SIZE, + max_file_size: int = DEFAULT_MAX_FILE_SIZE, + max_file_count: int = DEFAULT_MAX_FILE_COUNT, + max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO, +) -> Path: + """ + Extract a ZIP archive after validating every member. + + Validation runs in a single pass over the archive's central directory + before any bytes are written. If any check fails, ``UnsafeArchiveError`` is + raised and no archive members are written from this call. After extraction + each member's filesystem mode is replaced with a sanitized default so a + tampered archive cannot set setuid/setgid/sticky/exec bits on the host. + + Args: + source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``. + dest_dir: Directory to extract into. Created if it does not exist. + max_total_size: Cap on the sum of uncompressed member sizes. + max_file_size: Cap on any single member's uncompressed size. + max_file_count: Cap on the number of members in the archive. + max_compression_ratio: Reject members whose uncompressed/compressed + ratio exceeds this value (zip bomb defense). + + Returns: + Resolved destination directory. + + Raises: + UnsafeArchiveError: If any member fails validation. + """ + if isinstance(source, (bytes, bytearray)): + source = io.BytesIO(source) + + dest_real = Path(dest_dir).resolve() + dest_real.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(source) as zf: + members = zf.infolist() + try: + _validate_members( + members, + dest_real=dest_real, + max_total_size=max_total_size, + max_file_size=max_file_size, + max_file_count=max_file_count, + max_compression_ratio=max_compression_ratio, + ) + except UnsafeArchiveError as exc: + logger.warning("safe_extract_zip rejected archive: %s", exc) + raise + for m in members: + extracted = Path(zf.extract(m, dest_real)) + _sanitize_extracted_permissions(extracted) + + return dest_real + + +def _sanitize_extracted_permissions(path: Path) -> None: + # zipfile.ZipFile.extract applies the archive's external_attr mode bits on + # POSIX, so a tampered archive can request setuid/setgid/sticky or + # executable bits on extracted entries. Replace with a sane default. + try: + if path.is_dir(): + os.chmod(path, _EXTRACTED_DIR_MODE) + else: + os.chmod(path, _EXTRACTED_FILE_MODE) + except OSError as exc: + logger.warning("safe_extract_zip could not chmod %s: %s", path, exc) + + +def _validate_members( + members: list[zipfile.ZipInfo], + *, + dest_real: Path, + max_total_size: int, + max_file_size: int, + max_file_count: int, + max_compression_ratio: int, +) -> None: + if len(members) > max_file_count: + raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})") + + total = 0 + for m in members: + _reject_disallowed_entry_type(m) + _reject_absolute_path(m) + _reject_path_traversal(m, dest_real) + _reject_oversized_member(m, max_file_size=max_file_size) + _reject_compression_bomb(m, max_ratio=max_compression_ratio) + + total += m.file_size + if total > max_total_size: + raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes") + + +def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None: + # The upper 16 bits of external_attr hold the Unix mode when the archive + # was created on a Unix system. Check unconditionally because create_system + # is attacker-controlled metadata: a zip crafted with create_system=0 (DOS) + # but Unix-style mode bits set should still be rejected. + mode = m.external_attr >> 16 + if any(predicate(mode) for predicate in _DISALLOWED_TYPE_PREDICATES): + raise UnsafeArchiveError(f"disallowed entry type: {m.filename}") + + +def _reject_absolute_path(m: zipfile.ZipInfo) -> None: + name = m.filename + if name.startswith(("/", "\\")): + raise UnsafeArchiveError(f"absolute path in archive: {name}") + if len(name) >= 2 and name[1] == ":": + raise UnsafeArchiveError(f"drive-letter path in archive: {name}") + + +def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None: + try: + target = (dest_real / m.filename).resolve() + except ValueError as exc: + # Path raises ValueError on null bytes and other invalid path characters. + raise UnsafeArchiveError(f"invalid characters in archive entry: {m.filename!r}") from exc + try: + target.relative_to(dest_real) + except ValueError as exc: + raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc + + +def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None: + if m.file_size > max_file_size: + raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}") + + +def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None: + if m.file_size <= 0: + return + if m.compress_size <= 0: + # Declared non-zero uncompressed size with zero compressed size is + # malformed metadata, refuse rather than skip the ratio check. + raise UnsafeArchiveError( + f"member {m.filename!r} declares uncompressed size {m.file_size} but compressed size {m.compress_size}" + ) + ratio = m.file_size / m.compress_size + if ratio > max_ratio: + raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}") diff --git a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py index 3a7e10a34b..12866279b5 100644 --- a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py @@ -6,7 +6,6 @@ import logging import re import uuid -import zipfile from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -15,6 +14,7 @@ from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.common.path import DB_DATA_PATH +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, ) @@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path: zip_bytes = response.content def _extract() -> None: - extract_dir.mkdir(parents=True, exist_ok=True) - with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: - zf.extractall(extract_dir) + safe_extract_zip(source=io.BytesIO(zip_bytes), dest_dir=extract_dir) await asyncio.to_thread(_extract) return extract_dir diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py index 251cfb5405..07306f2bc3 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import asyncio import logging import pathlib import uuid -import zipfile from enum import Enum from typing import Literal +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) @@ -149,8 +150,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: # Only unzip if the target directory does not already exist if not zip_extracted_path.exists(): logger.info(f"Extracting {zip_file_path} to {self.zip_dir}") - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(self.zip_dir) + await asyncio.to_thread(safe_extract_zip, source=zip_file_path, dest_dir=self.zip_dir) try: logger.info(f"Loading JailBreakV-28K dataset from {self.source}") diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index b5e000e6c7..ccafe058c9 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -6,7 +6,6 @@ import logging import os import uuid -import zipfile from enum import Enum from pathlib import Path from typing import TYPE_CHECKING @@ -15,6 +14,7 @@ from typing_extensions import override from pyrit.common.path import DB_DATA_PATH +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) @@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]: zip_path = cache_dir / "test.zip" if zip_path.exists(): logger.info("Extracting VLGuard test images...") - with zipfile.ZipFile(str(zip_path), "r") as zf: - zf.extractall(str(cache_dir)) + await asyncio.to_thread(safe_extract_zip, source=zip_path, dest_dir=cache_dir) with open(json_path, encoding="utf-8") as f: metadata = json.load(f) diff --git a/tests/unit/common/test_safe_extract.py b/tests/unit/common/test_safe_extract.py new file mode 100644 index 0000000000..651f535109 --- /dev/null +++ b/tests/unit/common/test_safe_extract.py @@ -0,0 +1,245 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import stat +import zipfile + +import pytest + +from pyrit.common.safe_extract import ( + DEFAULT_MAX_COMPRESSION_RATIO, + UnsafeArchiveError, + safe_extract_zip, +) + + +def _zip_with(entries): + """ + Build an in-memory zip. + + entries: list of (filename, data, external_attr_mode_or_None) + """ + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + for name, data, mode in entries: + info = zipfile.ZipInfo(name) + info.create_system = 3 # unix, so external_attr is interpreted as mode + if mode is not None: + info.external_attr = mode << 16 + zf.writestr(info, data) + buf.seek(0) + return buf + + +def test_happy_path_extracts_files(tmp_path): + archive = _zip_with( + [ + ("a.txt", b"hello", None), + ("nested/b.txt", b"world", None), + ] + ) + out = safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + assert (out / "a.txt").read_bytes() == b"hello" + assert (out / "nested" / "b.txt").read_bytes() == b"world" + + +def test_rejects_dotdot_traversal(tmp_path): + archive = _zip_with([("../escape.txt", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="path traversal"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + # destination should be created but empty + assert list((tmp_path / "out").iterdir()) == [] + + +def test_rejects_absolute_unix_path(tmp_path): + archive = _zip_with([("/etc/passwd", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="absolute path"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + +def test_rejects_drive_letter_path(tmp_path): + archive = _zip_with([("C:/windows/system32/x.dll", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="drive-letter"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + +def test_rejects_symlink_entry(tmp_path): + archive = _zip_with([("link", b"../target", stat.S_IFLNK | 0o777)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + +def test_rejects_device_entry(tmp_path): + archive = _zip_with([("dev", b"", stat.S_IFBLK | 0o600)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + +def test_rejects_fifo_entry(tmp_path): + archive = _zip_with([("pipe", b"", stat.S_IFIFO | 0o600)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + + +def test_rejects_total_size_bomb(tmp_path): + archive = _zip_with([(f"f{i}.txt", b"A" * 1000, None) for i in range(5)]) + with pytest.raises(UnsafeArchiveError, match="total uncompressed size"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out", max_total_size=2000) + + +def test_rejects_single_file_bomb(tmp_path): + archive = _zip_with([("big.bin", b"A" * 1000, None)]) + with pytest.raises(UnsafeArchiveError, match="exceeds cap"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out", max_file_size=500) + + +def test_rejects_compression_ratio_bomb(tmp_path): + # DEFLATE 1 MiB of zeros into a few hundred bytes, classic ratio bomb. + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED, compresslevel=9) as zf: + info = zipfile.ZipInfo("bomb.bin") + info.create_system = 3 + info.compress_type = zipfile.ZIP_DEFLATED + zf.writestr(info, b"\x00" * (1024 * 1024)) + buf.seek(0) + + with pytest.raises(UnsafeArchiveError, match="compression ratio"): + safe_extract_zip( + source=buf, + dest_dir=tmp_path / "out", + max_compression_ratio=DEFAULT_MAX_COMPRESSION_RATIO, + max_file_size=10 * 1024 * 1024, + ) + + +def test_rejects_excessive_file_count(tmp_path): + archive = _zip_with([(f"f{i}.txt", b"x", None) for i in range(10)]) + with pytest.raises(UnsafeArchiveError, match="entries"): + safe_extract_zip(source=archive, dest_dir=tmp_path / "out", max_file_count=5) + + +def test_no_partial_write_when_one_member_invalid(tmp_path): + # First 2 entries are valid, third escapes, nothing should be written. + archive = _zip_with( + [ + ("ok1.txt", b"one", None), + ("ok2.txt", b"two", None), + ("../escape.txt", b"bad", None), + ] + ) + out = tmp_path / "out" + with pytest.raises(UnsafeArchiveError): + safe_extract_zip(source=archive, dest_dir=out) + + assert list(out.iterdir()) == [] + + +def test_accepts_bytes_source(tmp_path): + buf = _zip_with([("a.txt", b"hi", None)]) + out = safe_extract_zip(source=buf.getvalue(), dest_dir=tmp_path / "out") + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_accepts_path_source(tmp_path): + zip_path = tmp_path / "src.zip" + zip_path.write_bytes(_zip_with([("a.txt", b"hi", None)]).getvalue()) + + out = safe_extract_zip(source=zip_path, dest_dir=tmp_path / "out") + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_destination_dir_is_created(tmp_path): + archive = _zip_with([("a.txt", b"hi", None)]) + target = tmp_path / "does" / "not" / "exist" + + out = safe_extract_zip(source=archive, dest_dir=target) + assert out.is_dir() + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_returns_resolved_destination(tmp_path): + archive = _zip_with([("a.txt", b"hi", None)]) + out = safe_extract_zip(source=archive, dest_dir=tmp_path / "out") + assert out == (tmp_path / "out").resolve() + assert out.is_absolute() + + +def test_path_traversal_check_handles_invalid_chars(tmp_path): + # Python's zipfile reader truncates filenames at null bytes, so this can't + # be triggered through a real archive — but the validator should still + # surface UnsafeArchiveError rather than leak a ValueError if a future + # caller hands us a manually-built ZipInfo with such a name. + from pyrit.common.safe_extract import _reject_path_traversal + + info = zipfile.ZipInfo() + info.filename = "foo\x00.txt" + + with pytest.raises(UnsafeArchiveError, match="invalid characters"): + _reject_path_traversal(info, tmp_path.resolve()) + + +def test_rejects_symlink_when_create_system_is_dos(tmp_path): + # Adversary sets create_system=0 (DOS) but with Unix S_IFLNK upper bits. + # Helper must still reject because create_system is attacker-controlled. + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + info = zipfile.ZipInfo("link") + info.create_system = 0 + info.external_attr = (stat.S_IFLNK | 0o777) << 16 + zf.writestr(info, b"/etc/passwd") + buf.seek(0) + + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(source=buf, dest_dir=tmp_path / "out") + + +def test_directory_entry_happy_path(tmp_path): + # Explicit directory entry (filename ending in '/') plus a regular file inside it. + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + dir_info = zipfile.ZipInfo("subdir/") + dir_info.create_system = 3 + zf.writestr(dir_info, b"") + file_info = zipfile.ZipInfo("subdir/file.txt") + file_info.create_system = 3 + zf.writestr(file_info, b"hi") + buf.seek(0) + + out = safe_extract_zip(source=buf, dest_dir=tmp_path / "out") + assert (out / "subdir").is_dir() + assert (out / "subdir" / "file.txt").read_bytes() == b"hi" + + +def test_rejects_directory_entry_with_traversal(tmp_path): + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + info = zipfile.ZipInfo("../escape/") + info.create_system = 3 + zf.writestr(info, b"") + buf.seek(0) + + with pytest.raises(UnsafeArchiveError, match="path traversal"): + safe_extract_zip(source=buf, dest_dir=tmp_path / "out") + + +def test_rejects_zero_compress_size_with_nonzero_file_size(tmp_path): + # Malformed metadata: declared compress_size=0 but file_size>0. + # This is the bypass path previously short-circuited by the ratio check. + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + info = zipfile.ZipInfo("malformed.bin") + info.create_system = 3 + zf.writestr(info, b"X" * 100) + raw = bytearray(buf.getvalue()) + # Patch the central-directory entry: set compress_size to 0, leave file_size. + import struct + + idx = raw.rfind(b"PK\x01\x02") + assert idx != -1, "central directory signature missing, zip layout assumption is stale" + struct.pack_into("