xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/MPSCompiler.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2//  Copyright (c) 2023 Apple Inc. All rights reserved.
3//  Provided subject to the LICENSE file in the top level directory.
4//
5
6// Obj-C headers
7#import <Foundation/Foundation.h>
8#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
9#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
10
11// MPS headers
12#include <executorch/backends/apple/mps/runtime/MPSDevice.h>
13#include <executorch/backends/apple/mps/runtime/MPSCompiler.h>
14#include <executorch/backends/apple/mps/runtime/MPSGraphBuilder.h>
15#include <executorch/backends/apple/mps/schema_generated.h>
16
17// Runtime headers
18#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
19
20#include <unordered_map>
21#include <string>
22#include <iostream>
23
24#define MPS_UNUSED(x) ( (void)(x) )
25
26namespace executorch {
27namespace backends {
28namespace mps {
29namespace delegate {
30
31using executorch::runtime::ArrayRef;
32using executorch::runtime::CompileSpec;
33using executorch::runtime::Error;
34using executorch::runtime::MemoryAllocator;
35
36/*
37Builds the mps runtime object using the buffer pointer. The buffer pointer
38must be a valid pointer to the serialized mps object.
39*/
40ET_NODISCARD Error MPSCompiler::compileModel(
41  const void* buffer_pointer,
42  size_t num_bytes,
43  MPSExecutor* executor,
44  MemoryAllocator* runtime_allocator,
45  ArrayRef<CompileSpec> compile_specs) {
46  MPS_UNUSED(compile_specs);
47
48  Error err = Error::Ok;
49
50  std::unique_ptr<MPSGraphBuilder> mpsGraphBuilder(
51    new MPSGraphBuilder(buffer_pointer, num_bytes, executor->_mpsGraphTensorToId));
52  err = mpsGraphBuilder->compileModel();
53  ET_CHECK_OR_RETURN_ERROR(
54    err == Error::Ok, Internal, "Failed to construct the MPS graph object");
55
56  executor->_executable = mpsGraphBuilder->getMPSGraphExecutable();
57  ET_CHECK_OR_RETURN_ERROR(
58      executor->_executable != nil,
59      InvalidProgram,
60      "Invalid FlatBuffer contents - could not create MPSGraphExecutable");
61
62  err = executor->initDataBuffers();
63  ET_CHECK_OR_RETURN_ERROR(
64      err == Error::Ok, Internal, "Could not allocate data buffers");
65
66  ET_LOG(Debug, "MPSGraphExecutable total inputs: %lu", [executor->_inputShapes count]);
67  ET_LOG(Debug, "MPSGraphExecutable total outputs: %lu", [executor->_outputShapes count]);
68
69  return err;
70}
71
72} // namespace delegate
73} // namespace mps
74} // namespace backends
75} // namespace executorch
76