Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions pyrit/common/safe_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
Defensive ZIP extraction for untrusted remote archives.

Remote dataset loaders in PyRIT download ZIP archives from third-party sources
and feed them to ``zipfile.ZipFile.extractall()``. ``extractall`` does not
validate member paths, file sizes, or entry types, which leaves the loader
vulnerable to Zip Slip (CWE-22), zip bombs, and symlink-based path escape if
any upstream source is tampered with.

``safe_extract_zip`` validates every archive member before writing anything to
disk. If any member fails validation, the destination directory is left empty.
"""

from __future__ import annotations

import io
import logging
import os
import stat
import zipfile
from pathlib import Path
from typing import IO

logger = logging.getLogger(__name__)

# 5 GiB cumulative uncompressed size across all members
DEFAULT_MAX_TOTAL_SIZE = 5 * 1024**3
# 1 GiB cap on any single member
DEFAULT_MAX_FILE_SIZE = 1 * 1024**3
# 50_000 entries: above legitimate dataset sizes, defeats inode DoS
DEFAULT_MAX_FILE_COUNT = 50_000
# Reject members whose uncompressed/compressed ratio exceeds this (zip bomb)
DEFAULT_MAX_COMPRESSION_RATIO = 100

ZipSource = str | os.PathLike | bytes | IO[bytes]


class UnsafeArchiveError(Exception):
"""Raised when an archive member fails a safe-extraction precondition."""


def safe_extract_zip(
source: ZipSource,
dest_dir: str | os.PathLike,
*,
max_total_size: int = DEFAULT_MAX_TOTAL_SIZE,
max_file_size: int = DEFAULT_MAX_FILE_SIZE,
max_file_count: int = DEFAULT_MAX_FILE_COUNT,
max_compression_ratio: int = DEFAULT_MAX_COMPRESSION_RATIO,
) -> Path:
"""
Extract a ZIP archive after validating every member.

Validation runs in a single pass over the archive's central directory
before any bytes are written. If any check fails, ``UnsafeArchiveError`` is
raised and the destination directory is left without partial output from
this call.

Args:
source: Path, bytes, or file-like object accepted by ``zipfile.ZipFile``.
dest_dir: Directory to extract into. Created if it does not exist.
max_total_size: Cap on the sum of uncompressed member sizes.
max_file_size: Cap on any single member's uncompressed size.
max_file_count: Cap on the number of members in the archive.
max_compression_ratio: Reject members whose uncompressed/compressed
ratio exceeds this value (zip bomb defense).

Returns:
Resolved destination directory.

Raises:
UnsafeArchiveError: If any member fails validation.
"""
if isinstance(source, (bytes, bytearray)):
source = io.BytesIO(source)

dest_real = Path(dest_dir).resolve()
dest_real.mkdir(parents=True, exist_ok=True)

with zipfile.ZipFile(source) as zf:
members = zf.infolist()
_validate_members(
members,
dest_real=dest_real,
max_total_size=max_total_size,
max_file_size=max_file_size,
max_file_count=max_file_count,
max_compression_ratio=max_compression_ratio,
)
for m in members:
zf.extract(m, dest_real)

return dest_real


def _validate_members(
members: list[zipfile.ZipInfo],
*,
dest_real: Path,
max_total_size: int,
max_file_size: int,
max_file_count: int,
max_compression_ratio: int,
) -> None:
if len(members) > max_file_count:
raise UnsafeArchiveError(f"archive contains {len(members)} entries (max {max_file_count})")

total = 0
for m in members:
_reject_disallowed_entry_type(m)
_reject_absolute_path(m)
_reject_path_traversal(m, dest_real)
_reject_oversized_member(m, max_file_size=max_file_size)
_reject_compression_bomb(m, max_ratio=max_compression_ratio)

total += m.file_size
if total > max_total_size:
raise UnsafeArchiveError(f"total uncompressed size exceeds {max_total_size} bytes")


def _reject_disallowed_entry_type(m: zipfile.ZipInfo) -> None:
# Unix mode lives in the upper 16 bits of external_attr when create_system==3.
if m.create_system != 3:
return
mode = m.external_attr >> 16
if stat.S_ISLNK(mode) or stat.S_ISBLK(mode) or stat.S_ISCHR(mode) or stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode):
raise UnsafeArchiveError(f"disallowed entry type: {m.filename}")


def _reject_absolute_path(m: zipfile.ZipInfo) -> None:
name = m.filename
if name.startswith(("/", "\\")):
raise UnsafeArchiveError(f"absolute path in archive: {name}")
if len(name) >= 2 and name[1] == ":":
raise UnsafeArchiveError(f"drive-letter path in archive: {name}")


def _reject_path_traversal(m: zipfile.ZipInfo, dest_real: Path) -> None:
target = (dest_real / m.filename).resolve()
try:
target.relative_to(dest_real)
except ValueError as exc:
raise UnsafeArchiveError(f"path traversal in archive: {m.filename!r} escapes {dest_real}") from exc


def _reject_oversized_member(m: zipfile.ZipInfo, *, max_file_size: int) -> None:
if m.file_size > max_file_size:
raise UnsafeArchiveError(f"member {m.filename!r} uncompressed size {m.file_size} exceeds cap {max_file_size}")


def _reject_compression_bomb(m: zipfile.ZipInfo, *, max_ratio: int) -> None:
if m.compress_size <= 0 or m.file_size <= 0:
return
ratio = m.file_size / m.compress_size
if ratio > max_ratio:
raise UnsafeArchiveError(f"member {m.filename!r} compression ratio {ratio:.1f} exceeds cap {max_ratio}")
6 changes: 2 additions & 4 deletions pyrit/datasets/seed_datasets/remote/figstep_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import re
import uuid
import zipfile
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Literal
Expand All @@ -15,6 +14,7 @@

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.common.path import DB_DATA_PATH
from pyrit.common.safe_extract import safe_extract_zip
from pyrit.datasets.seed_datasets.remote._image_cache import (
fetch_and_cache_image_async,
)
Expand Down Expand Up @@ -562,9 +562,7 @@ async def _download_and_extract_pro_zip_async(self, *, cache: bool) -> Path:
zip_bytes = response.content

def _extract() -> None:
extract_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
zf.extractall(extract_dir)
safe_extract_zip(io.BytesIO(zip_bytes), extract_dir)

await asyncio.to_thread(_extract)
return extract_dir
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/seed_datasets/remote/jailbreakv_28k_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
import pathlib
import uuid
import zipfile
from enum import Enum
from typing import Literal

from pyrit.common.safe_extract import safe_extract_zip
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
Expand Down Expand Up @@ -149,8 +149,7 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:
# Only unzip if the target directory does not already exist
if not zip_extracted_path.exists():
logger.info(f"Extracting {zip_file_path} to {self.zip_dir}")
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(self.zip_dir)
safe_extract_zip(zip_file_path, self.zip_dir)

try:
logger.info(f"Loading JailBreakV-28K dataset from {self.source}")
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/seed_datasets/remote/vlguard_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import os
import uuid
import zipfile
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -15,6 +14,7 @@
from typing_extensions import override

from pyrit.common.path import DB_DATA_PATH
from pyrit.common.safe_extract import safe_extract_zip
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
Expand Down Expand Up @@ -329,8 +329,7 @@ def _download_sync() -> tuple[str, str]:
zip_path = cache_dir / "test.zip"
if zip_path.exists():
logger.info("Extracting VLGuard test images...")
with zipfile.ZipFile(str(zip_path), "r") as zf:
zf.extractall(str(cache_dir))
safe_extract_zip(zip_path, cache_dir)

with open(json_path, encoding="utf-8") as f:
metadata = json.load(f)
Expand Down
Loading