From ace13843cf50e7b54562253dd96904e0d2e160e6 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Wed, 7 May 2025 14:59:12 -0700 Subject: [PATCH] Fix set_external_context() during state kInit @tensorflow/micro Fixes MicroInterpreterContext::set_external_context so that it can be called during InterpreterState::kInit. Add unit test for kInit state. Cleanup other external context tests. bug=fixes External Context in Prepare Stage #3101 --- .../lite/micro/micro_interpreter_context.cc | 3 +- .../micro/micro_interpreter_context_test.cc | 36 ++++++++++++++++--- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc index 38c6225be0a..62b33afe631 100644 --- a/tensorflow/lite/micro/micro_interpreter_context.cc +++ b/tensorflow/lite/micro/micro_interpreter_context.cc @@ -112,7 +112,8 @@ void MicroInterpreterContext::SetScratchBufferHandles( TfLiteStatus MicroInterpreterContext::set_external_context( void* external_context_payload) { - TFLITE_DCHECK(state_ == InterpreterState::kPrepare || + TFLITE_DCHECK(state_ == InterpreterState::kInit || + state_ == InterpreterState::kPrepare || state_ == InterpreterState::kInvoke); if (external_context_payload == nullptr || external_context_payload_ != nullptr) { diff --git a/tensorflow/lite/micro/micro_interpreter_context_test.cc b/tensorflow/lite/micro/micro_interpreter_context_test.cc index 3af123f5511..fd7fb43831f 100644 --- a/tensorflow/lite/micro/micro_interpreter_context_test.cc +++ b/tensorflow/lite/micro/micro_interpreter_context_test.cc @@ -54,8 +54,8 @@ struct TestExternalContextPayloadData { TF_LITE_MICRO_TESTS_BEGIN -// Ensures that a regular set and get pair works ok. -TF_LITE_MICRO_TEST(TestSetGetExternalContextSuccess) { +// Ensures that a regular set and get pair works ok during state kInvoke. +TF_LITE_MICRO_TEST(TestSetGetExternalContextSuccessInvoke) { tflite::MicroInterpreterContext micro_context = tflite::CreateMicroInterpreterContext(); micro_context.SetInterpreterState( @@ -70,19 +70,36 @@ TF_LITE_MICRO_TEST(TestSetGetExternalContextSuccess) { micro_context.external_context()); // What is returned should be the same as what is set. - TF_LITE_MICRO_EXPECT((void*)returned_external_context == (void*)(&payload)); + TF_LITE_MICRO_EXPECT(returned_external_context == &payload); } -TF_LITE_MICRO_TEST(TestGetExternalContextWithoutSetShouldReturnNull) { +// Ensures that a regular set and get pair works ok during state kInit. +TF_LITE_MICRO_TEST(TestSetGetExternalContextSuccessInit) { tflite::MicroInterpreterContext micro_context = tflite::CreateMicroInterpreterContext(); + micro_context.SetInterpreterState( + tflite::MicroInterpreterContext::InterpreterState::kInit); + + tflite::TestExternalContextPayloadData payload; + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, + micro_context.set_external_context(&payload)); tflite::TestExternalContextPayloadData* returned_external_context = reinterpret_cast( micro_context.external_context()); + // What is returned should be the same as what is set. + TF_LITE_MICRO_EXPECT(returned_external_context == &payload); +} + +TF_LITE_MICRO_TEST(TestGetExternalContextWithoutSetShouldReturnNull) { + tflite::MicroInterpreterContext micro_context = + tflite::CreateMicroInterpreterContext(); + + void* returned_external_context = micro_context.external_context(); + // Return a null if nothing is set before. - TF_LITE_MICRO_EXPECT((void*)returned_external_context == (nullptr)); + TF_LITE_MICRO_EXPECT(returned_external_context == nullptr); } TF_LITE_MICRO_TEST(TestSetExternalContextCanOnlyBeCalledOnce) { @@ -98,6 +115,15 @@ TF_LITE_MICRO_TEST(TestSetExternalContextCanOnlyBeCalledOnce) { // Another set should fail. TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, micro_context.set_external_context(&payload)); + + // Null set should fail. + TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, + micro_context.set_external_context(nullptr)); + tflite::TestExternalContextPayloadData* returned_external_context = + reinterpret_cast( + micro_context.external_context()); + // Payload should be unchanged. + TF_LITE_MICRO_EXPECT(&payload == returned_external_context); } TF_LITE_MICRO_TEST(TestSetExternalContextToNullShouldFail) {