Skip to content

[1/N] Add BackendOptions class #11389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/cccclai/21/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions runtime/backend/backend_options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <executorch/runtime/core/error.h>
#include <cstddef>
#include <cstring>
#include <variant>

namespace executorch {
namespace runtime {

// Strongly-typed option key template
template <typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class doesnt need to be templated. No use of T

struct OptionKey {
const char* key;
constexpr explicit OptionKey(const char* k) : key(k) {}
};

// Union replaced with std::variant
using OptionValue = std::variant<bool, int, const char*>;

struct BackendOption {
const char* key; // key is the name of the backend option, like num_threads,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit concerned about storing pointer to the key. Users may do things like

std::string my_key("my_key");
BackendOption().set_option<int>(my_key.c_str(), 10);

When that goes out of scope we have a problem. I suggest instead use something like const char[kMaxKeyLength] key

// enable_profiling, etc
OptionValue
value; // value is the value of the backend option, like 4, true, etc
Comment on lines +31 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format?

};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other thing I realize is why do we need BackendOption as a separate struct?

Can we not just have BackendOptions class instantiate array of keys and values? The interface there is set_option with key name and value, get_option using key name. User never really needs to get BackendOption

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users don't need get backend option, but the dispatcher needs it and dispatch to the actual backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you're right. The backend options map is for dispatching. and each backend will get an array ref of backend options.


template <size_t MaxCapacity>
class BackendOptions {
public:
// Initialize with zero options
BackendOptions() : size_(0) {}

// Type-safe setters
template <typename T>
void set_option(OptionKey<T> key, T value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bassed on suggestion above

Suggested change
void set_option(OptionKey<T> key, T value) {
void set_option(const char [kMaxKeyLength]& key, T value) {

const char* k = key.key;
// Update existing if found
for (size_t i = 0; i < size_; ++i) {
if (strcmp(options_[i].key, k) == 0) {
options_[i].value = value;
return;
}
}
// Add new option if space available
if (size_ < MaxCapacity) {
options_[size_++] = BackendOption{k, value};
}
Comment on lines +53 to +55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set error if no capacity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point. Will update

}

// Type-safe getters
template <typename T>
Error get_option(OptionKey<T> key, T& out) const {
const char* k = key.key;
for (size_t i = 0; i < size_; ++i) {
if (strcmp(options_[i].key, k) == 0) {
if (auto* val = std::get_if<T>(&options_[i].value)) {
out = *val;
return Error::Ok;
}
return Error::InvalidArgument;
}
}
return Error::NotFound;
}

private:
BackendOption options_[MaxCapacity]{}; // Storage for backend options
size_t size_; // Current number of options
};

// Helper functions for creating typed option keys (unchanged)
constexpr OptionKey<bool> BoolKey(const char* k) {
return OptionKey<bool>(k);
}

constexpr OptionKey<int> IntKey(const char* k) {
return OptionKey<int>(k);
}

constexpr OptionKey<const char*> StrKey(const char* k) {
return OptionKey<const char*>(k);
}
Comment on lines +79 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont think you need these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just helper functions to simplify users' code.


} // namespace runtime
} // namespace executorch
1 change: 1 addition & 0 deletions runtime/backend/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def define_common_targets():
exported_headers = [
"backend_execution_context.h",
"backend_init_context.h",
"backend_options.h",
"interface.h",
],
preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
Expand Down
130 changes: 130 additions & 0 deletions runtime/backend/test/backend_options_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/backend/backend_options.h>
#include <executorch/runtime/platform/runtime.h>

#include <gtest/gtest.h>

using namespace ::testing;
using executorch::runtime::BackendOptions;
using executorch::runtime::BoolKey;
using executorch::runtime::Error;
using executorch::runtime::IntKey;
using executorch::runtime::OptionKey;
using executorch::runtime::StrKey;

class BackendOptionsTest : public ::testing::Test {
protected:
void SetUp() override {
// Since these tests cause ET_LOG to be called, the PAL must be initialized
// first.
executorch::runtime::runtime_init();
}
BackendOptions<5> options; // Capacity of 5 for testing limits
};

// Test basic string functionality
TEST_F(BackendOptionsTest, HandlesStringOptions) {
// Set and retrieve valid string
options.set_option(StrKey("backend_type"), "GPU");
const char* result = nullptr;
EXPECT_EQ(options.get_option(StrKey("backend_type"), result), Error::Ok);
EXPECT_STREQ(result, "GPU");

// Update existing key
options.set_option(StrKey("backend_type"), "CPU");
EXPECT_EQ(options.get_option(StrKey("backend_type"), result), Error::Ok);
EXPECT_STREQ(result, "CPU");
}

// Test boolean options
TEST_F(BackendOptionsTest, HandlesBoolOptions) {
options.set_option(BoolKey("debug"), true);
bool debug = false;
EXPECT_EQ(options.get_option(BoolKey("debug"), debug), Error::Ok);
EXPECT_TRUE(debug);

// Test false value
options.set_option(BoolKey("verbose"), false);
EXPECT_EQ(options.get_option(BoolKey("verbose"), debug), Error::Ok);
EXPECT_FALSE(debug);
}

// Test integer options
TEST_F(BackendOptionsTest, HandlesIntOptions) {
options.set_option(IntKey("num_threads"), 256);
int num_threads = 0;
EXPECT_EQ(options.get_option(IntKey("num_threads"), num_threads), Error::Ok);
EXPECT_EQ(num_threads, 256);
}

// Test error conditions
TEST_F(BackendOptionsTest, HandlesErrors) {
// Non-existent key
bool dummy_bool;
EXPECT_EQ(
options.get_option(BoolKey("missing"), dummy_bool), Error::NotFound);

// Type mismatch
options.set_option(IntKey("threshold"), 100);
const char* dummy_str = nullptr;
EXPECT_EQ(
options.get_option(StrKey("threshold"), dummy_str),
Error::InvalidArgument);

// Null value handling
options.set_option(StrKey("nullable"), static_cast<const char*>(nullptr));
EXPECT_EQ(options.get_option(StrKey("nullable"), dummy_str), Error::Ok);
EXPECT_EQ(dummy_str, nullptr);
}

// Test capacity limits
TEST_F(BackendOptionsTest, HandlesCapacity) {
// Use persistent storage for keys
std::vector<std::string> keys = {"key0", "key1", "key2", "key3", "key4"};

// Fill to capacity with persistent keys
for (int i = 0; i < 5; i++) {
options.set_option(IntKey(keys[i].c_str()), i);
}

// Verify all exist
int value;
for (int i = 0; i < 5; i++) {
EXPECT_EQ(options.get_option(IntKey(keys[i].c_str()), value), Error::Ok);
EXPECT_EQ(value, i);
}

// Add beyond capacity - should fail
const char* overflow_key = "overflow";
options.set_option(IntKey(overflow_key), 99);
EXPECT_EQ(options.get_option(IntKey(overflow_key), value), Error::NotFound);

// Update existing within capacity
options.set_option(IntKey(keys[2].c_str()), 222);
EXPECT_EQ(options.get_option(IntKey(keys[2].c_str()), value), Error::Ok);
EXPECT_EQ(value, 222);
}

// Test type-specific keys
TEST_F(BackendOptionsTest, EnforcesKeyTypes) {
// Same key name - later set operations overwrite earlier ones
options.set_option(BoolKey("flag"), true);
options.set_option(IntKey("flag"), 123); // Overwrites the boolean entry

bool bval;
int ival;

// Boolean get should fail - type was overwritten to INT
EXPECT_EQ(options.get_option(BoolKey("flag"), bval), Error::InvalidArgument);

// Integer get should succeed with correct value
EXPECT_EQ(options.get_option(IntKey("flag"), ival), Error::Ok);
EXPECT_EQ(ival, 123);
}
11 changes: 10 additions & 1 deletion runtime/backend/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
pass
runtime.cxx_test(
name = "backend_options_test",
srcs = ["backend_options_test.cpp"],
deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/backend:interface",
],
)
Loading