Skip to content

Commit 14e33ff

Browse files
committed
Append extra args to pip_install_requirements
1 parent 37d0f0b commit 14e33ff

File tree

25 files changed

+36
-136
lines changed

25 files changed

+36
-136
lines changed

torchbenchmark/canary_models/DALLE2_pytorch/install.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import patch
33
import subprocess
44
import sys
5+
from utils.python_utils import pip_install_requirements
56

67
def patch_dalle2():
78
import dalle2_pytorch
@@ -12,8 +13,8 @@ def patch_dalle2():
1213
print("Failed to patch dalle2_pytorch/dalle2_pytorch.py. Exit.")
1314
exit(1)
1415

15-
def pip_install_requirements():
16-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
16+
def pip_install_requirements_dalle2():
17+
pip_install_requirements()
1718
# DALLE2_pytorch requires embedding-reader
1819
# https://github.com/lucidrains/DALLE2-pytorch/blob/00e07b7d61e21447d55e6d06d5c928cf8b67601d/setup.py#L34
1920
# embedding-reader requires an old version of pandas and pyarrow
@@ -22,5 +23,5 @@ def pip_install_requirements():
2223
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-U', 'pandas', 'pyarrow'])
2324

2425
if __name__ == '__main__':
25-
pip_install_requirements()
26+
pip_install_requirements_dalle2()
2627
patch_dalle2()

torchbenchmark/canary_models/codellama/install.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11

2-
import subprocess
3-
import sys
42
import os
53
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
64

torchbenchmark/canary_models/fambench_dlrm/install.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import subprocess
44
from torchbenchmark import REPO_PATH
5+
from utils.python_utils import pip_install_requirements
56

67

78
def update_fambench_submodule():
@@ -17,12 +18,6 @@ def update_fambench_submodule():
1718
subprocess.check_call(update_command, cwd=REPO_PATH)
1819

1920

20-
def pip_install_requirements():
21-
subprocess.check_call(
22-
[sys.executable, "-m", "pip", "install", "-q", "-r", "requirements.txt"]
23-
)
24-
25-
2621
if __name__ == "__main__":
2722
update_fambench_submodule()
2823
pip_install_requirements()

torchbenchmark/canary_models/fambench_xlmr/install.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import subprocess
44
from torchbenchmark import REPO_PATH
5-
5+
from utils.python_utils import pip_install_requirements
66

77
def update_fambench_submodule():
88
"Update FAMBench submodule of the benchmark repo"
@@ -19,9 +19,7 @@ def update_fambench_submodule():
1919

2020
def pip_install_requirements():
2121
try:
22-
subprocess.check_call(
23-
[sys.executable, "-m", "pip", "install", "-q", "-r", "requirements.txt"]
24-
)
22+
pip_install_requirements()
2523
# pin fairseq version
2624
# ignore deps specified in requirements.txt
2725
subprocess.check_call(
Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
1-
2-
import subprocess
3-
import sys
41
from utils import s3_utils
5-
6-
7-
def pip_install_requirements():
8-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt', '-f', 'https://data.pyg.org/whl/torch-2.1.0+cpu.html'])
9-
2+
from utils.python_utils import pip_install_requirements
103

114
if __name__ == '__main__':
125
s3_utils.checkout_s3_data("INPUT_TARBALLS", "Reddit_minimal.tar.gz", decompress=True)
13-
pip_install_requirements()
6+
pip_install_requirements(extra_args=["-f", "https://data.pyg.org/whl/torch-2.1.0+cpu.html"])

torchbenchmark/canary_models/hf_MPT_7b_instruct/install.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import subprocess
2-
import sys
31
import os
42
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
3+
from utils.python_utils import pip_install_requirements
84

95
if __name__ == '__main__':
106
pip_install_requirements()

torchbenchmark/canary_models/hf_Yi/install.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import sys
33
import os
44
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
5+
from utils.python_utils import pip_install_requirements
86

97
if __name__ == '__main__':
108
pip_install_requirements()

torchbenchmark/canary_models/hf_mixtral/install.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import subprocess
2-
import sys
31
import os
42
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
3+
from utils.python_utils import pip_install_requirements
54

6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
85

96
if __name__ == '__main__':
107
pip_install_requirements()

torchbenchmark/canary_models/phi_1_5/install.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import sys
33
import os
44
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
5+
from utils.python_utils import pip_install_requirements
86

97
if __name__ == '__main__':
108
pip_install_requirements()

torchbenchmark/canary_models/phi_2/install.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import subprocess
2-
import sys
31
import os
42
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
3+
from utils.python_utils import pip_install_requirements
84

95
if __name__ == '__main__':
106
pip_install_requirements()

torchbenchmark/canary_models/sage/install.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
1-
2-
import subprocess
3-
import sys
41
from utils import s3_utils
5-
6-
7-
def pip_install_requirements():
8-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt', '-f', 'https://data.pyg.org/whl/torch-2.1.0+cpu.html'])
9-
2+
from utils.python_utils import pip_install_requirements
103

114
if __name__ == '__main__':
125
s3_utils.checkout_s3_data("INPUT_TARBALLS", "Reddit_minimal.tar.gz", decompress=True)
Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
1-
import subprocess
2-
import sys
3-
4-
5-
def pip_install_requirements():
6-
subprocess.check_call(
7-
[sys.executable, "-m", "pip", "install", "-q", "-r", "requirements.txt"]
8-
)
9-
1+
from utils.python_utils import pip_install_requirements
102

113
if __name__ == "__main__":
124
pip_install_requirements()

torchbenchmark/models/maml_omniglot/install.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
import subprocess
2-
import sys
31
from utils import s3_utils
4-
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
2+
from utils.python_utils import pip_install_requirements
83

94
if __name__ == '__main__':
105
pip_install_requirements()

torchbenchmark/models/moondream/install.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import subprocess
2-
import sys
31
import os
42
from torchbenchmark.util.framework.huggingface.patch_hf import patch_transformers, cache_model
3+
from utils.python_utils import pip_install_requirements
54

6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
85

96
if __name__ == '__main__':
107
pip_install_requirements()

torchbenchmark/models/nvidia_deeprecommender/install.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import subprocess
22
import sys
3-
4-
5-
def pip_install_requirements():
6-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
3+
from utils.python_utils import pip_install_requirements
74

85
if __name__ == '__main__':
96
pip_install_requirements()
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
import subprocess
2-
import sys
3-
4-
def pip_install_requirements():
5-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
1+
from utils.python_utils import pip_install_requirements
62

73
if __name__ == '__main__':
84
pip_install_requirements()

torchbenchmark/models/pytorch_CycleGAN_and_pix2pix/install.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
import subprocess
2-
import sys
31
from utils import s3_utils
4-
5-
6-
def pip_install_requirements():
7-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
2+
from utils.python_utils import pip_install_requirements
83

94
if __name__ == '__main__':
105
s3_utils.checkout_s3_data("INPUT_TARBALLS", "pytorch_CycleGAN_and_pix2pix_inputs.tar.gz", decompress=True)

torchbenchmark/models/pytorch_stargan/install.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import subprocess
2-
import sys
31
from utils import s3_utils
4-
5-
def pip_install_requirements():
6-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
2+
from utils.python_utils import pip_install_requirements
73

84
if __name__ == '__main__':
95
s3_utils.checkout_s3_data("INPUT_TARBALLS", "pytorch_stargan_inputs.tar.gz", decompress=True)
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
import subprocess
2-
import sys
3-
4-
5-
def pip_install_requirements():
6-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'pytorch_unet/requirements.txt'])
1+
from utils.python_utils import pip_install_requirements
72

83
if __name__ == '__main__':
94
pip_install_requirements()

torchbenchmark/models/sam/install.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import subprocess
33
import sys
44
import requests
5+
from utils.python_utils import pip_install_requirements
56

67
def download(uri):
78
directory = '.data'
@@ -16,10 +17,6 @@ def download(uri):
1617
else:
1718
print(f'Failed to download file with status code {response.status_code}')
1819

19-
20-
def pip_install_requirements():
21-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
22-
2320
def download_checkpoint():
2421
download('https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
2522

torchbenchmark/models/sam_fast/install.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
2-
import subprocess
3-
import sys
42
import requests
3+
from utils.python_utils import pip_install_requirements
54

65
def download(uri):
76
directory = '.data'
@@ -16,10 +15,6 @@ def download(uri):
1615
else:
1716
print(f'Failed to download file with status code {response.status_code}')
1817

19-
20-
def pip_install_requirements():
21-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
22-
2318
def download_checkpoint():
2419
download('https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth')
2520

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
1-
import subprocess
2-
import sys
3-
4-
5-
def pip_install_requirements():
6-
subprocess.check_call(
7-
[sys.executable, "-m", "pip", "install", "-q", "-r", "requirements.txt"]
8-
)
9-
1+
from utils.python_utils import pip_install_requirements
102

113
if __name__ == "__main__":
124
pip_install_requirements()

torchbenchmark/models/speech_transformer/install.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import sys
2-
import subprocess
31
from utils import s3_utils
4-
5-
def pip_install_requirements():
6-
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])
2+
from utils.python_utils import pip_install_requirements
73

84
if __name__ == '__main__':
95
s3_utils.checkout_s3_data("INPUT_TARBALLS", "speech_transformer_inputs.tar.gz", decompress=True)

torchbenchmark/util/framework/lit_llama.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66

77
from torchbenchmark import REPO_PATH
8+
from utils.python_utils import pip_install_requirements
89

910
LIT_LLAMA_PATH = os.path.join(REPO_PATH, "submodules", "lit-llama")
1011

@@ -21,20 +22,6 @@ def update_lit_llama_submodule():
2122
subprocess.check_call(update_command, cwd=REPO_PATH)
2223

2324

24-
def pip_install_requirements():
25-
subprocess.check_call(
26-
[
27-
sys.executable,
28-
"-m",
29-
"pip",
30-
"install",
31-
"-q",
32-
"-r",
33-
os.path.join(LIT_LLAMA_PATH, "requirements.txt"),
34-
]
35-
)
36-
37-
3825
def openllama_download():
3926
if os.path.exists(
4027
os.path.join(LIT_LLAMA_PATH, "checkpoints/lit-llama/7B/lit-llama.pth")

utils/python_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from pathlib import Path
33
import subprocess
44

5+
from typing import Optional, List
6+
57
DEFAULT_PYTHON_VERSION = "3.11"
68

79
PYTHON_VERSION_MAP = {
@@ -23,7 +25,10 @@ def create_conda_env(pyver: str, name: str):
2325
subprocess.check_call(command)
2426

2527

26-
def pip_install_requirements(requirements_txt="requirements.txt", continue_on_fail=False, no_build_isolation=False):
28+
def pip_install_requirements(requirements_txt="requirements.txt",
29+
continue_on_fail=False,
30+
no_build_isolation=False,
31+
extra_args: Optional[List[str]]=None):
2732
import sys
2833
constraints_file = REPO_DIR.joinpath("build", "constraints.txt")
2934
if not constraints_file.exists():
@@ -33,6 +38,8 @@ def pip_install_requirements(requirements_txt="requirements.txt", continue_on_fa
3338
constraints_parameters = ["-c", str(constraints_file.resolve())]
3439
if no_build_isolation:
3540
constraints_parameters.append("--no-build-isolation")
41+
if extra_args and isinstance(extra_args, list):
42+
constraints_parameters.extend(extra_args)
3643
if not continue_on_fail:
3744
subprocess.check_call(
3845
[sys.executable, "-m", "pip", "install", "-r", requirements_txt] + constraints_parameters,

0 commit comments

Comments
 (0)