Skip to content

Commit 08dcaad

Browse files
nvgrwGoogle-ML-Automation
authored andcommitted
Split RunAndCompare with reference backend functionality into a mixin.
Many users don't require `RunAndCompare` functionality, but are forced to select and initialize a reference backend anyway. With this change, users can opt to extend their specific `HloRunnerAgnosticTestBase` implementation to add `RunAndCompare` functionality. The mixin acts as a wrapper around any `HloRunnerAgnosticTestBase` implementation, allowing a high degree of customization. PiperOrigin-RevId: 714085396
1 parent dd246d5 commit 08dcaad

File tree

3 files changed

+323
-0
lines changed

3 files changed

+323
-0
lines changed

xla/tests/BUILD

+25
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,31 @@ cc_library(
248248
],
249249
)
250250

251+
cc_library(
252+
name = "hlo_runner_agnostic_reference_mixin",
253+
testonly = True,
254+
srcs = ["hlo_runner_agnostic_reference_mixin.cc"],
255+
hdrs = ["hlo_runner_agnostic_reference_mixin.h"],
256+
deps = [
257+
":hlo_runner_agnostic_test_base",
258+
":literal_test_util",
259+
":test_utils",
260+
"//xla:error_spec",
261+
"//xla:literal",
262+
"//xla:shape_util",
263+
"//xla/hlo/ir:hlo",
264+
"//xla/hlo/testlib:verified_hlo_module",
265+
"//xla/service:hlo_runner_interface",
266+
"//xla/tsl/platform:errors",
267+
"//xla/tsl/platform:statusor",
268+
"//xla/tsl/platform:test",
269+
"@com_google_absl//absl/base:nullability",
270+
"@com_google_absl//absl/log",
271+
"@com_google_absl//absl/strings:string_view",
272+
"@com_google_absl//absl/types:span",
273+
],
274+
)
275+
251276
cc_library(
252277
name = "hlo_pjrt_test_base",
253278
testonly = True,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/tests/hlo_runner_agnostic_reference_mixin.h"
17+
18+
#include "xla/hlo/ir/hlo_module.h"
19+
#include "xla/shape.h"
20+
21+
namespace xla {
22+
23+
ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
24+
ProgramShape program_shape;
25+
const auto* entry = module.entry_computation();
26+
for (const auto* param : entry->parameter_instructions()) {
27+
*program_shape.add_parameters() = param->shape();
28+
*program_shape.add_parameter_names() = param->name();
29+
}
30+
*program_shape.mutable_result() = entry->root_instruction()->shape();
31+
return program_shape;
32+
}
33+
34+
bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) {
35+
if (lhs.parameters_size() != rhs.parameters_size()) {
36+
return false;
37+
}
38+
for (int i = 0; i < lhs.parameters_size(); ++i) {
39+
if (!Shape::Equal().IgnoreElementSizeInLayout()(lhs.parameters(i),
40+
rhs.parameters(i))) {
41+
return false;
42+
}
43+
}
44+
return Shape::Equal().IgnoreElementSizeInLayout()(lhs.result(), rhs.result());
45+
}
46+
47+
} // namespace xla
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_
17+
#define XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_
18+
19+
#include <cstdint>
20+
#include <functional>
21+
#include <iterator>
22+
#include <memory>
23+
#include <optional>
24+
#include <utility>
25+
#include <vector>
26+
27+
#include "absl/base/nullability.h"
28+
#include "absl/log/log.h"
29+
#include "absl/strings/string_view.h"
30+
#include "absl/types/span.h"
31+
#include "xla/error_spec.h"
32+
#include "xla/hlo/ir/hlo_module.h"
33+
#include "xla/hlo/testlib/verified_hlo_module.h"
34+
#include "xla/literal.h"
35+
#include "xla/service/hlo_runner_interface.h"
36+
#include "xla/shape.h"
37+
#include "xla/tests/hlo_runner_agnostic_test_base.h"
38+
#include "xla/tests/literal_test_util.h"
39+
#include "xla/tests/test_utils.h"
40+
#include "xla/tsl/platform/errors.h"
41+
#include "xla/tsl/platform/statusor.h"
42+
#include "xla/tsl/platform/test.h"
43+
44+
namespace xla {
45+
46+
ProgramShape GetProgramShapeWithLayout(const HloModule& module);
47+
48+
bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs);
49+
50+
// This class is designed to be used as a mixin for tests that want to run
51+
// against a reference implementation via a runner implementing
52+
// HloRunnerInterface.
53+
//
54+
// The mixin requires that that the test class is a subclass of
55+
// HloRunnerAgnosticTestBase.
56+
template <typename T>
57+
class HloRunnerAgnosticReferenceMixin : public T {
58+
static_assert(
59+
std::is_base_of_v<HloRunnerAgnosticTestBase, T>,
60+
"Mixin must be used with a subclass of HloRunnerAgnosticTestBase.");
61+
62+
protected:
63+
template <typename... BaseArgs>
64+
explicit HloRunnerAgnosticReferenceMixin(
65+
absl::Nonnull<std::unique_ptr<HloRunnerInterface>> reference_runner,
66+
BaseArgs&&... base_args)
67+
: T(std::forward<BaseArgs>(base_args)...),
68+
reference_runner_(std::move(reference_runner)) {}
69+
~HloRunnerAgnosticReferenceMixin() override = default;
70+
71+
// Executes the given hlo module on two backends and compares results.
72+
//
73+
// 'arguments': the input of the hlo module.
74+
//
75+
// 'error': if has value, expects the results to be near (within the error
76+
// bound). Otherwise, expects the results to be equal.
77+
//
78+
// 'reference_preprocessor': the module should be ready to run on the test
79+
// backend, but it might need to be tailored so that it is able to run on the
80+
// reference backend. Note that the program shape of the module must not be
81+
// modified.
82+
::testing::AssertionResult RunAndCompare(
83+
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
84+
const std::optional<ErrorSpec>& error,
85+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
86+
const std::function<void(HloModule*)>& test_preprocessor = nullptr) {
87+
const absl::StatusOr<::testing::AssertionResult> result =
88+
RunAndCompareInternal(std::move(module), arguments, error,
89+
/*run_hlo_passes=*/true, reference_preprocessor,
90+
test_preprocessor);
91+
if (!result.ok()) {
92+
return ::testing::AssertionFailure() << result.status();
93+
}
94+
return *result;
95+
}
96+
97+
// Same as above, except that the module will be executed without Hlo
98+
// optimization.
99+
::testing::AssertionResult RunAndCompareNoHloPasses(
100+
std::unique_ptr<HloModule> module,
101+
const absl::Span<Literal* const> arguments,
102+
const std::optional<ErrorSpec>& error,
103+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
104+
const std::function<void(HloModule*)>& test_preprocessor = nullptr) {
105+
const absl::StatusOr<::testing::AssertionResult> result =
106+
RunAndCompareInternal(std::move(module), arguments, error,
107+
/*run_hlo_passes=*/false, reference_preprocessor,
108+
test_preprocessor);
109+
if (!result.ok()) {
110+
return ::testing::AssertionFailure() << result.status();
111+
}
112+
return *result;
113+
}
114+
115+
// Executes an hlo module with fake inputs and compares the results.
116+
::testing::AssertionResult RunAndCompare(
117+
std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error,
118+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
119+
const std::function<void(HloModule*)>& test_preprocessor = nullptr,
120+
const std::optional<int64_t> args_max_bits_of_precision = std::nullopt) {
121+
const absl::StatusOr<std::vector<Literal>> fake_arguments =
122+
MakeFakeArguments(module.get(), /*pseudo_random=*/true,
123+
/*use_large_range=*/false,
124+
/*treat_gte_as_data_formatting=*/false,
125+
args_max_bits_of_precision);
126+
if (!fake_arguments.ok()) {
127+
return ::testing::AssertionFailure() << fake_arguments.status().message();
128+
}
129+
std::vector<Literal*> fake_argument_ptrs;
130+
absl::c_transform(
131+
*fake_arguments, std::back_inserter(fake_argument_ptrs),
132+
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
133+
134+
return RunAndCompare(std::move(module), fake_argument_ptrs, error,
135+
reference_preprocessor, test_preprocessor);
136+
}
137+
138+
// Same as above, except that the module will be executed without Hlo
139+
// optimization.
140+
::testing::AssertionResult RunAndCompareNoHloPasses(
141+
std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error,
142+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
143+
const std::function<void(HloModule*)>& test_preprocessor = nullptr) {
144+
const absl::StatusOr<std::vector<Literal>> fake_arguments =
145+
MakeFakeArguments(module.get());
146+
if (!fake_arguments.ok()) {
147+
return ::testing::AssertionFailure() << fake_arguments.status().message();
148+
}
149+
std::vector<Literal*> fake_argument_ptrs;
150+
absl::c_transform(
151+
*fake_arguments, std::back_inserter(fake_argument_ptrs),
152+
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
153+
return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs,
154+
error, reference_preprocessor,
155+
test_preprocessor);
156+
}
157+
158+
// Convenient wrapper for executing and comparing an hlo module with fake
159+
// input. Module can be passed in directly, or parsed from an hlo_string,
160+
// or loaded from a file.
161+
::testing::AssertionResult RunAndCompare(
162+
const absl::string_view hlo_string, const std::optional<ErrorSpec>& error,
163+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
164+
const std::function<void(HloModule*)>& test_preprocessor = nullptr,
165+
const std::optional<int64_t> args_max_bits_of_precision = std::nullopt) {
166+
absl::StatusOr<std::unique_ptr<VerifiedHloModule>> module =
167+
this->ParseAndReturnVerifiedModule(hlo_string);
168+
if (!module.ok()) {
169+
return ::testing::AssertionFailure()
170+
<< "Error while parsing HLO text format: "
171+
<< module.status().ToString();
172+
}
173+
return RunAndCompare(*std::move(module), error, reference_preprocessor,
174+
test_preprocessor, args_max_bits_of_precision);
175+
}
176+
177+
::testing::AssertionResult RunAndCompareNoHloPasses(
178+
const absl::string_view hlo_string, const std::optional<ErrorSpec>& error,
179+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
180+
const std::function<void(HloModule*)>& test_preprocessor = nullptr) {
181+
absl::StatusOr<std::unique_ptr<VerifiedHloModule>> module =
182+
this->ParseAndReturnVerifiedModule(hlo_string);
183+
if (!module.ok()) {
184+
return ::testing::AssertionFailure()
185+
<< "Error while parsing HLO text format: "
186+
<< module.status().ToString();
187+
}
188+
return RunAndCompareNoHloPasses(*std::move(module), error,
189+
reference_preprocessor, test_preprocessor);
190+
}
191+
192+
HloRunnerInterface& reference_runner() const { return *reference_runner_; }
193+
194+
private:
195+
// Given the test module, makes a reference module that is ready to run on the
196+
// reference platform. This assumes that the given module is ready to run on
197+
// the test platform.
198+
absl::StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule(
199+
const HloModule& test_module,
200+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr) {
201+
std::unique_ptr<HloModule> reference_module = test_module.Clone();
202+
const ProgramShape program_shape = GetProgramShapeWithLayout(test_module);
203+
204+
if (reference_preprocessor != nullptr) {
205+
reference_preprocessor(reference_module.get());
206+
if (!ProgramShapesEqual(program_shape,
207+
GetProgramShapeWithLayout(*reference_module))) {
208+
return absl::InvalidArgumentError(
209+
"reference preprocessor must not modify the program shape");
210+
}
211+
}
212+
TF_RETURN_IF_ERROR(this->verifier().Run(reference_module.get()).status());
213+
return std::move(reference_module);
214+
}
215+
216+
// Runs the module on two platforms with or without running hlo passes and
217+
// compares the results. Returns whether the results are near or equal. If any
218+
// error happens before the results are computed, returns the error status.
219+
absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal(
220+
std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments,
221+
const std::optional<ErrorSpec>& error, bool run_hlo_passes,
222+
const std::function<void(HloModule*)>& reference_preprocessor = nullptr,
223+
const std::function<void(HloModule*)>& test_preprocessor = nullptr) {
224+
TF_RETURN_IF_ERROR(this->verifier().Run(module.get()).status());
225+
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> reference_module,
226+
MakeReferenceModule(*module, reference_preprocessor));
227+
TF_RETURN_IF_ERROR(this->PreprocessModuleForTestRunner(module.get()));
228+
if (test_preprocessor != nullptr) {
229+
test_preprocessor(module.get());
230+
}
231+
// Execute on two backends.
232+
TF_ASSIGN_OR_RETURN(const Literal test,
233+
this->test_runner().Execute(std::move(module),
234+
arguments, run_hlo_passes));
235+
TF_ASSIGN_OR_RETURN(const Literal reference,
236+
reference_runner_->Execute(std::move(reference_module),
237+
arguments, run_hlo_passes));
238+
if (reference.IsAll(0)) {
239+
LOG(WARNING) << "Reference value is only zeros.";
240+
}
241+
242+
return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
243+
error);
244+
}
245+
246+
std::unique_ptr<HloRunnerInterface> reference_runner_;
247+
};
248+
249+
} // namespace xla
250+
251+
#endif // XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_

0 commit comments

Comments
 (0)