Skip to content

[poetry] when running prep_binary_for_pypi inject linux metadata to all wheels #6681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 99 additions & 45 deletions release/pypi/prep_binary_for_pypi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# flake8: noqa: E501

"""
Preps binaries for publishing to PyPI by removing the version suffix normally added for binaries.
Expand All @@ -17,7 +18,41 @@
from pathlib import Path


def process_wheel(whl_file, output_dir=None):
cuda_metadata_inject = """Requires-Dist: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == "Linux" and platform_machine == "x86_64"
Copy link
Contributor Author

@atalman atalman May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really fan of this section.

One Possibility is to Construct this on the fly from:
https://github.com/pytorch/pytorch/blob/main/.github/scripts/generate_binary_build_matrix.py#L58

Triton can be extracted from here:
https://github.com/pytorch/pytorch/blob/main/.ci/docker/triton_version.txt

Let me know what you think, I can iterate on this PR and add the logic to construct these requirements on the fly.

Keeping it simple for now.

Requires-Dist: nvidia-cuda-runtime-cu12==12.6.77; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cuda-cupti-cu12==12.6.80; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cudnn-cu12==9.5.1.17; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cublas-cu12==12.6.4.1; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cufft-cu12==11.3.0.4; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-curand-cu12==10.3.7.77; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cusolver-cu12==11.7.1.2; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cusparse-cu12==12.5.4.2; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cusparselt-cu12==0.6.3; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-nccl-cu12==2.26.2; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-nvtx-cu12==12.6.77; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-nvjitlink-cu12==12.6.85; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: nvidia-cufile-cu12==1.11.1.6; platform_system == "Linux" and platform_machine == "x86_64"
Requires-Dist: triton==3.3.1; platform_system == "Linux" and platform_machine == "x86_64"
"""


def inject_text_block(filename, match_pattern, text_block) -> bool:
with open(filename, "r", encoding="utf-8") as file:
lines = file.readlines()

for i, line in enumerate(lines):
if match_pattern in line:
lines.insert(i, text_block)
break

with open(filename, "w", encoding="utf-8") as file:
file.writelines(lines)

return True


def process_wheel(whl_file, output_dir=None, package=""):
"""Process a single wheel file to remove version suffixes and perform METADATA injection"""
# Check if auditwheel is installed
try:
from auditwheel.wheeltools import InWheelCtx
Expand All @@ -27,14 +62,19 @@ def process_wheel(whl_file, output_dir=None):
)
sys.exit(1)

"""Process a single wheel file to remove version suffixes"""
metadata_injected = False
version_replacement_required = False

if output_dir is None:
output_dir = os.getcwd()

# Convert to absolute paths
whl_file = Path(os.path.abspath(whl_file))
output_dir = Path(os.path.abspath(output_dir))

# Get the original wheel filename
wheel_filename = os.path.basename(whl_file)

# Create a temporary directory for working
with tempfile.TemporaryDirectory() as tmp_dir:
# Open the wheel using auditwheel's tools
Expand Down Expand Up @@ -63,6 +103,11 @@ def process_wheel(whl_file, output_dir=None):
print(f"Error: METADATA file not found in {whl_file}")
return

if "manylinux_2_28_x86_64" not in wheel_filename and package == "torch":
metadata_injected = inject_text_block(
metadata_file, "Provides-Extra:", cuda_metadata_inject
)

# Extract version with suffix
version_with_suffix = None
with open(metadata_file, "r", encoding="utf-8") as f:
Expand All @@ -78,56 +123,62 @@ def process_wheel(whl_file, output_dir=None):
# Check if there's a suffix to remove
if "+" not in version_with_suffix:
print(f"No suffix found in version {version_with_suffix}, skipping")
return
if not metadata_injected:
return
else:
version_replacement_required = True

# Remove suffix from version
version_no_suffix = version_with_suffix.split("+")[0]
print(f"Removing suffix: {version_with_suffix} -> {version_no_suffix}")

# Update version in all files in dist-info
for root, _dirs, files in os.walk(os.path.join(ctx.path, dist_info_dir)):
for file in files:
file_path = os.path.join(root, file)
try:
with open(
file_path, "r", encoding="utf-8", errors="ignore"
) as f:
content = f.read()

# Replace version with suffix to version without suffix
updated_content = content.replace(
version_with_suffix, version_no_suffix
)

if content != updated_content:
with open(file_path, "w", encoding="utf-8") as f:
f.write(updated_content)
except UnicodeDecodeError:
# Skip binary files
pass

# Rename the dist-info directory
new_dist_info_dir = dist_info_dir.replace(
version_with_suffix, version_no_suffix
)
if new_dist_info_dir != dist_info_dir:
print(f"Renaming {new_dist_info_dir}")
os.rename(
os.path.join(ctx.path, dist_info_dir),
os.path.join(ctx.path, new_dist_info_dir),
if version_replacement_required:
version_no_suffix = version_with_suffix.split("+")[0]
print(f"Removing suffix: {version_with_suffix} -> {version_no_suffix}")

# Update version in all files in dist-info
for root, _dirs, files in os.walk(
os.path.join(ctx.path, dist_info_dir)
):
for file in files:
file_path = os.path.join(root, file)
try:
with open(
file_path, "r", encoding="utf-8", errors="ignore"
) as f:
content = f.read()

# Replace version with suffix to version without suffix
updated_content = content.replace(
version_with_suffix, version_no_suffix
)

if content != updated_content:
with open(file_path, "w", encoding="utf-8") as f:
f.write(updated_content)
except UnicodeDecodeError:
# Skip binary files
pass

# Rename the dist-info directory
new_dist_info_dir = dist_info_dir.replace(
version_with_suffix, version_no_suffix
)
if new_dist_info_dir != dist_info_dir:
print(f"Renaming {new_dist_info_dir}")
os.rename(
os.path.join(ctx.path, dist_info_dir),
os.path.join(ctx.path, new_dist_info_dir),
)

# Let auditwheel handle recreating the RECORD file when the context exits
pass

# Get the original wheel filename
wheel_filename = os.path.basename(whl_file)

# Create the new filename with updated version
version_with_suffix_escaped = version_with_suffix.replace("+", "%2B")
new_wheel_filename = wheel_filename.replace(
version_with_suffix_escaped, version_no_suffix
)
if version_replacement_required:
version_with_suffix_escaped = version_with_suffix.replace("+", "%2B")
new_wheel_filename = wheel_filename.replace(
version_with_suffix_escaped, version_no_suffix
)
else:
new_wheel_filename = wheel_filename

# The wheel will be created in the same directory as the original
# Move it to the requested output directory if needed
Expand All @@ -150,17 +201,20 @@ def main():
default=None,
help="Directory to output processed wheels (default: current directory)",
)
parser.add_argument("--package", default="", help="Package to process")

args = parser.parse_args()

output_dir = args.output_dir or os.getcwd()
os.makedirs(output_dir, exist_ok=True)

for whl_file in args.wheel_files:
if not os.path.exists(whl_file):
print(f"Error: Wheel file not found: {whl_file}")
continue

try:
process_wheel(whl_file, output_dir)
process_wheel(whl_file, output_dir, args.package)
except Exception as e:
print(f"Error processing {whl_file}: {e}")

Expand Down
4 changes: 2 additions & 2 deletions release/pypi/upload_pypi_to_staging.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ for pkg in ${pkgs_to_promote}; do
curl -fSL -o "${orig_pkg}" "https://download.pytorch.org${pkg}"
)

if [[ -n "${VERSION_SUFFIX}" ]]; then
OUTPUT_DIR="${output_tmp_dir}" python "${DIR}/prep_binary_for_pypi.py" "${orig_pkg}" --output-dir .
if [[ -n "${VERSION_SUFFIX}" || "${pkg}" == "torch"]]; then
OUTPUT_DIR="${output_tmp_dir}" python "${DIR}/prep_binary_for_pypi.py" "${orig_pkg}" --output-dir . --package ${pkg}
else
mv "${orig_pkg}" "${output_tmp_dir}/"
fi
Expand Down