|
| 1 | +/* Copyright 2025 The OpenXLA Authors. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#ifndef XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ |
| 17 | +#define XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ |
| 18 | + |
| 19 | +#include <cstdint> |
| 20 | +#include <functional> |
| 21 | +#include <iterator> |
| 22 | +#include <memory> |
| 23 | +#include <optional> |
| 24 | +#include <utility> |
| 25 | +#include <vector> |
| 26 | + |
| 27 | +#include "absl/base/nullability.h" |
| 28 | +#include "absl/log/log.h" |
| 29 | +#include "absl/strings/string_view.h" |
| 30 | +#include "absl/types/span.h" |
| 31 | +#include "xla/error_spec.h" |
| 32 | +#include "xla/hlo/ir/hlo_module.h" |
| 33 | +#include "xla/hlo/testlib/verified_hlo_module.h" |
| 34 | +#include "xla/literal.h" |
| 35 | +#include "xla/service/hlo_runner_interface.h" |
| 36 | +#include "xla/shape.h" |
| 37 | +#include "xla/tests/hlo_runner_agnostic_test_base.h" |
| 38 | +#include "xla/tests/literal_test_util.h" |
| 39 | +#include "xla/tests/test_utils.h" |
| 40 | +#include "xla/tsl/platform/errors.h" |
| 41 | +#include "xla/tsl/platform/statusor.h" |
| 42 | +#include "xla/tsl/platform/test.h" |
| 43 | + |
| 44 | +namespace xla { |
| 45 | + |
| 46 | +ProgramShape GetProgramShapeWithLayout(const HloModule& module); |
| 47 | + |
| 48 | +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs); |
| 49 | + |
| 50 | +// This class is designed to be used as a mixin for tests that want to run |
| 51 | +// against a reference implementation via a runner implementing |
| 52 | +// HloRunnerInterface. |
| 53 | +// |
| 54 | +// The mixin requires that that the test class is a subclass of |
| 55 | +// HloRunnerAgnosticTestBase. |
| 56 | +template <typename T> |
| 57 | +class HloRunnerAgnosticReferenceMixin : public T { |
| 58 | + static_assert( |
| 59 | + std::is_base_of_v<HloRunnerAgnosticTestBase, T>, |
| 60 | + "Mixin must be used with a subclass of HloRunnerAgnosticTestBase."); |
| 61 | + |
| 62 | + protected: |
| 63 | + template <typename... BaseArgs> |
| 64 | + explicit HloRunnerAgnosticReferenceMixin( |
| 65 | + absl::Nonnull<std::unique_ptr<HloRunnerInterface>> reference_runner, |
| 66 | + BaseArgs&&... base_args) |
| 67 | + : T(std::forward<BaseArgs>(base_args)...), |
| 68 | + reference_runner_(std::move(reference_runner)) {} |
| 69 | + ~HloRunnerAgnosticReferenceMixin() override = default; |
| 70 | + |
| 71 | + // Executes the given hlo module on two backends and compares results. |
| 72 | + // |
| 73 | + // 'arguments': the input of the hlo module. |
| 74 | + // |
| 75 | + // 'error': if has value, expects the results to be near (within the error |
| 76 | + // bound). Otherwise, expects the results to be equal. |
| 77 | + // |
| 78 | + // 'reference_preprocessor': the module should be ready to run on the test |
| 79 | + // backend, but it might need to be tailored so that it is able to run on the |
| 80 | + // reference backend. Note that the program shape of the module must not be |
| 81 | + // modified. |
| 82 | + ::testing::AssertionResult RunAndCompare( |
| 83 | + std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, |
| 84 | + const std::optional<ErrorSpec>& error, |
| 85 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 86 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr) { |
| 87 | + const absl::StatusOr<::testing::AssertionResult> result = |
| 88 | + RunAndCompareInternal(std::move(module), arguments, error, |
| 89 | + /*run_hlo_passes=*/true, reference_preprocessor, |
| 90 | + test_preprocessor); |
| 91 | + if (!result.ok()) { |
| 92 | + return ::testing::AssertionFailure() << result.status(); |
| 93 | + } |
| 94 | + return *result; |
| 95 | + } |
| 96 | + |
| 97 | + // Same as above, except that the module will be executed without Hlo |
| 98 | + // optimization. |
| 99 | + ::testing::AssertionResult RunAndCompareNoHloPasses( |
| 100 | + std::unique_ptr<HloModule> module, |
| 101 | + const absl::Span<Literal* const> arguments, |
| 102 | + const std::optional<ErrorSpec>& error, |
| 103 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 104 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr) { |
| 105 | + const absl::StatusOr<::testing::AssertionResult> result = |
| 106 | + RunAndCompareInternal(std::move(module), arguments, error, |
| 107 | + /*run_hlo_passes=*/false, reference_preprocessor, |
| 108 | + test_preprocessor); |
| 109 | + if (!result.ok()) { |
| 110 | + return ::testing::AssertionFailure() << result.status(); |
| 111 | + } |
| 112 | + return *result; |
| 113 | + } |
| 114 | + |
| 115 | + // Executes an hlo module with fake inputs and compares the results. |
| 116 | + ::testing::AssertionResult RunAndCompare( |
| 117 | + std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error, |
| 118 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 119 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr, |
| 120 | + const std::optional<int64_t> args_max_bits_of_precision = std::nullopt) { |
| 121 | + const absl::StatusOr<std::vector<Literal>> fake_arguments = |
| 122 | + MakeFakeArguments(module.get(), /*pseudo_random=*/true, |
| 123 | + /*use_large_range=*/false, |
| 124 | + /*treat_gte_as_data_formatting=*/false, |
| 125 | + args_max_bits_of_precision); |
| 126 | + if (!fake_arguments.ok()) { |
| 127 | + return ::testing::AssertionFailure() << fake_arguments.status().message(); |
| 128 | + } |
| 129 | + std::vector<Literal*> fake_argument_ptrs; |
| 130 | + absl::c_transform( |
| 131 | + *fake_arguments, std::back_inserter(fake_argument_ptrs), |
| 132 | + [](const Literal& literal) { return const_cast<Literal*>(&literal); }); |
| 133 | + |
| 134 | + return RunAndCompare(std::move(module), fake_argument_ptrs, error, |
| 135 | + reference_preprocessor, test_preprocessor); |
| 136 | + } |
| 137 | + |
| 138 | + // Same as above, except that the module will be executed without Hlo |
| 139 | + // optimization. |
| 140 | + ::testing::AssertionResult RunAndCompareNoHloPasses( |
| 141 | + std::unique_ptr<HloModule> module, const std::optional<ErrorSpec>& error, |
| 142 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 143 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr) { |
| 144 | + const absl::StatusOr<std::vector<Literal>> fake_arguments = |
| 145 | + MakeFakeArguments(module.get()); |
| 146 | + if (!fake_arguments.ok()) { |
| 147 | + return ::testing::AssertionFailure() << fake_arguments.status().message(); |
| 148 | + } |
| 149 | + std::vector<Literal*> fake_argument_ptrs; |
| 150 | + absl::c_transform( |
| 151 | + *fake_arguments, std::back_inserter(fake_argument_ptrs), |
| 152 | + [](const Literal& literal) { return const_cast<Literal*>(&literal); }); |
| 153 | + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, |
| 154 | + error, reference_preprocessor, |
| 155 | + test_preprocessor); |
| 156 | + } |
| 157 | + |
| 158 | + // Convenient wrapper for executing and comparing an hlo module with fake |
| 159 | + // input. Module can be passed in directly, or parsed from an hlo_string, |
| 160 | + // or loaded from a file. |
| 161 | + ::testing::AssertionResult RunAndCompare( |
| 162 | + const absl::string_view hlo_string, const std::optional<ErrorSpec>& error, |
| 163 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 164 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr, |
| 165 | + const std::optional<int64_t> args_max_bits_of_precision = std::nullopt) { |
| 166 | + absl::StatusOr<std::unique_ptr<VerifiedHloModule>> module = |
| 167 | + this->ParseAndReturnVerifiedModule(hlo_string); |
| 168 | + if (!module.ok()) { |
| 169 | + return ::testing::AssertionFailure() |
| 170 | + << "Error while parsing HLO text format: " |
| 171 | + << module.status().ToString(); |
| 172 | + } |
| 173 | + return RunAndCompare(*std::move(module), error, reference_preprocessor, |
| 174 | + test_preprocessor, args_max_bits_of_precision); |
| 175 | + } |
| 176 | + |
| 177 | + ::testing::AssertionResult RunAndCompareNoHloPasses( |
| 178 | + const absl::string_view hlo_string, const std::optional<ErrorSpec>& error, |
| 179 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 180 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr) { |
| 181 | + absl::StatusOr<std::unique_ptr<VerifiedHloModule>> module = |
| 182 | + this->ParseAndReturnVerifiedModule(hlo_string); |
| 183 | + if (!module.ok()) { |
| 184 | + return ::testing::AssertionFailure() |
| 185 | + << "Error while parsing HLO text format: " |
| 186 | + << module.status().ToString(); |
| 187 | + } |
| 188 | + return RunAndCompareNoHloPasses(*std::move(module), error, |
| 189 | + reference_preprocessor, test_preprocessor); |
| 190 | + } |
| 191 | + |
| 192 | + HloRunnerInterface& reference_runner() const { return *reference_runner_; } |
| 193 | + |
| 194 | + private: |
| 195 | + // Given the test module, makes a reference module that is ready to run on the |
| 196 | + // reference platform. This assumes that the given module is ready to run on |
| 197 | + // the test platform. |
| 198 | + absl::StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( |
| 199 | + const HloModule& test_module, |
| 200 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr) { |
| 201 | + std::unique_ptr<HloModule> reference_module = test_module.Clone(); |
| 202 | + const ProgramShape program_shape = GetProgramShapeWithLayout(test_module); |
| 203 | + |
| 204 | + if (reference_preprocessor != nullptr) { |
| 205 | + reference_preprocessor(reference_module.get()); |
| 206 | + if (!ProgramShapesEqual(program_shape, |
| 207 | + GetProgramShapeWithLayout(*reference_module))) { |
| 208 | + return absl::InvalidArgumentError( |
| 209 | + "reference preprocessor must not modify the program shape"); |
| 210 | + } |
| 211 | + } |
| 212 | + TF_RETURN_IF_ERROR(this->verifier().Run(reference_module.get()).status()); |
| 213 | + return std::move(reference_module); |
| 214 | + } |
| 215 | + |
| 216 | + // Runs the module on two platforms with or without running hlo passes and |
| 217 | + // compares the results. Returns whether the results are near or equal. If any |
| 218 | + // error happens before the results are computed, returns the error status. |
| 219 | + absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( |
| 220 | + std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, |
| 221 | + const std::optional<ErrorSpec>& error, bool run_hlo_passes, |
| 222 | + const std::function<void(HloModule*)>& reference_preprocessor = nullptr, |
| 223 | + const std::function<void(HloModule*)>& test_preprocessor = nullptr) { |
| 224 | + TF_RETURN_IF_ERROR(this->verifier().Run(module.get()).status()); |
| 225 | + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> reference_module, |
| 226 | + MakeReferenceModule(*module, reference_preprocessor)); |
| 227 | + TF_RETURN_IF_ERROR(this->PreprocessModuleForTestRunner(module.get())); |
| 228 | + if (test_preprocessor != nullptr) { |
| 229 | + test_preprocessor(module.get()); |
| 230 | + } |
| 231 | + // Execute on two backends. |
| 232 | + TF_ASSIGN_OR_RETURN(const Literal test, |
| 233 | + this->test_runner().Execute(std::move(module), |
| 234 | + arguments, run_hlo_passes)); |
| 235 | + TF_ASSIGN_OR_RETURN(const Literal reference, |
| 236 | + reference_runner_->Execute(std::move(reference_module), |
| 237 | + arguments, run_hlo_passes)); |
| 238 | + if (reference.IsAll(0)) { |
| 239 | + LOG(WARNING) << "Reference value is only zeros."; |
| 240 | + } |
| 241 | + |
| 242 | + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, |
| 243 | + error); |
| 244 | + } |
| 245 | + |
| 246 | + std::unique_ptr<HloRunnerInterface> reference_runner_; |
| 247 | +}; |
| 248 | + |
| 249 | +} // namespace xla |
| 250 | + |
| 251 | +#endif // XLA_TESTS_HLO_RUNNER_AGNOSTIC_REFERENCE_MIXIN_H_ |
0 commit comments