xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/multiarray.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 // multiarray.h
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7 
8 #pragma once
9 
10 #import <CoreML/CoreML.h>
11 #import <iostream>
12 #import <optional>
13 #import <vector>
14 
15 namespace executorchcoreml {
16 
17 /// A class representing an unowned buffer.
18 class Buffer {
19 public:
20     /// Constructs a buffer from data and size.
Buffer(const void * data,size_t size)21     explicit Buffer(const void* data, size_t size) noexcept : data_(data), size_(size) { }
22 
23     /// Returns the data pointer.
data()24     inline const void* data() const noexcept { return data_; }
25 
26     /// Returns the size of the buffer.
size()27     inline size_t size() const noexcept { return size_; }
28 
29 private:
30     const void* data_;
31     size_t size_;
32 };
33 
34 /// A class representing a MultiArray.
35 class MultiArray final {
36 public:
37     /// The MultiArray datatype.
38     enum class DataType : uint8_t {
39         Bool = 0,
40         Byte,
41         Char,
42         Short,
43         Int32,
44         Int64,
45         Float16,
46         Float32,
47         Float64,
48     };
49 
50     /// Options for copying.
51     struct CopyOptions {
CopyOptionsCopyOptions52         inline CopyOptions() noexcept : use_bnns(true), use_memcpy(true) { }
53 
CopyOptionsCopyOptions54         inline CopyOptions(bool use_bnns, bool use_memcpy) noexcept : use_bnns(use_bnns), use_memcpy(use_memcpy) { }
55 
56         bool use_bnns = true;
57         bool use_memcpy = true;
58     };
59 
60     /// A class describing the memory layout of a MultiArray.
61     class MemoryLayout final {
62     public:
MemoryLayout(DataType dataType,std::vector<size_t> shape,std::vector<ssize_t> strides)63         MemoryLayout(DataType dataType, std::vector<size_t> shape, std::vector<ssize_t> strides)
64             : dataType_(dataType), shape_(std::move(shape)), strides_(std::move(strides)) { }
65 
66         /// Returns the datatype of the MultiArray.
dataType()67         inline DataType dataType() const noexcept { return dataType_; }
68 
69         /// Returns the shape of the MultiArray.
shape()70         inline const std::vector<size_t>& shape() const noexcept { return shape_; }
71 
72         /// Returns the strides of the MultiArray.
strides()73         inline const std::vector<ssize_t>& strides() const noexcept { return strides_; }
74 
75         /// Returns the MultiArray rank.
rank()76         inline size_t rank() const noexcept { return shape_.size(); }
77 
78         /// Returns the number of elements in the MultiArray.
79         size_t num_elements() const noexcept;
80 
81         /// Returns the byte size of an element.
82         size_t num_bytes() const noexcept;
83 
84         /// Returns `true` if the memory layout is packed otherwise `false`.
85         bool is_packed() const noexcept;
86 
87     private:
88         DataType dataType_;
89         std::vector<size_t> shape_;
90         std::vector<ssize_t> strides_;
91     };
92 
93     /// Constructs a `MultiArray` from data and it's memory layout.
94     ///
95     /// The data is not owned by the `MultiArray`.
MultiArray(void * data,MemoryLayout layout)96     MultiArray(void* data, MemoryLayout layout) : data_(data), layout_(std::move(layout)) { }
97 
98     /// Returns the data pointer.
data()99     inline void* data() const noexcept { return data_; }
100 
101     /// Returns the layout of the MultiArray.
layout()102     inline const MemoryLayout& layout() const noexcept { return layout_; }
103 
104     /// Copies this into another `MultiArray`.
105     ///
106     /// @param dst The destination `MultiArray`.
107     void copy(MultiArray& dst, CopyOptions options = CopyOptions()) const noexcept;
108 
109     /// Get the value at `indices`.
value(const std::vector<size_t> & indices)110     template <typename T> inline T value(const std::vector<size_t>& indices) const noexcept {
111         return *(static_cast<T*>(data(indices)));
112     }
113 
114     /// Set the value at `indices`.
set_value(const std::vector<size_t> & indices,T value)115     template <typename T> inline void set_value(const std::vector<size_t>& indices, T value) const noexcept {
116         T* ptr = static_cast<T*>(data(indices));
117         *ptr = value;
118     }
119 
120     /// Get the value at `index`.
value(size_t index)121     template <typename T> inline T value(size_t index) const noexcept { return *(static_cast<T*>(data(index))); }
122 
123     /// Set the value at `index`.
set_value(size_t index,T value)124     template <typename T> inline void set_value(size_t index, T value) const noexcept {
125         T* ptr = static_cast<T*>(data(index));
126         *ptr = value;
127     }
128 
129 private:
130     void* data(const std::vector<size_t>& indices) const noexcept;
131 
132     void* data(size_t index) const noexcept;
133 
134     void* data_;
135     MemoryLayout layout_;
136 };
137 
138 /// Converts `MultiArray::DataType` to `MLMultiArrayDataType`.
139 std::optional<MLMultiArrayDataType> to_ml_multiarray_data_type(MultiArray::DataType data_type);
140 
141 /// Converts `MLMultiArrayDataType` to `MultiArray::DataType`.
142 std::optional<MultiArray::DataType> to_multiarray_data_type(MLMultiArrayDataType data_type);
143 
144 
145 } // namespace executorchcoreml
146