Skip to content

Commit 1b7c8e8

Browse files
Add editable jax wheel target.
The set of editable wheels (`jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt`) was used as dependencies in `requirements.in` file together with `:build_jaxlib=false` flag. After [adding `jax` wheel dependencies](f5a4d1a) to the tests when `:build_jaxlib=false` is used, we need an editable `jax` wheel target as well to get the tests passing. PiperOrigin-RevId: 740840736
1 parent feed69c commit 1b7c8e8

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

BUILD.bazel

+22-20
Original file line numberDiff line numberDiff line change
@@ -72,35 +72,37 @@ py_binary(
7272
],
7373
)
7474

75+
WHEEL_SOURCE_FILES = [
76+
":transitive_py_data",
77+
":transitive_py_deps",
78+
"//jax:py.typed",
79+
"AUTHORS",
80+
"LICENSE",
81+
"README.md",
82+
"pyproject.toml",
83+
"setup.py",
84+
]
85+
7586
jax_wheel(
7687
name = "jax_wheel",
7788
platform_independent = True,
78-
source_files = [
79-
":transitive_py_data",
80-
":transitive_py_deps",
81-
"//jax:py.typed",
82-
"AUTHORS",
83-
"LICENSE",
84-
"README.md",
85-
"pyproject.toml",
86-
"setup.py",
87-
],
89+
source_files = WHEEL_SOURCE_FILES,
90+
wheel_binary = ":build_wheel",
91+
wheel_name = "jax",
92+
)
93+
94+
jax_wheel(
95+
name = "jax_wheel_editable",
96+
editable = True,
97+
platform_independent = True,
98+
source_files = WHEEL_SOURCE_FILES,
8899
wheel_binary = ":build_wheel",
89100
wheel_name = "jax",
90101
)
91102

92103
jax_source_package(
93104
name = "jax_source_package",
94-
source_files = [
95-
":transitive_py_data",
96-
":transitive_py_deps",
97-
"//jax:py.typed",
98-
"AUTHORS",
99-
"LICENSE",
100-
"README.md",
101-
"pyproject.toml",
102-
"setup.py",
103-
],
105+
source_files = WHEEL_SOURCE_FILES,
104106
source_package_binary = ":build_wheel",
105107
source_package_name = "jax",
106108
)

build/build.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,14 @@
6868
# rule as the default.
6969
WHEEL_BUILD_TARGET_DICT_NEW = {
7070
"jax": "//:jax_wheel",
71+
"jax_editable": "//:jax_wheel_editable",
7172
"jax_source_package": "//:jax_source_package",
7273
"jaxlib": "//jaxlib/tools:jaxlib_wheel",
74+
"jaxlib_editable": "//jaxlib/tools:jaxlib_wheel_editable",
7375
"jax-cuda-plugin": "//jaxlib/tools:jax_cuda_plugin_wheel",
76+
"jax-cuda-plugin_editable": "//jaxlib/tools:jax_cuda_plugin_wheel_editable",
7477
"jax-cuda-pjrt": "//jaxlib/tools:jax_cuda_pjrt_wheel",
78+
"jax-cuda-pjrt_editable": "//jaxlib/tools:jax_cuda_pjrt_wheel_editable",
7579
"jax-rocm-plugin": "//jaxlib/tools:jax_rocm_plugin_wheel",
7680
"jax-rocm-pjrt": "//jaxlib/tools:jax_rocm_pjrt_wheel",
7781
}
@@ -662,9 +666,12 @@ async def main():
662666
)
663667

664668
# Append the build target to the Bazel command.
665-
build_target = wheel_build_targets[wheel]
669+
if args.use_new_wheel_build_rule and args.editable:
670+
build_target = wheel_build_targets[wheel + "_editable"]
671+
else:
672+
build_target = wheel_build_targets[wheel]
666673
wheel_build_command.append(build_target)
667-
if args.use_new_wheel_build_rule and wheel == "jax":
674+
if args.use_new_wheel_build_rule and wheel == "jax" and not args.editable:
668675
wheel_build_command.append(wheel_build_targets["jax_source_package"])
669676

670677
if not args.use_new_wheel_build_rule:

build_wheel.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@
6161
"Whether to build the source package only. Optional."
6262
),
6363
)
64+
parser.add_argument(
65+
"--editable",
66+
action="store_true",
67+
help="Create an 'editable' jax build instead of a wheel.",
68+
)
6469
args = parser.parse_args()
6570

6671

@@ -90,7 +95,11 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
9095
"""
9196

9297
for file in deps:
93-
if not (file.startswith("bazel-out") or file.startswith("external")):
98+
if not (
99+
file.startswith("bazel-out")
100+
or file.startswith("external")
101+
or file.startswith("jaxlib")
102+
):
94103
copy_file(file, srcs_dir)
95104

96105

@@ -103,14 +112,18 @@ def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
103112
try:
104113
os.makedirs(args.output_path, exist_ok=True)
105114
prepare_srcs(args.srcs, pathlib.Path(sources_path))
106-
build_utils.build_wheel(
107-
sources_path,
108-
args.output_path,
109-
package_name="jax",
110-
git_hash=args.jaxlib_git_hash,
111-
build_wheel_only=args.build_wheel_only,
112-
build_source_package_only=args.build_source_package_only,
113-
)
115+
package_name = "jax"
116+
if args.editable:
117+
build_utils.build_editable(sources_path, args.output_path, package_name)
118+
else:
119+
build_utils.build_wheel(
120+
sources_path,
121+
args.output_path,
122+
package_name,
123+
git_hash=args.jaxlib_git_hash,
124+
build_wheel_only=args.build_wheel_only,
125+
build_source_package_only=args.build_source_package_only,
126+
)
114127
finally:
115128
if tmpdir:
116129
tmpdir.cleanup()

0 commit comments

Comments
 (0)