From bac6c33b46fa46ddbfa96718e4f7213a3bd013c7 Mon Sep 17 00:00:00 2001 From: francose <13445813+francose@users.noreply.github.com> Date: Mon, 8 Jun 2026 17:24:41 -0400 Subject: [PATCH] FEAT Add safe_extract_zip helper for defensive remote ZIP extraction Signed-off-by: francose <13445813+francose@users.noreply.github.com> --- pyrit/common/safe_extract.py | 159 +++++++++++++++++ .../seed_datasets/remote/figstep_dataset.py | 6 +- .../remote/jailbreakv_28k_dataset.py | 5 +- .../seed_datasets/remote/vlguard_dataset.py | 5 +- tests/unit/common/test_safe_extract.py | 166 ++++++++++++++++++ 5 files changed, 331 insertions(+), 10 deletions(-) create mode 100644 pyrit/common/safe_extract.py create mode 100644 tests/unit/common/test_safe_extract.py diff --git a/pyrit/common/safe_extract.py b/pyrit/common/safe_extract.py new file mode 100644 index 0000000000..b7ff99a43c --- /dev/null +++ b/pyrit/common/safe_extract.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Defensive ZIP extraction for untrusted remote archives. + +Remote dataset loaders in PyRIT download ZIP archives from third-party sources +and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not +validate member paths, file sizes, or entry types, which leaves the loader +vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if +any upstream source is tampered with. + +``safe_extract_zip`` validates every archive member before writing anything to +disk. If any member fails validation, the destination directory is left empty. +""" + +from __future__ import annotations + +import io +import logging +import os +import stat +import zipfile +from pathlib import Path +from typing import IO + +logger = logging.getLogger(__name__) + +# 5 GiB cumulative uncompressed size across all members +DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3 +# 1 GiB cap on any single member +DEFAULT_MAX_FILE_SIZE = 1 * 1024**3 +# 50_000 entries: above legitimate dataset sizes, defeats inode DoS +DEFAULT_MAX_FILE_COUNT = 50_000 +# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb) +DEFAULT_MAX_COMPRESSION_RATIO = 100 + +ZipSource = str | os.PathLike | bytes | IO[bytes] + + +class UnsafeArchiveError(Exception): + """Raised when an archive member fails a safe-extraction precondition.""" + + +def safe_extract_zip( + source: ZipSource, + dest_dir: str | os.PathLike, + *, + max_total_size: int = DEFAULT_MAX_TOTAL_SIZE, + max_file_size: int = DEFAULT_MAX_FILE_SIZE, + max_file_count: int = DEFAULT_MAX_FILE_COUNT, + max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO, +) -> Path: + """ + Extract a ZIP archive after validating every member. + + Validation runs in a single pass over the archive's central directory + before any bytes are written. If any check fails, ``UnsafeArchiveError`` is + raised and the destination directory is left without partial output from + this call. + + Args: + source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``. + dest_dir: Directory to extract into. Created if it does not exist. + max_total_size: Cap on the sum of uncompressed member sizes. + max_file_size: Cap on any single member's uncompressed size. + max_file_count: Cap on the number of members in the archive. + max_compression_ratio: Reject members whose uncompressed/compressed + ratio exceeds this value (zip bomb defense). + + Returns: + Resolved destination directory. + + Raises: + UnsafeArchiveError: If any member fails validation. + """ + if isinstance(source, (bytes, bytearray)): + source = io.BytesIO(source) + + dest_real = Path(dest_dir).resolve() + dest_real.mkdir(parents=True, exist_ok=True) + + with zipfile.ZipFile(source) as zf: + members = zf.infolist() + _validate_members( + members, + dest_real=dest_real, + max_total_size=max_total_size, + max_file_size=max_file_size, + max_file_count=max_file_count, + max_compression_ratio=max_compression_ratio, + ) + for m in members: + zf.extract(m, dest_real) + + return dest_real + + +def _validate_members( + members: list[zipfile.ZipInfo], + *, + dest_real: Path, + max_total_size: int, + max_file_size: int, + max_file_count: int, + max_compression_ratio: int, +) -> None: + if len(members) > max_file_count: + raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})") + + total = 0 + for m in members: + _reject_disallowed_entry_type(m) + _reject_absolute_path(m) + _reject_path_traversal(m, dest_real) + _reject_oversized_member(m, max_file_size=max_file_size) + _reject_compression_bomb(m, max_ratio=max_compression_ratio) + + total += m.file_size + if total > max_total_size: + raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes") + + +def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None: + # Unix mode lives in the upper 16 bits of external_attr when create_system==3. + if m.create_system != 3: + return + mode = m.external_attr >> 16 + if stat.S_ISLNK(mode) or stat.S_ISBLK(mode) or stat.S_ISCHR(mode) or stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode): + raise UnsafeArchiveError(f"disallowed entry type: {m.filename}") + + +def _reject_absolute_path(m: zipfile.ZipInfo) -> None: + name = m.filename + if name.startswith(("/", "\\")): + raise UnsafeArchiveError(f"absolute path in archive: {name}") + if len(name) >= 2 and name[1] == ":": + raise UnsafeArchiveError(f"drive-letter path in archive: {name}") + + +def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None: + target = (dest_real / m.filename).resolve() + try: + target.relative_to(dest_real) + except ValueError as exc: + raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc + + +def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None: + if m.file_size > max_file_size: + raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}") + + +def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None: + if m.compress_size <= 0 or m.file_size <= 0: + return + ratio = m.file_size / m.compress_size + if ratio > max_ratio: + raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}") diff --git a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py index 3a7e10a34b..e2d07732cf 100644 --- a/pyrit/datasets/seed_datasets/remote/figstep_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/figstep_dataset.py @@ -6,7 +6,6 @@ import logging import re import uuid -import zipfile from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -15,6 +14,7 @@ from pyrit.common.net_utility import make_request_and_raise_if_error_async from pyrit.common.path import DB_DATA_PATH +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote._image_cache import ( fetch_and_cache_image_async, ) @@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path: zip_bytes = response.content def _extract() -> None: - extract_dir.mkdir(parents=True, exist_ok=True) - with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: - zf.extractall(extract_dir) + safe_extract_zip(io.BytesIO(zip_bytes), extract_dir) await asyncio.to_thread(_extract) return extract_dir diff --git a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py index 251cfb5405..7c07cae974 100644 --- a/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py @@ -4,10 +4,10 @@ import logging import pathlib import uuid -import zipfile from enum import Enum from typing import Literal +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) @@ -149,8 +149,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: # Only unzip if the target directory does not already exist if not zip_extracted_path.exists(): logger.info(f"Extracting {zip_file_path} to {self.zip_dir}") - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(self.zip_dir) + safe_extract_zip(zip_file_path, self.zip_dir) try: logger.info(f"Loading JailBreakV-28K dataset from {self.source}") diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index b5e000e6c7..799f9f5c91 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -6,7 +6,6 @@ import logging import os import uuid -import zipfile from enum import Enum from pathlib import Path from typing import TYPE_CHECKING @@ -15,6 +14,7 @@ from typing_extensions import override from pyrit.common.path import DB_DATA_PATH +from pyrit.common.safe_extract import safe_extract_zip from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) @@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]: zip_path = cache_dir / "test.zip" if zip_path.exists(): logger.info("Extracting VLGuard test images...") - with zipfile.ZipFile(str(zip_path), "r") as zf: - zf.extractall(str(cache_dir)) + safe_extract_zip(zip_path, cache_dir) with open(json_path, encoding="utf-8") as f: metadata = json.load(f) diff --git a/tests/unit/common/test_safe_extract.py b/tests/unit/common/test_safe_extract.py new file mode 100644 index 0000000000..cd2233aa76 --- /dev/null +++ b/tests/unit/common/test_safe_extract.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import stat +import zipfile + +import pytest + +from pyrit.common.safe_extract import ( + DEFAULT_MAX_COMPRESSION_RATIO, + UnsafeArchiveError, + safe_extract_zip, +) + + +def _zip_with(entries): + """ + Build an in-memory zip. + + entries: list of (filename, data, external_attr_mode_or_None) + """ + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_STORED) as zf: + for name, data, mode in entries: + info = zipfile.ZipInfo(name) + info.create_system = 3 # unix, so external_attr is interpreted as mode + if mode is not None: + info.external_attr = mode << 16 + zf.writestr(info, data) + buf.seek(0) + return buf + + +def test_happy_path_extracts_files(tmp_path): + archive = _zip_with( + [ + ("a.txt", b"hello", None), + ("nested/b.txt", b"world", None), + ] + ) + out = safe_extract_zip(archive, tmp_path / "out") + + assert (out / "a.txt").read_bytes() == b"hello" + assert (out / "nested" / "b.txt").read_bytes() == b"world" + + +def test_rejects_dotdot_traversal(tmp_path): + archive = _zip_with([("../escape.txt", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="path traversal"): + safe_extract_zip(archive, tmp_path / "out") + # destination should be created but empty + assert list((tmp_path / "out").iterdir()) == [] + + +def test_rejects_absolute_unix_path(tmp_path): + archive = _zip_with([("/etc/passwd", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="absolute path"): + safe_extract_zip(archive, tmp_path / "out") + + +def test_rejects_drive_letter_path(tmp_path): + archive = _zip_with([("C:/windows/system32/x.dll", b"x", None)]) + with pytest.raises(UnsafeArchiveError, match="drive-letter"): + safe_extract_zip(archive, tmp_path / "out") + + +def test_rejects_symlink_entry(tmp_path): + archive = _zip_with([("link", b"../target", stat.S_IFLNK | 0o777)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(archive, tmp_path / "out") + + +def test_rejects_device_entry(tmp_path): + archive = _zip_with([("dev", b"", stat.S_IFBLK | 0o600)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(archive, tmp_path / "out") + + +def test_rejects_fifo_entry(tmp_path): + archive = _zip_with([("pipe", b"", stat.S_IFIFO | 0o600)]) + with pytest.raises(UnsafeArchiveError, match="disallowed entry type"): + safe_extract_zip(archive, tmp_path / "out") + + +def test_rejects_total_size_bomb(tmp_path): + archive = _zip_with([(f"f{i}.txt", b"A" * 1000, None) for i in range(5)]) + with pytest.raises(UnsafeArchiveError, match="total uncompressed size"): + safe_extract_zip(archive, tmp_path / "out", max_total_size=2000) + + +def test_rejects_single_file_bomb(tmp_path): + archive = _zip_with([("big.bin", b"A" * 1000, None)]) + with pytest.raises(UnsafeArchiveError, match="exceeds cap"): + safe_extract_zip(archive, tmp_path / "out", max_file_size=500) + + +def test_rejects_compression_ratio_bomb(tmp_path): + # DEFLATE 1 MiB of zeros into a few hundred bytes, classic ratio bomb. + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED, compresslevel=9) as zf: + info = zipfile.ZipInfo("bomb.bin") + info.create_system = 3 + info.compress_type = zipfile.ZIP_DEFLATED + zf.writestr(info, b"\x00" * (1024 * 1024)) + buf.seek(0) + + with pytest.raises(UnsafeArchiveError, match="compression ratio"): + safe_extract_zip( + buf, + tmp_path / "out", + max_compression_ratio=DEFAULT_MAX_COMPRESSION_RATIO, + max_file_size=10 * 1024 * 1024, + ) + + +def test_rejects_excessive_file_count(tmp_path): + archive = _zip_with([(f"f{i}.txt", b"x", None) for i in range(10)]) + with pytest.raises(UnsafeArchiveError, match="entries"): + safe_extract_zip(archive, tmp_path / "out", max_file_count=5) + + +def test_no_partial_write_when_one_member_invalid(tmp_path): + # First 2 entries are valid, third escapes, nothing should be written. + archive = _zip_with( + [ + ("ok1.txt", b"one", None), + ("ok2.txt", b"two", None), + ("../escape.txt", b"bad", None), + ] + ) + out = tmp_path / "out" + with pytest.raises(UnsafeArchiveError): + safe_extract_zip(archive, out) + + assert list(out.iterdir()) == [] + + +def test_accepts_bytes_source(tmp_path): + buf = _zip_with([("a.txt", b"hi", None)]) + out = safe_extract_zip(buf.getvalue(), tmp_path / "out") + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_accepts_path_source(tmp_path): + zip_path = tmp_path / "src.zip" + zip_path.write_bytes(_zip_with([("a.txt", b"hi", None)]).getvalue()) + + out = safe_extract_zip(zip_path, tmp_path / "out") + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_destination_dir_is_created(tmp_path): + archive = _zip_with([("a.txt", b"hi", None)]) + target = tmp_path / "does" / "not" / "exist" + + out = safe_extract_zip(archive, target) + assert out.is_dir() + assert (out / "a.txt").read_bytes() == b"hi" + + +def test_returns_resolved_destination(tmp_path): + archive = _zip_with([("a.txt", b"hi", None)]) + out = safe_extract_zip(archive, tmp_path / "out") + assert out == (tmp_path / "out").resolve() + assert out.is_absolute()