Skip to content

Commit 22375ab

Browse files
Add Bazel CPU tests with py_import dependency to continuous tests.
This is needed to collect the statistics and identify the failures which are detected by pytest but not by Bazel or vice versa. These tests don't require pre-built wheels downloaded from GCS. Instead they build the wheels as transitive dependencies of the test targets and unpack them using `py_import`. Execution example - https://github.com/jax-ml/jax/actions/runs/14800773538/job/41558800743 PiperOrigin-RevId: 754113096
1 parent 5a3605d commit 22375ab

File tree

3 files changed

+152
-2
lines changed

3 files changed

+152
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# CI - Bazel CPU tests with py_import (RBE)
2+
#
3+
# This workflow runs the Bazel CPU tests with py_import dependency. It can only be triggered by
4+
# other workflows via `workflow_call`. It is used by the `CI - Wheel Tests (Continuous)` workflows
5+
# to run the Bazel CPU tests.
6+
#
7+
# It consists of the following job:
8+
# run-tests:
9+
# - Executes the `run_bazel_test_cpu_py_import_rbe.sh` script, which performs the following actions:
10+
# - Runs the Bazel CPU tests with py_import dependency.
11+
name: CI - Bazel CPU tests with py_import (RBE)
12+
permissions:
13+
contents: read
14+
15+
on:
16+
workflow_call:
17+
inputs:
18+
runner:
19+
description: "Which runner should the workflow run on?"
20+
type: string
21+
required: true
22+
default: "linux-x86-n2-16"
23+
python:
24+
description: "Which python version to test?"
25+
type: string
26+
required: true
27+
default: "3.12"
28+
enable-x64:
29+
description: "Should x64 mode be enabled?"
30+
type: string
31+
required: true
32+
default: "0"
33+
halt-for-connection:
34+
description: 'Should this workflow run wait for a remote connection?'
35+
type: string
36+
required: false
37+
default: 'no'
38+
39+
jobs:
40+
run-tests:
41+
defaults:
42+
run:
43+
# Explicitly set the shell to bash
44+
shell: bash
45+
runs-on: ${{ inputs.runner }}
46+
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
47+
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}
48+
env:
49+
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
50+
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}
51+
52+
name: "Bazel CPU tests with py_import (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
53+
54+
steps:
55+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
56+
# Halt for testing
57+
- name: Wait For Connection
58+
uses: google-ml-infra/actions/ci_connection@main
59+
with:
60+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
61+
- name: Run Bazel CPU tests with py_import (RBE)
62+
timeout-minutes: 60
63+
run: ./ci/run_bazel_test_cpu_py_import_rbe.sh

.github/workflows/wheel_tests_continuous.yml

+20-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,17 @@
99
# that was built in the previous step and runs CPU tests.
1010
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
1111
# uploads them to a GCS bucket.
12-
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
12+
# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_py_import_rbe.yml` workflow which
13+
# runs Bazel CPU tests with py_import on RBE.
14+
# 5. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
1315
# artifacts that were built in the previous steps and runs the CUDA tests.
14-
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib
16+
# 6. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib
1517
# and CUDA artifacts that were built in the previous steps and runs the
1618
# CUDA tests using Bazel.
1719

1820
name: CI - Wheel Tests (Continuous)
21+
permissions:
22+
contents: read
1923

2024
on:
2125
schedule:
@@ -136,6 +140,20 @@ jobs:
136140
# GCS upload URI is the same for both artifact build jobs
137141
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
138142

143+
run-bazel-test-cpu-py-import:
144+
uses: ./.github/workflows/bazel_cpu_py_import_rbe.yml
145+
strategy:
146+
fail-fast: false # don't cancel all jobs on failure
147+
matrix:
148+
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"]
149+
python: ["3.10",]
150+
enable-x64: [1, 0]
151+
name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}"
152+
with:
153+
runner: ${{ matrix.runner }}
154+
python: ${{ matrix.python }}
155+
enable-x64: ${{ matrix.enable-x64 }}
156+
139157
run-bazel-test-cuda:
140158
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
141159
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/bin/bash
2+
# Copyright 2025 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
# Runs Bazel CPU tests with py_import on RBE.
17+
#
18+
# -e: abort script if one command fails
19+
# -u: error if undefined variable used
20+
# -x: log all commands
21+
# -o history: record shell history
22+
# -o allexport: export all functions and variables to be available to subscripts
23+
set -exu -o history -o allexport
24+
25+
# Source default JAXCI environment variables.
26+
source ci/envs/default.env
27+
28+
# Clone XLA at HEAD if path to local XLA is not provided
29+
if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then
30+
export JAXCI_CLONE_MAIN_XLA=1
31+
fi
32+
33+
# Set up the build environment.
34+
source "ci/utilities/setup_build_environment.sh"
35+
36+
# Run Bazel CPU tests with RBE.
37+
os=$(uname -s | awk '{print tolower($0)}')
38+
arch=$(uname -m)
39+
40+
echo "Running CPU tests..."
41+
# When running on Mac or Linux Aarch64, we build the test targets on RBE
42+
# and run the tests locally. These platforms do not have native RBE support so
43+
# we RBE cross-compile them on remote Linux x86 machines.
44+
if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then
45+
bazel test --config=rbe_cross_compile_${os}_${arch} \
46+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
47+
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
48+
--test_env=JAX_NUM_GENERATED_CASES=25 \
49+
--test_env=JAX_SKIP_SLOW_TESTS=true \
50+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
51+
--test_output=errors \
52+
--color=yes \
53+
--strategy=TestRunner=local \
54+
--//jax:build_jaxlib=wheel \
55+
--//jax:build_jax=wheel \
56+
//tests:cpu_tests //tests:backend_independent_tests
57+
else
58+
bazel test --config=rbe_${os}_${arch} \
59+
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \
60+
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \
61+
--test_env=JAX_NUM_GENERATED_CASES=25 \
62+
--test_env=JAX_SKIP_SLOW_TESTS=true \
63+
--action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \
64+
--test_output=errors \
65+
--color=yes \
66+
--//jax:build_jaxlib=wheel \
67+
--//jax:build_jax=wheel \
68+
//tests:cpu_tests //tests:backend_independent_tests
69+
fi

0 commit comments

Comments
 (0)