Skip to content

Commit ef770cf

Browse files
committed
Add package contraints to torchbench
1 parent 8ab8a3e commit ef770cf

File tree

4 files changed

+40
-18
lines changed

4 files changed

+40
-18
lines changed

install.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
from userbenchmark import list_userbenchmarks
8-
from utils import get_pkg_versions, TORCH_DEPS
8+
from utils import get_pkg_versions, TORCH_DEPS, generate_pkg_constraints
99

1010
REPO_ROOT = Path(__file__).parent
1111

@@ -38,6 +38,11 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
3838
action="store_true",
3939
help="Run in test mode and check package versions",
4040
)
41+
parser.add_argument(
42+
"--check-only",
43+
action="store_true",
44+
help="Only run the version check and generate the contraints"
45+
)
4146
parser.add_argument("--canary", action="store_true", help="Install canary model.")
4247
parser.add_argument("--continue_on_fail", action="store_true")
4348
parser.add_argument("--verbose", "-v", action="store_true")
@@ -51,12 +56,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
5156
os.chdir(os.path.realpath(os.path.dirname(__file__)))
5257

5358
print(
54-
f"checking packages {', '.join(TORCH_DEPS)} are installed...",
59+
f"checking packages {', '.join(TORCH_DEPS)} are installed, generating constaints...",
5560
end="",
5661
flush=True,
5762
)
5863
if args.userbenchmark:
59-
TORCH_DEPS = ["torch"]
64+
TORCH_DEPS = ["numpy", "torch"]
6065
try:
6166
versions = get_pkg_versions(TORCH_DEPS)
6267
except ModuleNotFoundError as e:
@@ -65,8 +70,12 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
6570
f"Error: Users must first manually install packages {TORCH_DEPS} before installing the benchmark."
6671
)
6772
sys.exit(-1)
73+
generate_pkg_constraints(versions)
6874
print("OK")
6975

76+
if args.check_only:
77+
exit(0)
78+
7079
if args.userbenchmark:
7180
# Install userbenchmark dependencies if exists
7281
userbenchmark_dir = REPO_ROOT.joinpath("userbenchmark", args.userbenchmark)
@@ -101,7 +110,9 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
101110
new_versions = get_pkg_versions(TORCH_DEPS)
102111
if versions != new_versions:
103112
print(
104-
f"The torch packages are re-installed after installing the benchmark deps. \
113+
f"The numpy and torch package versions become inconsistent after installing the benchmark deps. \
105114
Before: {versions}, after: {new_versions}"
106115
)
107116
sys.exit(-1)
117+
else:
118+
print(f"installed torchbench with package constraints: {versions}")

torchbenchmark/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def setup(
181181
versions = get_pkg_versions(TORCH_DEPS)
182182
success, errmsg, stdout_stderr = _install_deps(model_path, verbose=verbose)
183183
if test_mode:
184-
new_versions = get_pkg_versions(TORCH_DEPS, reload=True)
184+
new_versions = get_pkg_versions(TORCH_DEPS)
185185
if versions != new_versions:
186186
print(
187-
f"The torch packages are re-installed after installing the benchmark model {model_path}. \
187+
f"The numpy and torch packages are re-installed after installing the benchmark model {model_path}. \
188188
Before: {versions}, after: {new_versions}"
189189
)
190190
sys.exit(-1)

torchbenchmark/util/env_check.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
This file may be loaded without torch packages installed, e.g., in OnDemand CI.
44
"""
55

6-
import argparse
76
import copy
8-
import importlib
97
import os
108
import shutil
119
import argparse
@@ -187,10 +185,13 @@ def deterministic_torch_manual_seed(*args, **kwargs):
187185

188186

189187
def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
188+
import sys
189+
import subprocess
190190
versions = {}
191191
for module in packages:
192-
module = importlib.import_module(module)
193-
versions[module] = module.__version__
192+
cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)']
193+
version = subprocess.check_output(cmd).decode().strip()
194+
versions[module] = version
194195
return versions
195196

196197

utils/__init__.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import importlib
21
import sys
2+
import subprocess
33
from typing import Dict, List
4+
from pathlib import Path
45

5-
TORCH_DEPS = ["torch", "torchvision", "torchaudio"]
6+
REPO_DIR = Path(__file__).parent.parent
7+
TORCH_DEPS = ["numpy", "torch", "torchvision", "torchaudio"]
68

79

810
class add_path:
@@ -18,12 +20,20 @@ def __exit__(self, exc_type, exc_value, traceback):
1820
except ValueError:
1921
pass
2022

21-
22-
def get_pkg_versions(packages: List[str], reload: bool = False) -> Dict[str, str]:
23+
def get_pkg_versions(packages: List[str]) -> Dict[str, str]:
2324
versions = {}
2425
for module in packages:
25-
module = importlib.import_module(module)
26-
if reload:
27-
module = importlib.reload(module)
28-
versions[module.__name__] = module.__version__
26+
cmd = [sys.executable, "-c", f'import {module}; print({module}.__version__)']
27+
version = subprocess.check_output(cmd).decode().strip()
28+
versions[module] = version
2929
return versions
30+
31+
def generate_pkg_constraints(package_versions: Dict[str, str]):
32+
"""
33+
Generate package versions dict and save them to {REPO_ROOT}/build/constraints.txt
34+
"""
35+
output_dir = REPO_DIR.joinpath("build")
36+
output_dir.mkdir(exist_ok=True)
37+
with open(output_dir.joinpath("constraints.txt"), "w") as fp:
38+
for k, v in package_versions.items():
39+
fp.write(f"{k}=={v}\n")

0 commit comments

Comments
 (0)