1 // 2 // multiarray.h 3 // 4 // Copyright © 2024 Apple Inc. All rights reserved. 5 // 6 // Please refer to the license found in the LICENSE file in the root directory of the source tree. 7 8 #pragma once 9 10 #import <CoreML/CoreML.h> 11 #import <iostream> 12 #import <optional> 13 #import <vector> 14 15 namespace executorchcoreml { 16 17 /// A class representing an unowned buffer. 18 class Buffer { 19 public: 20 /// Constructs a buffer from data and size. Buffer(const void * data,size_t size)21 explicit Buffer(const void* data, size_t size) noexcept : data_(data), size_(size) { } 22 23 /// Returns the data pointer. data()24 inline const void* data() const noexcept { return data_; } 25 26 /// Returns the size of the buffer. size()27 inline size_t size() const noexcept { return size_; } 28 29 private: 30 const void* data_; 31 size_t size_; 32 }; 33 34 /// A class representing a MultiArray. 35 class MultiArray final { 36 public: 37 /// The MultiArray datatype. 38 enum class DataType : uint8_t { 39 Bool = 0, 40 Byte, 41 Char, 42 Short, 43 Int32, 44 Int64, 45 Float16, 46 Float32, 47 Float64, 48 }; 49 50 /// Options for copying. 51 struct CopyOptions { CopyOptionsCopyOptions52 inline CopyOptions() noexcept : use_bnns(true), use_memcpy(true) { } 53 CopyOptionsCopyOptions54 inline CopyOptions(bool use_bnns, bool use_memcpy) noexcept : use_bnns(use_bnns), use_memcpy(use_memcpy) { } 55 56 bool use_bnns = true; 57 bool use_memcpy = true; 58 }; 59 60 /// A class describing the memory layout of a MultiArray. 61 class MemoryLayout final { 62 public: MemoryLayout(DataType dataType,std::vector<size_t> shape,std::vector<ssize_t> strides)63 MemoryLayout(DataType dataType, std::vector<size_t> shape, std::vector<ssize_t> strides) 64 : dataType_(dataType), shape_(std::move(shape)), strides_(std::move(strides)) { } 65 66 /// Returns the datatype of the MultiArray. dataType()67 inline DataType dataType() const noexcept { return dataType_; } 68 69 /// Returns the shape of the MultiArray. shape()70 inline const std::vector<size_t>& shape() const noexcept { return shape_; } 71 72 /// Returns the strides of the MultiArray. strides()73 inline const std::vector<ssize_t>& strides() const noexcept { return strides_; } 74 75 /// Returns the MultiArray rank. rank()76 inline size_t rank() const noexcept { return shape_.size(); } 77 78 /// Returns the number of elements in the MultiArray. 79 size_t num_elements() const noexcept; 80 81 /// Returns the byte size of an element. 82 size_t num_bytes() const noexcept; 83 84 /// Returns `true` if the memory layout is packed otherwise `false`. 85 bool is_packed() const noexcept; 86 87 private: 88 DataType dataType_; 89 std::vector<size_t> shape_; 90 std::vector<ssize_t> strides_; 91 }; 92 93 /// Constructs a `MultiArray` from data and it's memory layout. 94 /// 95 /// The data is not owned by the `MultiArray`. MultiArray(void * data,MemoryLayout layout)96 MultiArray(void* data, MemoryLayout layout) : data_(data), layout_(std::move(layout)) { } 97 98 /// Returns the data pointer. data()99 inline void* data() const noexcept { return data_; } 100 101 /// Returns the layout of the MultiArray. layout()102 inline const MemoryLayout& layout() const noexcept { return layout_; } 103 104 /// Copies this into another `MultiArray`. 105 /// 106 /// @param dst The destination `MultiArray`. 107 void copy(MultiArray& dst, CopyOptions options = CopyOptions()) const noexcept; 108 109 /// Get the value at `indices`. value(const std::vector<size_t> & indices)110 template <typename T> inline T value(const std::vector<size_t>& indices) const noexcept { 111 return *(static_cast<T*>(data(indices))); 112 } 113 114 /// Set the value at `indices`. set_value(const std::vector<size_t> & indices,T value)115 template <typename T> inline void set_value(const std::vector<size_t>& indices, T value) const noexcept { 116 T* ptr = static_cast<T*>(data(indices)); 117 *ptr = value; 118 } 119 120 /// Get the value at `index`. value(size_t index)121 template <typename T> inline T value(size_t index) const noexcept { return *(static_cast<T*>(data(index))); } 122 123 /// Set the value at `index`. set_value(size_t index,T value)124 template <typename T> inline void set_value(size_t index, T value) const noexcept { 125 T* ptr = static_cast<T*>(data(index)); 126 *ptr = value; 127 } 128 129 private: 130 void* data(const std::vector<size_t>& indices) const noexcept; 131 132 void* data(size_t index) const noexcept; 133 134 void* data_; 135 MemoryLayout layout_; 136 }; 137 138 /// Converts `MultiArray::DataType` to `MLMultiArrayDataType`. 139 std::optional<MLMultiArrayDataType> to_ml_multiarray_data_type(MultiArray::DataType data_type); 140 141 /// Converts `MLMultiArrayDataType` to `MultiArray::DataType`. 142 std::optional<MultiArray::DataType> to_multiarray_data_type(MLMultiArrayDataType data_type); 143 144 145 } // namespace executorchcoreml 146