1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstdlib>
10 #include <cstring>
11 #include <functional>
12 #include <optional>
13 #include <vector>
14
15 #include <executorch/extension/data_loader/buffer_data_loader.h>
16 #include <executorch/extension/data_loader/file_data_loader.h>
17 #include <executorch/extension/runner_util/inputs.h>
18 #include <executorch/runtime/backend/interface.h>
19 #include <executorch/runtime/core/error.h>
20 #include <executorch/runtime/core/result.h>
21 #include <executorch/runtime/executor/method.h>
22 #include <executorch/runtime/executor/program.h>
23 #include <executorch/runtime/executor/test/managed_memory_manager.h>
24 #include <executorch/runtime/platform/runtime.h>
25 #include <executorch/test/utils/DeathTest.h>
26 #include <executorch/test/utils/alignment.h>
27
28 #include <gtest/gtest.h>
29
30 using namespace ::testing;
31 using exec_aten::ArrayRef;
32 using executorch::runtime::BackendExecutionContext;
33 using executorch::runtime::BackendInitContext;
34 using executorch::runtime::BackendInterface;
35 using executorch::runtime::CompileSpec;
36 using executorch::runtime::DataLoader;
37 using executorch::runtime::DelegateHandle;
38 using executorch::runtime::Error;
39 using executorch::runtime::EValue;
40 using executorch::runtime::FreeableBuffer;
41 using executorch::runtime::MemoryAllocator;
42 using executorch::runtime::Method;
43 using executorch::runtime::Program;
44 using executorch::runtime::Result;
45 using executorch::runtime::testing::ManagedMemoryManager;
46 using torch::executor::util::FileDataLoader;
47
48 /**
49 * A backend class whose methods can be overridden individually.
50 */
51 class StubBackend final : public BackendInterface {
52 public:
53 // Function signature types that match the BackendInterface methods.
54 using IsAvailableFn = std::function<bool()>;
55 using InitFn = std::function<Result<DelegateHandle*>(
56 FreeableBuffer*,
57 ArrayRef<CompileSpec>,
58 BackendInitContext&)>;
59 using ExecuteFn =
60 std::function<Error(BackendExecutionContext&, DelegateHandle*, EValue**)>;
61 using DestroyFn = std::function<void(DelegateHandle*)>;
62
63 // Default name that this backend is registered as.
64 static constexpr char kName[] = "StubBackend";
65
install_is_available(IsAvailableFn fn)66 void install_is_available(IsAvailableFn fn) {
67 is_available_fn_ = fn;
68 }
69
is_available() const70 bool is_available() const override {
71 if (is_available_fn_) {
72 return is_available_fn_.value()();
73 }
74 // Return a benign value otherwise.
75 return true;
76 }
77
install_init(InitFn fn)78 void install_init(InitFn fn) {
79 init_fn_ = fn;
80 }
81
init(BackendInitContext & context,FreeableBuffer * processed,ArrayRef<CompileSpec> compile_specs) const82 Result<DelegateHandle*> init(
83 BackendInitContext& context,
84 FreeableBuffer* processed,
85 ArrayRef<CompileSpec> compile_specs) const override {
86 if (init_fn_) {
87 return init_fn_.value()(processed, compile_specs, context);
88 }
89 // Return a benign value otherwise.
90 return nullptr;
91 }
92
install_execute(ExecuteFn fn)93 void install_execute(ExecuteFn fn) {
94 execute_fn_ = fn;
95 }
96
execute(BackendExecutionContext & context,DelegateHandle * handle,EValue ** args) const97 Error execute(
98 BackendExecutionContext& context,
99 DelegateHandle* handle,
100 EValue** args) const override {
101 if (execute_fn_) {
102 return execute_fn_.value()(context, handle, args);
103 }
104 // Return a benign value otherwise.
105 return Error::Ok;
106 }
107
install_destroy(DestroyFn fn)108 void install_destroy(DestroyFn fn) {
109 destroy_fn_ = fn;
110 }
111
destroy(DelegateHandle * handle) const112 void destroy(DelegateHandle* handle) const override {
113 if (destroy_fn_) {
114 destroy_fn_.value()(handle);
115 }
116 }
117
118 /**
119 * Resets to the original constructed state.
120 */
reset()121 void reset() {
122 is_available_fn_.reset();
123 init_fn_.reset();
124 execute_fn_.reset();
125 destroy_fn_.reset();
126 }
127
128 /**
129 * Registers the singleton instance if not already registered.
130 *
131 * Note that this can be used to install the stub as the implementation for
132 * any export-time backend by passing in the right name, as long as no other
133 * backend with that name has been registered yet.
134 */
register_singleton(const char * name=kName)135 static Error register_singleton(const char* name = kName) {
136 if (!registered_) {
137 registered_ = true;
138 return executorch::runtime::register_backend({name, &singleton_});
139 }
140 return Error::Ok;
141 }
142
143 /**
144 * Returns the instance that was added to the backend registry.
145 */
singleton()146 static StubBackend& singleton() {
147 return singleton_;
148 }
149
150 private:
151 static bool registered_;
152 static StubBackend singleton_;
153
154 std::optional<IsAvailableFn> is_available_fn_;
155 std::optional<InitFn> init_fn_;
156 std::optional<ExecuteFn> execute_fn_;
157 std::optional<DestroyFn> destroy_fn_;
158 };
159
160 bool StubBackend::registered_ = false;
161 StubBackend StubBackend::singleton_;
162
163 /**
164 * A DataLoader that wraps a real DataLoader and records the operations
165 * performed on it and the FreeableBuffers it loads.
166 */
167 class DataLoaderSpy final : public DataLoader {
168 public:
169 /// A record of an operation performed on this DataLoader.
170 struct Operation {
171 enum { Load, Free } op;
172 size_t offset; // Set for Load; zero for Free.
173 void* data; // Set for Free; nullptr for Load.
174 size_t size; // Set for Load and Free.
175 std::unique_ptr<const DataLoader::SegmentInfo>
176 segment_info; // Set for Load; nullptr for Free.
177 };
178
DataLoaderSpy(DataLoader * delegate)179 explicit DataLoaderSpy(DataLoader* delegate) : delegate_(delegate) {}
180
load(size_t offset,size_t size,const SegmentInfo & segment_info) const181 Result<FreeableBuffer> load(
182 size_t offset,
183 size_t size,
184 const SegmentInfo& segment_info) const override {
185 Result<FreeableBuffer> buf = delegate_->load(offset, size, segment_info);
186 if (!buf.ok()) {
187 return buf.error();
188 }
189
190 auto segment_info_cpy =
191 std::make_unique<const DataLoader::SegmentInfo>(segment_info);
192 operations_.push_back(
193 {Operation::Load,
194 offset,
195 /*data=*/nullptr,
196 size,
197 /*segment_info=*/std::move(segment_info_cpy)});
198 auto* context = new SpyContext(&operations_, std::move(buf.get()));
199 // Use context->buffer since buf has been moved.
200 return FreeableBuffer(
201 context->buffer.data(), context->buffer.size(), FreeBuffer, context);
202 }
203
size() const204 Result<size_t> size() const override {
205 return delegate_->size();
206 }
207
208 /**
209 * Returns records of the operations performed on this DataLoader and the
210 * FreeableBuffers it returned, in order they were performed.
211 */
operations() const212 const std::vector<Operation>& operations() const {
213 return operations_;
214 }
215
216 /**
217 * Returns true if the DataLoader::load() method was called with the correct
218 * segment info.
219 */
UsedLoad(DataLoader::SegmentInfo::Type segment_type,const char * descriptor=nullptr) const220 bool UsedLoad(
221 DataLoader::SegmentInfo::Type segment_type,
222 const char* descriptor = nullptr) const {
223 for (const auto& op : operations_) {
224 if (op.op != Operation::Load) {
225 continue;
226 }
227 // We have a load op.
228 if (op.segment_info->segment_type == segment_type) {
229 if (segment_type != DataLoader::SegmentInfo::Type::Backend) {
230 // For non-backend segments, the descriptor is irrelevant / a nullptr.
231 return true;
232 } else {
233 if (strcmp(op.segment_info->descriptor, descriptor) == 0) {
234 return true;
235 }
236 }
237 }
238 }
239 return false;
240 }
241
242 /**
243 * Returns true if the operations list shows that the provided data pointer
244 * was freed.
245 */
WasFreed(const void * data) const246 bool WasFreed(const void* data) const {
247 for (const auto& op : operations_) {
248 if (op.op == Operation::Free && op.data == data) {
249 return true;
250 }
251 }
252 return false;
253 }
254
255 private:
256 struct SpyContext {
SpyContextDataLoaderSpy::SpyContext257 SpyContext(std::vector<Operation>* operations, FreeableBuffer&& buffer)
258 : operations(operations), buffer(std::move(buffer)) {}
259 std::vector<Operation>* operations;
260 FreeableBuffer buffer;
261 };
262
FreeBuffer(void * context,void * data,size_t size)263 static void FreeBuffer(void* context, void* data, size_t size) {
264 auto* sc = reinterpret_cast<SpyContext*>(context);
265 sc->operations->push_back(
266 {Operation::Free, /*offset=*/0, data, size, /*segment_info=*/nullptr});
267 delete sc;
268 }
269
270 /// The real loader to delegate to.
271 DataLoader* delegate_;
272
273 mutable std::vector<Operation> operations_;
274 };
275
276 constexpr size_t kDefaultNonConstMemBytes = 32 * 1024;
277 constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024;
278
279 class BackendIntegrationTest : public ::testing::TestWithParam<bool> {
280 protected:
SetUp()281 void SetUp() override {
282 // Since these tests cause ET_LOG to be called, the PAL must be initialized
283 // first.
284 executorch::runtime::runtime_init();
285
286 // Make sure that the backend has been registered. Safe to call multiple
287 // times. Doing this at runtime ensures that it's only registered if these
288 // tests are run.
289 ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);
290
291 // Paths to the test program files.
292 program_path_ = std::getenv("ET_MODULE_ADD_MUL_PATH");
293 ASSERT_FALSE(program_path_.empty());
294 program_nosegments_path_ = std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_PATH");
295 ASSERT_FALSE(program_nosegments_path_.empty());
296 }
297
TearDown()298 void TearDown() override {
299 // Clean up any modifications to the singleton.
300 StubBackend::singleton().reset();
301 }
302
303 /**
304 * Returns true if program_path() returns a file with extracted segments.
305 */
using_segments() const306 bool using_segments() const {
307 return GetParam();
308 }
309
310 /**
311 * Returns tha path to the program to load. May or may not have extracted
312 * segments, depending on the return value of using_segments().
313 */
program_path() const314 const char* program_path() const {
315 if (using_segments()) {
316 return program_path_.c_str();
317 } else {
318 return program_nosegments_path_.c_str();
319 }
320 }
321
322 private:
323 std::string program_path_;
324 std::string program_nosegments_path_;
325 };
326
TEST_P(BackendIntegrationTest,BackendIsPresent)327 TEST_P(BackendIntegrationTest, BackendIsPresent) {
328 BackendInterface* backend =
329 executorch::runtime::get_backend_class(StubBackend::kName);
330 ASSERT_EQ(backend, &StubBackend::singleton());
331 }
332
333 // Demonstrate that installed StubBackend initializes successfully by default.
TEST_P(BackendIntegrationTest,BasicInitSucceeds)334 TEST_P(BackendIntegrationTest, BasicInitSucceeds) {
335 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
336 ASSERT_EQ(loader.error(), Error::Ok);
337
338 Result<Program> program = Program::load(&loader.get());
339 ASSERT_EQ(program.error(), Error::Ok);
340
341 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
342 Result<Method> method_res = program->load_method("forward", &mmm.get());
343 EXPECT_EQ(method_res.error(), Error::Ok);
344 }
345
TEST_P(BackendIntegrationTest,FreeingProcessedBufferSucceeds)346 TEST_P(BackendIntegrationTest, FreeingProcessedBufferSucceeds) {
347 // Install an init() implementation that frees its processed buffer, and lets
348 // us know that it was actually called by setting init_called.
349 bool init_called = false;
350 const void* processed_data = nullptr;
351 StubBackend::singleton().install_init(
352 [&](FreeableBuffer* processed,
353 ET_UNUSED ArrayRef<CompileSpec> compile_specs,
354 ET_UNUSED BackendInitContext& backend_init_context)
355 -> Result<DelegateHandle*> {
356 init_called = true;
357 processed_data = processed->data();
358 processed->Free();
359 return nullptr;
360 });
361
362 // Wrap the real loader in a spy so we can see which operations were
363 // performed.
364 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
365 ASSERT_EQ(loader.error(), Error::Ok);
366 DataLoaderSpy spy_loader(&loader.get());
367
368 // Load the program.
369 Result<Program> program = Program::load(&spy_loader);
370 ASSERT_EQ(program.error(), Error::Ok);
371 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
372 Result<Method> method_res = program->load_method("forward", &mmm.get());
373 EXPECT_EQ(method_res.error(), Error::Ok);
374
375 // Demonstrate that our installed init was called.
376 EXPECT_TRUE(init_called);
377
378 // See if the processed data was freed.
379 bool processed_was_freed = spy_loader.WasFreed(processed_data);
380 if (using_segments()) {
381 // Used the loader to create the FreeableBuffer that was passed to the
382 // backend, so we can see its Free() call.
383 EXPECT_TRUE(processed_was_freed);
384 } else {
385 // Didn't use the loader to create the FreeableBuffer that was passed to the
386 // backend, so we can't see its Free() call.
387 EXPECT_FALSE(processed_was_freed);
388 }
389 }
390
TEST_P(BackendIntegrationTest,EndToEndTestWithProcessedAsHandle)391 TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
392 // Install an init() implementation that does not free its processed buffer,
393 // and returns the FreeableBuffer as the delegate handle.
394 FreeableBuffer* init_processed = nullptr;
395 StubBackend::singleton().install_init(
396 [&](FreeableBuffer* processed,
397 ET_UNUSED ArrayRef<CompileSpec> compile_specs,
398 ET_UNUSED BackendInitContext& backend_init_context)
399 -> Result<DelegateHandle*> {
400 init_processed = processed;
401 return processed;
402 });
403
404 // Install an execute() that expects the handle to be the processed
405 // FreeableBuffer.
406 DelegateHandle* execute_handle = nullptr;
407 StubBackend::singleton().install_execute(
408 [&](ET_UNUSED BackendExecutionContext& backend_execution_context,
409 DelegateHandle* handle,
410 ET_UNUSED EValue** args) -> Error {
411 execute_handle = handle;
412 auto* processed = reinterpret_cast<FreeableBuffer*>(handle);
413
414 // Read the data, which will tend to cause an ASAN error if it's not
415 // valid.
416 auto copy = std::make_unique<char[]>(processed->size());
417 std::memcpy(copy.get(), processed->data(), processed->size());
418
419 return Error::Ok;
420 });
421
422 // Install a destroy() that expects the handle to be the processed
423 // FreeableBuffer.
424 DelegateHandle* destroy_handle = nullptr;
425 StubBackend::singleton().install_destroy(
426 [&](DelegateHandle* handle) -> void { destroy_handle = handle; });
427
428 // Wrap the real loader in a spy so we can see which operations were
429 // performed.
430 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
431 ASSERT_EQ(loader.error(), Error::Ok);
432 DataLoaderSpy spy_loader(&loader.get());
433
434 // Load the program.
435 Result<Program> program = Program::load(&spy_loader);
436 ASSERT_EQ(program.error(), Error::Ok);
437
438 // Hold onto the address of the processed buffer so we can compare against
439 // it after the FreeableBuffer has been destroyed.
440 const void* processed_data;
441
442 // Add a scope so we can watch executor be destroyed.
443 {
444 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
445 Result<Method> method_res = program->load_method("forward", &mmm.get());
446 EXPECT_TRUE(method_res.ok());
447
448 // Demonstrate that our installed init was called.
449 EXPECT_NE(init_processed, nullptr);
450 // Not freed yet.
451 EXPECT_GT(init_processed->size(), 0);
452 EXPECT_NE(init_processed->data(), nullptr);
453 processed_data = init_processed->data();
454
455 // The processed data should not have been freed during init.
456 EXPECT_FALSE(spy_loader.WasFreed(init_processed->data()));
457 auto method(std::move(method_res.get()));
458 // Execute the model.
459 auto input_cleanup = executorch::extension::prepare_input_tensors(method);
460 ASSERT_EQ(input_cleanup.error(), Error::Ok);
461 auto err = method.execute();
462 EXPECT_EQ(err, Error::Ok);
463
464 // Check that the processed buffer was passed to execute() as the handle.
465 EXPECT_EQ(init_processed, execute_handle);
466
467 // The processed data should not have been freed during execution.
468 EXPECT_FALSE(spy_loader.WasFreed(init_processed->data()));
469 }
470
471 // `executor` has now been destroyed, which should have freed the processed
472 // data.
473 bool processed_was_freed = spy_loader.WasFreed(processed_data);
474 if (using_segments()) {
475 // Used the loader to create the FreeableBuffer that was passed to the
476 // backend, so we can see its Free() call.
477 EXPECT_TRUE(processed_was_freed);
478 } else {
479 // Didn't use the loader to create the FreeableBuffer that was passed to the
480 // backend, so we can't see its Free() call.
481 EXPECT_FALSE(processed_was_freed);
482 }
483
484 // And it should have destroyed the backend handle.
485 EXPECT_EQ(execute_handle, destroy_handle);
486 }
487
488 /**
489 * Tests that the DataLoader's load is receiving the correct segment info for
490 * different types of segments.
491 */
TEST_P(BackendIntegrationTest,SegmentInfoIsPassedIntoDataLoader)492 TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
493 const void* processed_data = nullptr;
494 StubBackend::singleton().install_init(
495 [&](FreeableBuffer* processed,
496 ET_UNUSED ArrayRef<CompileSpec> compile_specs,
497 ET_UNUSED BackendInitContext& backend_init_context)
498 -> Result<DelegateHandle*> {
499 processed_data = processed->data();
500 processed->Free();
501 return nullptr;
502 });
503
504 // Wrap the real loader in a spy so we can see which operations were
505 // performed.
506 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
507 ASSERT_EQ(loader.error(), Error::Ok);
508 DataLoaderSpy spy_loader(&loader.get());
509
510 // Load the program.
511 Result<Program> program = Program::load(&spy_loader);
512 ASSERT_EQ(program.error(), Error::Ok);
513 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
514
515 // Expect that load was called correctly on program segments.
516 bool program_load_was_called =
517 spy_loader.UsedLoad(DataLoader::SegmentInfo::Type::Program, nullptr);
518
519 // Load a method.
520 Result<Method> method_res = program->load_method("forward", &mmm.get());
521 EXPECT_EQ(method_res.error(), Error::Ok);
522
523 // Expect that load was called correctly on a backend segment.
524 bool backend_load_was_called = spy_loader.UsedLoad(
525 DataLoader::SegmentInfo::Type::Backend,
526 "StubBackend"); // This backend id is taken from the StubBackend defined
527 // in export_delegated_program.py.
528
529 EXPECT_TRUE(program_load_was_called);
530 EXPECT_EQ(backend_load_was_called, using_segments());
531 }
532
TEST_P(BackendIntegrationTest,GetMethodNameDuringInitSuccess)533 TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
534 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
535 ASSERT_EQ(loader.error(), Error::Ok);
536 const void* processed_data = nullptr;
537 StubBackend::singleton().install_init(
538 [&](FreeableBuffer* processed,
539 ET_UNUSED ArrayRef<CompileSpec> compile_specs,
540 ET_UNUSED BackendInitContext& backend_init_context)
541 -> Result<DelegateHandle*> {
542 auto method_name = backend_init_context.get_method_name();
543 // Ensure that we can get the method name during init via context
544 EXPECT_STREQ(method_name, "forward");
545 processed_data = processed->data();
546 return nullptr;
547 });
548 Result<Program> program = Program::load(&loader.get());
549 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
550 Result<Method> method = program->load_method("forward", &mmm.get());
551 EXPECT_TRUE(method.ok());
552 ASSERT_EQ(program.error(), Error::Ok);
553 }
554
TEST_P(BackendIntegrationTest,GetMethodNameDuringExecuteSuccess)555 TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
556 Result<FileDataLoader> loader = FileDataLoader::from(program_path());
557 ASSERT_EQ(loader.error(), Error::Ok);
558 StubBackend::singleton().install_execute(
559 [&](BackendExecutionContext& backend_execution_context,
560 ET_UNUSED DelegateHandle* handle,
561 ET_UNUSED EValue** args) -> Error {
562 // Ensure that we can get the method name during execution via context
563 auto method_name = backend_execution_context.get_method_name();
564 EXPECT_STREQ(method_name, "forward");
565 return Error::Ok;
566 });
567 Result<Program> program = Program::load(&loader.get());
568 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
569 Result<Method> method = program->load_method("forward", &mmm.get());
570 EXPECT_TRUE(method.ok());
571 Error err = method->execute();
572 ASSERT_EQ(err, Error::Ok);
573 }
574
575 // TODO: Add more tests for the runtime-to-backend interface. E.g.:
576 // - Errors during init() or execute() result in runtime init/execution failures
577 // - Correct values are passed to init()/execute()
578 // - Demonstrate use of the runtime allocator
579 // - ...
580
581 // Run all BackendIntegrationTests multiple times, varying the return value of
582 // `GetParam()` based on the `testing::Values` list. The tests will interpret
583 // the boolean as "using segments".
584 INSTANTIATE_TEST_SUITE_P(
585 VariedSegments,
586 BackendIntegrationTest,
587 testing::Values(false, true));
588
589 class DelegateDataAlignmentTest : public ::testing::TestWithParam<bool> {
590 protected:
SetUp()591 void SetUp() override {
592 // Since these tests cause ET_LOG to be called, the PAL must be initialized
593 // first.
594 executorch::runtime::runtime_init();
595
596 // Make sure that the backend has been registered. Safe to call multiple
597 // times. Doing this at runtime ensures that it's only registered if these
598 // tests are run.
599 ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);
600
601 // Paths to the test program files.
602 default_alignment_program_path_ =
603 std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_PATH");
604 ASSERT_FALSE(default_alignment_program_path_.empty());
605 override_alignment_program_path_ =
606 std::getenv("ET_MODULE_ADD_MUL_NOSEGMENTS_DA1024_PATH");
607 ASSERT_FALSE(override_alignment_program_path_.empty());
608 }
609
TearDown()610 void TearDown() override {
611 // Clean up any modifications to the singleton.
612 StubBackend::singleton().reset();
613 }
614
615 /**
616 * Returns the expected minimum alignment of inline tensor data, given
617 * the testing parameter.
618 */
expected_alignment() const619 size_t expected_alignment() const {
620 if (GetParam()) {
621 // The delegate data inline alignment used by the -da1024 file.
622 return 1024;
623 } else {
624 // A small alignment that's compatible with any realistic alignment.
625 return 4;
626 }
627 }
628
629 /**
630 * Returns tha path to the program to load. May or may not have an alignment
631 * override, depending on the return value of expected_alignment().
632 */
program_path() const633 const char* program_path() const {
634 if (GetParam()) {
635 return override_alignment_program_path_.c_str();
636 } else {
637 return default_alignment_program_path_.c_str();
638 }
639 }
640
641 private:
642 std::string default_alignment_program_path_;
643 std::string override_alignment_program_path_;
644 };
645
TEST_P(DelegateDataAlignmentTest,ExpectedDataAlignment)646 TEST_P(DelegateDataAlignmentTest, ExpectedDataAlignment) {
647 // Install an init() implementation that records the pointer to the delegate
648 // data blob so we can check its alignment.
649 const void* processed_data = nullptr;
650 StubBackend::singleton().install_init(
651 [&](FreeableBuffer* processed,
652 ET_UNUSED ArrayRef<CompileSpec> compile_specs,
653 ET_UNUSED BackendInitContext& backend_init_context)
654 -> Result<DelegateHandle*> {
655 processed_data = processed->data();
656 return nullptr;
657 });
658
659 // Create a loader that can satisfy the alignment required by this program.
660 Result<FileDataLoader> loader =
661 FileDataLoader::from(program_path(), /*alignment=*/expected_alignment());
662 ASSERT_EQ(loader.error(), Error::Ok);
663
664 // Wrap the real loader in a spy so we can see which operations were
665 // performed.
666 DataLoaderSpy spy_loader(&loader.get());
667
668 // Load the program.
669 Result<Program> program = Program::load(&spy_loader);
670 ASSERT_EQ(program.error(), Error::Ok);
671 ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
672 Result<Method> method = program->load_method("forward", &mmm.get());
673 EXPECT_TRUE(method.ok());
674
675 // Demonstrate that our installed init was called.
676 EXPECT_NE(processed_data, nullptr);
677
678 // Check that it had the required alignment. The alignment of 1024 is larger
679 // than the test file with default alignment, so the default alignment cannot
680 // accidentally satisfy it.
681 EXPECT_ALIGNED(processed_data, expected_alignment());
682 }
683
684 // Run all DelegateDataAlignmentTests multiple times, varying the return value
685 // of `GetParam()` based on the `testing::Values` list. The tests will interpret
686 // the boolean as "was inline delegate data alignment overridden to 1024".
687 INSTANTIATE_TEST_SUITE_P(
688 VariedAlignment,
689 DelegateDataAlignmentTest,
690 testing::Values(false, true));
691