xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSBackend.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6#import <Foundation/Foundation.h>
7
8#include "MPSCompiler.h"
9#include <executorch/runtime/backend/interface.h>
10#include <executorch/runtime/core/error.h>
11#include <executorch/runtime/core/evalue.h>
12#include <executorch/runtime/platform/profiler.h>
13#include <cstdio>
14#include <cstdlib> /* strtol */
15#include <memory>
16#include <string>
17#include <iostream>
18
19namespace executorch {
20namespace backends {
21
22using executorch::aten::Tensor;
23using executorch::runtime::ArrayRef;
24using executorch::runtime::Backend;
25using executorch::runtime::BackendExecutionContext;
26using executorch::runtime::BackendInitContext;
27using executorch::runtime::CompileSpec;
28using executorch::runtime::DelegateHandle;
29using executorch::runtime::EValue;
30using executorch::runtime::Error;
31using executorch::runtime::FreeableBuffer;
32using executorch::runtime::Result;
33
34class MPSBackend final : public ::executorch::runtime::BackendInterface {
35 public:
36  ~MPSBackend() = default;
37
38  bool is_available() const override {
39    return true;
40  }
41
42  Result<DelegateHandle*> init(
43      BackendInitContext& context,
44      FreeableBuffer* processed,
45      ArrayRef<CompileSpec> compile_specs) const override {
46    auto executor = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(
47        context.get_runtime_allocator(), mps::delegate::MPSExecutor);
48    // NOTE: Since we use placement new and since this type is not trivially
49    // destructible, we must call the destructor manually in destroy().
50    new (executor) mps::delegate::MPSExecutor;
51    Error err = mps::delegate::MPSCompiler::compileModel(
52        processed->data(),
53        processed->size(),
54        executor,
55        context.get_runtime_allocator(),
56        compile_specs);
57    ET_CHECK_OR_RETURN_ERROR(
58      err == Error::Ok,
59      Internal,
60      "Failed to initialize the MPS delegate");
61
62    // Free the flatbuffer.
63    processed->Free();
64
65    return executor;
66  }
67
68  // Function that actually executes the model in the backend.
69  Error execute(
70    ET_UNUSED BackendExecutionContext& context,
71    DelegateHandle* handle,
72    EValue** args) const override {
73    auto executor = static_cast<mps::delegate::MPSExecutor*>(handle);
74    std::vector<const Tensor*> input_pointers;
75    std::vector<const Tensor*> output_pointers;
76
77    Error err = Error::Ok;
78
79    int i = 0;
80    int total_placeholders = executor->getNumInputs() + executor->getNumOutputs();
81    while ((input_pointers.size() != executor->getNumInputs()    ||
82           output_pointers.size() != executor->getNumOutputs())  &&
83           (i < total_placeholders)) {
84      ET_CHECK_OR_RETURN_ERROR(
85        args[i] != nullptr,
86        Internal,
87        "Nullptr tensor received during graph execution");
88
89      if (args[i]->isTensor()) {
90        if (input_pointers.size() < executor->getNumInputs()) {
91          input_pointers.push_back(&args[i]->toTensor());
92        } else {
93          output_pointers.push_back(&args[i]->toTensor());
94        }
95      } else if (args[i]->isTensorList()) {
96        const executorch::aten::ArrayRef<executorch::aten::Tensor>& tensorList = args[i]->toTensorList();
97        for (auto& tensor_ : tensorList) {
98          if (input_pointers.size() < executor->getNumInputs()) {
99            input_pointers.push_back(&tensor_);
100          } else {
101            output_pointers.push_back(&tensor_);
102          }
103        }
104      } else {
105        ET_CHECK_OR_RETURN_ERROR(
106          false,
107          Internal,
108          "Unhandled tag during execution of the graph");
109      }
110      i++;
111    }
112
113    err = executor->set_inputs_outputs(input_pointers, output_pointers);
114    if (err != Error::Ok) {
115      return err;
116    }
117
118    err = executor->forward(output_pointers);
119    return err;
120  }
121
122  void destroy(DelegateHandle* handle) const override {
123    if (handle != nullptr) {
124      auto executor = static_cast<mps::delegate::MPSExecutor*>(handle);
125      // manually in init(), we must destroy it manually here.
126      executor->~MPSExecutor();
127    }
128  }
129};
130
131namespace {
132auto cls = MPSBackend();
133Backend backend{"MPSBackend", &cls};
134static auto success_with_compiler = register_backend(backend);
135} // namespace
136
137} // namespace backends
138} // namespace executorch
139