Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/private/pypi/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
load("//python:py_library.bzl", "py_library")

package(default_visibility = ["//:__subpackages__"])

Expand Down Expand Up @@ -377,6 +378,12 @@ bzl_library(
],
)

py_library(
name = "repack_whl",
srcs = ["repack_whl.py"],
deps = ["//tools:wheelmaker"],
)

bzl_library(
name = "requirements_files_by_platform_bzl",
srcs = ["requirements_files_by_platform.bzl"],
Expand Down
18 changes: 16 additions & 2 deletions python/private/pypi/repack_whl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@
_DISTINFO = "dist-info"


def _has_all_quoted_filenames(record_contents: str) -> bool:
"""Check if all filenames in the RECORD are quoted.

Some wheels (like torch) have all filenames quoted in their RECORD file.
We detect this to preserve the quoting style when repacking.
"""
lines = record_contents.splitlines()
return all(line.startswith('"') for line in lines)


def _unidiff_output(expected, actual, record):
"""
Helper function. Returns a string containing the unified diff of two
Expand Down Expand Up @@ -151,17 +161,21 @@ def main(sys_argv):
logging.debug(f"Found dist-info dir: {distinfo_dir}")
record_path = distinfo_dir / "RECORD"
record_contents = record_path.read_text() if record_path.exists() else ""
quote_files = _has_all_quoted_filenames(record_contents)
distribution_prefix = distinfo_dir.with_suffix("").name

with _WhlFile(
args.output, mode="w", distribution_prefix=distribution_prefix
args.output,
mode="w",
distribution_prefix=distribution_prefix,
quote_all_filenames=quote_files,
) as out:
for p in _files_to_pack(patched_wheel_dir, record_contents):
rel_path = p.relative_to(patched_wheel_dir)
out.add_file(str(rel_path), p)

logging.debug(f"Writing RECORD file")
got_record = out.add_recordfile().decode("utf-8", "surrogateescape")
got_record = out.add_recordfile()

if got_record == record_contents:
logging.info(f"Created a whl file: {args.output}")
Expand Down
8 changes: 8 additions & 0 deletions tests/pypi/repack_whl/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
load("//python:py_test.bzl", "py_test")

py_test(
name = "repack_whl_test",
size = "small",
srcs = ["repack_whl_test.py"],
deps = ["//python/private/pypi:repack_whl"],
)
37 changes: 37 additions & 0 deletions tests/pypi/repack_whl/repack_whl_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

from python.private.pypi import repack_whl


class HasAllQuotedFilenamesTest(unittest.TestCase):
"""Tests for _has_all_quoted_filenames detection logic."""

def test_all_quoted(self) -> None:
"""Returns True when all lines start with quotes (torch-style)."""
record = """\
"torch/__init__.py",sha256=abc,123
"torch/utils.py",sha256=def,456
"torch-2.0.0.dist-info/WHEEL",sha256=ghi,789
"""
self.assertTrue(repack_whl._has_all_quoted_filenames(record))

def test_none_quoted(self) -> None:
"""Returns False when no lines are quoted (standard style)."""
record = """\
torch/__init__.py,sha256=abc,123
torch/utils.py,sha256=def,456
torch-2.0.0.dist-info/WHEEL,sha256=ghi,789
"""
self.assertFalse(repack_whl._has_all_quoted_filenames(record))

def test_mixed_quoting(self) -> None:
"""Returns False when only some lines are quoted."""
record = """\
"file,with,commas.py",sha256=abc,123
normal_file.py,sha256=def,456
"""
self.assertFalse(repack_whl._has_all_quoted_filenames(record))


if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions tests/tools/wheelmaker_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
import io
import unittest

import tools.wheelmaker as wheelmaker


class QuoteAllFilenamesTest(unittest.TestCase):
"""Tests for quote_all_filenames behavior in _WhlFile.

Some wheels (like torch) have all filenames quoted in their RECORD file.
When repacking, we preserve this style to minimize diffs.
"""

def _make_whl_file(self, quote_all: bool) -> wheelmaker._WhlFile:
"""Create a _WhlFile instance for testing."""
buf = io.BytesIO()
return wheelmaker._WhlFile(
buf,
mode="w",
distribution_prefix="test-1.0.0",
quote_all_filenames=quote_all,
)

def test_quote_all_quotes_simple_filenames(self) -> None:
"""When quote_all_filenames=True, all filenames are quoted."""
whl = self._make_whl_file(quote_all=True)
self.assertEqual(whl._quote_filename("foo/bar.py"), '"foo/bar.py"')

def test_quote_all_false_leaves_simple_filenames_unquoted(self) -> None:
"""When quote_all_filenames=False, simple filenames stay unquoted."""
whl = self._make_whl_file(quote_all=False)
self.assertEqual(whl._quote_filename("foo/bar.py"), "foo/bar.py")

def test_quote_all_quotes_filenames_with_commas(self) -> None:
"""Filenames with commas are always quoted, regardless of quote_all_filenames."""
whl = self._make_whl_file(quote_all=True)
self.assertEqual(whl._quote_filename("foo,bar/baz.py"), '"foo,bar/baz.py"')

whl = self._make_whl_file(quote_all=False)
self.assertEqual(whl._quote_filename("foo,bar/baz.py"), '"foo,bar/baz.py"')


class ArcNameFromTest(unittest.TestCase):
def test_arcname_from(self) -> None:
# (name, distribution_prefix, strip_path_prefixes, want) tuples
Expand Down
57 changes: 29 additions & 28 deletions tools/wheelmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,17 @@ def __init__(
distribution_prefix: str,
strip_path_prefixes=None,
compression=zipfile.ZIP_DEFLATED,
quote_all_filenames: bool = False,
**kwargs,
):
self._distribution_prefix = distribution_prefix

self._strip_path_prefixes = strip_path_prefixes or []
# Entries for the RECORD file as (filename, hash, size) tuples.
self._record = []
# Entries for the RECORD file as (filename, digest, size) tuples.
self._record: list[tuple[str, str, str]] = []
# Whether to quote filenames in the RECORD file (for compatibility with
# some wheels like torch that have quoted filenames in their RECORD).
self.quote_all_filenames = quote_all_filenames

super().__init__(filename, mode=mode, compression=compression, **kwargs)

Expand Down Expand Up @@ -192,16 +196,15 @@ def add_string(self, filename, contents):
hash.update(contents)
self._add_to_record(filename, self._serialize_digest(hash), len(contents))

def _serialize_digest(self, hash):
def _serialize_digest(self, hash) -> str:
# https://www.python.org/dev/peps/pep-0376/#record
# "base64.urlsafe_b64encode(digest) with trailing = removed"
digest = base64.urlsafe_b64encode(hash.digest())
digest = b"sha256=" + digest.rstrip(b"=")
return digest
return digest.decode("utf-8", "surrogateescape")

def _add_to_record(self, filename, hash, size):
size = str(size).encode("ascii")
self._record.append((filename, hash, size))
def _add_to_record(self, filename: str, hash: str, size: int) -> None:
self._record.append((filename, hash, str(size)))

def _zipinfo(self, filename):
"""Construct deterministic ZipInfo entry for a file named filename"""
Expand All @@ -223,29 +226,27 @@ def _zipinfo(self, filename):
zinfo.compress_type = self.compression
return zinfo

def add_recordfile(self):
def _quote_filename(self, filename: str) -> str:
"""Return a possibly quoted filename for RECORD file."""
filename = filename.lstrip("/")
# Some RECORDs like torch have *all* filenames quoted and we must minimize diff.
# Otherwise, we quote only when necessary (e.g. for filenames with commas).
quoting = csv.QUOTE_ALL if self.quote_all_filenames else csv.QUOTE_MINIMAL
with io.StringIO() as buf:
csv.writer(buf, quoting=quoting).writerow([filename])
return buf.getvalue().strip()

def add_recordfile(self) -> str:
"""Write RECORD file to the distribution."""
record_path = self.distinfo_path("RECORD")
entries = self._record + [(record_path, b"", b"")]
with io.StringIO() as contents_io:
writer = csv.writer(contents_io, lineterminator="\n")
for filename, digest, size in entries:
if isinstance(filename, str):
filename = filename.lstrip("/")
writer.writerow(
(
(
c
if isinstance(c, str)
else c.decode("utf-8", "surrogateescape")
)
for c in (filename, digest, size)
)
)

contents = contents_io.getvalue()
self.add_string(record_path, contents)
return contents.encode("utf-8", "surrogateescape")
entries = self._record + [(record_path, "", "")]
entries = [
(self._quote_filename(fname), digest, size)
for fname, digest, size in entries
]
contents = "\n".join(",".join(entry) for entry in entries) + "\n"
self.add_string(record_path, contents)
return contents


class WheelMaker(object):
Expand Down