xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/OperationUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 //  Copyright © 2022 Apple Inc.
2 
3 #pragma once
4 
5 #include <initializer_list>
6 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
7 #include <ATen/Tensor.h>
8 #include <ATen/Utils.h>
9 #include <ATen/mps/MPSStream.h>
10 #include <ATen/native/mps/TensorFactory.h>
11 #include <c10/core/ScalarType.h>
12 #include <torch/library.h>
13 #include <unordered_map>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/empty_like.h>
21 #include <ATen/ops/zeros.h>
22 #include <ATen/ops/zeros_like.h>
23 #endif
24 
25 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
26 
27 // Fwd declarations
28 namespace at {
29   struct TensorIteratorBase;
30 }
31 using namespace at::mps;
32 
33 namespace at::native::mps {
34 
35 void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
36 
37 struct MPSScalar {
getMTLBufferMPSScalar38   id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
39 
40   size_t size = 0;
41   ScalarType type = ScalarType::Undefined;
42   c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
43   union {
44     float f; // MPS doesn't support 'double'
45     at::Half h;
46     int64_t i;
47     bool b;
48     c10::complex<float> cf;
49     c10::complex<at::Half> ch;
50     at::BFloat16 bf16;
51   } value {};
52 };
53 
54 void runMPSGraph(MPSStream* mpsStream,
55     MPSGraph* mpsGraph,
56     NSDictionary* feeds,
57     NSDictionary* results);
58 
59 MPSDataType getMPSDataType(ScalarType scalar_type);
getMPSDataType(const Tensor & t)60 static inline MPSDataType getMPSDataType(const Tensor& t) {
61   return getMPSDataType(t.scalar_type());
62 }
63 MPSDataType getMPSScalarType(ScalarType scalar_type);
getMPSScalarType(const Tensor & t)64 static inline MPSDataType getMPSScalarType(const Tensor& t) {
65   return getMPSScalarType(t.scalar_type());
66 }
67 MPSScalar   getMPSScalar(const Scalar& scalar, ScalarType type);
68 std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
69 static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
70   return getMPSTypeString(t.scalar_type(), short_name);
71 }
72 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
scalarToMetalTypeString(const Tensor & t)73 static inline std::string scalarToMetalTypeString(const Tensor& t) {
74   return scalarToMetalTypeString(t.scalar_type());
75 }
76 NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
77 NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
78 std::string getMPSShapeString(MPSShape* shape);
79 std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
80 std::string getArrayRefString(const IntArrayRef s);
81 // use has_storage() on the returned tensor to determine if src actually is a view
82 Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
83 Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
84 bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
85 MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
86 MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
87 MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
88 
89 MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
90 MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes = nil, MPSShape* strides = nil);
91 // The MPSShape could vary based on memory format
92 Tensor getTensorView(const Tensor& t, MPSShape* shape);
93 MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
94 MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
95 
getMTLBufferStorage(const at::Tensor & tensor)96 static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
97   return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
98 }
99 
100 class Placeholder {
101  public:
Placeholder()102   Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor * mpsGraphTensor)103   Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
104   Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
105   Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
106               bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
getMPSGraphTensor()107   MPSGraphTensor* getMPSGraphTensor() {
108     return _placeholder;
109   }
getMPSGraphTensorData()110   MPSGraphTensorData* getMPSGraphTensorData() {
111     return _value;
112   }
isIntermediate()113   bool isIntermediate() {
114     return _value == nullptr;
115   }
116 
117  private:
118   MPSGraphTensor* _placeholder;
119   MPSGraphTensorData* _value;
120   Tensor _tensor;
121 };
122 
123 void resize_tensor(Tensor* output);
124 Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
125 MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
126 MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
127 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
128 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
129 MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
130 MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
131 
132 MPSGraph* make_mps_graph();
133 void printTensorNDArray(const Tensor& t);
134 MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
135 
136 MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
137 MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
138 MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
139 MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
140 MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
141 
142 string get_mem_format_string(c10::MemoryFormat memory_format);
143 
144 using MPSCacheKey = uint64_t;
145 
146 // derive this class to cache a graph and its inputs/outputs
147 // can be used to store any NSObject
148 struct MPSCachedGraph
149 {
MPSCachedGraphMPSCachedGraph150   MPSCachedGraph(NSObject *object) : _object([object retain]) {}
~MPSCachedGraphMPSCachedGraph151   virtual ~MPSCachedGraph() {
152    [_object release];
153    _object = nullptr;
154   }
155 
156   template<typename T>
asMPSCachedGraph157   inline T* as() {
158     return static_cast<T*>(this);
159   }
160 
graphMPSCachedGraph161   MPSGraph *graph() const { return (MPSGraph *)_object; }
objectMPSCachedGraph162   NSObject *object() const { return _object; }
163 private:
164   NSObject *_object = nullptr;
165 };
166 
167 struct MPSUnaryCachedGraph : public MPSCachedGraph
168 {
MPSUnaryCachedGraphMPSUnaryCachedGraph169   MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
170   MPSGraphTensor *inputTensor_ = nil;
171   MPSGraphTensor *outputTensor_ = nil;
172 };
173 
174 struct MPSUnaryGradCachedGraph : public MPSCachedGraph
175 {
MPSUnaryGradCachedGraphMPSUnaryGradCachedGraph176   MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
177   MPSGraphTensor *gradOutputTensor_ = nil;
178   MPSGraphTensor *inputTensor_ = nil;
179   MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
180   MPSGraphTensor *gradInputTensor_ = nil;
181 };
182 
183 struct MPSBinaryCachedGraph : public MPSCachedGraph
184 {
MPSBinaryCachedGraphMPSBinaryCachedGraph185   MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
186   MPSGraphTensor *inputTensor_ = nil;
187   MPSGraphTensor *otherTensor_ = nil;
188   MPSGraphTensor *outputTensor_ = nil;
189 };
190 
191 struct MPSBinaryGradCachedGraph : public MPSCachedGraph
192 {
MPSBinaryGradCachedGraphMPSBinaryGradCachedGraph193   MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
194   MPSGraphTensor *gradOutputTensor_ = nil;
195   MPSGraphTensor *inputTensor_ = nil;
196   MPSGraphTensor *otherTensor_ = nil;
197   MPSGraphTensor *gradInputTensor_ = nil;
198 };
199 
200 // TODO: Improve the overall design of MPSGraphCache.
201 // https://github.com/pytorch/pytorch/issues/77176
202 // Cache holding various keys mapped to graphs
203 struct MPSGraphCache
204 {
205   typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
206 
207   struct CacheEntry {
CacheEntryMPSGraphCache::CacheEntry208     CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
209     MPSCachedGraph* cachedGraph_ = nullptr;
210     std::string key_;
211   };
212 
213  public:
214 
getInstanceMPSGraphCache215   static MPSGraphCache* getInstance() {
216     if(_instance_cache == nullptr) {
217       _instance_cache = new MPSGraphCache();
218     }
219     return _instance_cache;
220   }
221 
~MPSGraphCacheMPSGraphCache222   ~MPSGraphCache() {
223     dispatch_release(serialQueue_);
224 
225     for (const auto& i : cache_) {
226       delete i.second.cachedGraph_;
227     }
228   }
229 
230   // Disallow the copy constructor and operator= functions
231   MPSGraphCache(const MPSGraphCache&) = delete;
232   void operator=(const MPSGraphCache&) = delete;
233 
CreateCachedGraphMPSGraphCache234   MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
235 
236     __block MPSCachedGraph* cachedGraph = nil;
237 
238     MPSCacheKey hash = std::hash<std::string>{}(key);
239 
240     dispatch_sync_with_rethrow(serialQueue_, ^() {
241       // verify the cached entry doesn't already exist
242       if (cache_.count(hash) != 0) {
243         auto& entry = cache_.at(hash);
244         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
245         cachedGraph = entry.cachedGraph_;
246       } else {
247         cachedGraph = createCacheBlock();
248         CacheEntry entry(key, cachedGraph);
249         cache_.emplace(hash, entry);
250         profileCachedGraph(entry);
251       }
252     });
253     return cachedGraph;
254   }
255 
256   template<typename T>
CreateCachedGraphAsMPSGraphCache257   inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
258     return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
259   }
260 
LookUpMPSGraphCache261   MPSCachedGraph* LookUp(const std::string& key) const {
262 
263     __block MPSCachedGraph* cachedGraph = nullptr;
264 
265     MPSCacheKey hash = std::hash<std::string>{}(key);
266 
267     dispatch_sync(serialQueue_, ^() {
268 
269       if (cache_.count(hash) != 0) {
270         auto& entry = cache_.at(hash);
271         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
272         cachedGraph = entry.cachedGraph_;
273         profileCachedGraph(entry);
274       }
275     });
276     return cachedGraph;
277   }
278 
279   template<typename T>
LookUpAsMPSGraphCache280   inline T* LookUpAs(const std::string& key) const {
281     return static_cast<T *>(LookUp(key));
282   }
283 
284  private:
MPSGraphCacheMPSGraphCache285   MPSGraphCache() {
286     serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
287   }
288   // this is defined in OperationUtils.mm to not include
289   // MPSProfiler.h in header OperationUtils.h
290   void profileCachedGraph(const CacheEntry& cacheEntry) const;
291 
292   static MPSGraphCache* _instance_cache;
293   std::unordered_map<MPSCacheKey, CacheEntry> cache_;
294   dispatch_queue_t serialQueue_ = nullptr;
295 
296 };
297 
298 // Common template for creating graph with a specified cache if missing
299 template<typename T>
LookUpOrCreateCachedGraph(const std::string & key,std::function<void (MPSGraph *,T *)> instantiate)300 inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
301   auto cache_ = MPSGraphCache::getInstance();
302   if (auto rc  = cache_->LookUpAs<T>(key)) {
303     return rc;
304   }
305   return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
306     T* newCachedGraph = nil;
307     @autoreleasepool {
308       // Initialize graph
309       auto mpsGraph = mps::make_mps_graph();
310       newCachedGraph = new T(mpsGraph);
311       instantiate(mpsGraph, newCachedGraph);
312     }
313     return newCachedGraph;
314   });
315 }
316 
317 // Common math operations
318 MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
319 
320 #define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name)                                           \
321   if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) {                                                       \
322      TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name,                                                         \
323      ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3.");   \
324   }
325 
326 /**
327  * Returns distance from lowest to highest element offset in given tensor.
328  */
329 size_t compute_storage_numel_distance(const at::Tensor& t);
330 
331 /**
332  * Checks whether tensor is mapped to a contiguous area in the storage.
333  */
is_dense_in_storage(const at::Tensor & t)334 inline bool is_dense_in_storage(const at::Tensor& t) {
335   return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
336 }
337 
338 
339 class MetalShaderLibrary {
340 public:
MetalShaderLibrary(const std::string & src)341   MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
MetalShaderLibrary(const std::string & src,unsigned nparams_)342   MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string & src,unsigned nparams_,MTLCompileOptions * compile_options_)343   MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
344   MetalShaderLibrary(const MetalShaderLibrary&) = delete;
getPipelineStateForFunc(const std::string & fname)345   inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
346     return getLibraryPipelineState(getLibrary(), fname).first;
347   }
getPipelineStateForFunc(const std::string & fname,const std::initializer_list<std::string> & params)348   id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
349     return getLibraryPipelineState(getLibrary(params), fname).first;
350   }
getMTLFunction(const std::string & fname)351   inline id<MTLFunction> getMTLFunction(const std::string& fname) {
352     return getLibraryPipelineState(getLibrary(), fname).second;
353   }
getMTLFunction(const std::string & fname,const std::initializer_list<std::string> & params)354   id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) {
355     return getLibraryPipelineState(getLibrary(params), fname).second;
356   }
357 private:
358   std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
359   id<MTLLibrary> getLibrary();
360   id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
361 
362   id<MTLLibrary> compileLibrary(const std::string& src);
363   std::string shaderSource;
364   unsigned nparams;
365   MTLCompileOptions* compile_options;
366   id<MTLLibrary> library = nil;
367   std::unordered_map<std::string, id<MTLLibrary>> libMap;
368   std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
369 };
370 
371 template<typename encoder_t,
372          typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
mtl_setBuffer(encoder_t encoder,const Tensor & t,unsigned idx)373 static inline void mtl_setBuffer(encoder_t encoder, const Tensor& t, unsigned idx) {
374   [encoder setBuffer:getMTLBufferStorage(t)
375               offset:t.storage_offset() * t.element_size()
376              atIndex:idx];
377 }
378 
379 template<typename T,
380          typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
mtl_setBytes(id<MTLComputeCommandEncoder> encoder,const T val,unsigned idx)381 static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
382   [encoder setBytes:&val length:sizeof(T) atIndex: idx];
383 }
384 
385 template<typename Container,
386          typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
mtl_setBytes(id<MTLComputeCommandEncoder> encoder,const Container & values,unsigned idx)387 static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
388   [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex: idx];
389 }
390 
mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,id<MTLComputePipelineState> cplState,uint32_t length)391 static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
392                                      id<MTLComputePipelineState> cplState,
393                                      uint32_t length) {
394   const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
395   auto size = MTLSizeMake(length, 1, 1);
396   auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
397   [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
398 }
399 
400 id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
401 
dictionaryFromPlaceholders(Placeholder & p1)402 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
403         return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
404 }
405 
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2)406 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
407         return @{
408                 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
409                 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
410          };
411 }
412 
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2,Placeholder & p3)413 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
414         return @{
415                 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
416                 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
417                 p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
418          };
419 }
420 
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2,Placeholder & p3,Placeholder & p4)421 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
422         return @{
423                 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
424                 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
425                 p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
426                 p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
427          };
428 }
429 
runMPSGraph(MPSStream * stream,MPSGraph * graph,NSDictionary * feeds,Placeholder & result)430 inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
431         runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
432 }
433 
supportsComplex()434 inline bool supportsComplex() {
435   return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
436 }
437 
438 // MPS yet to support double types, but starting from MacOS 14, supports bfloat16
supportedFloatingType(ScalarType dtype)439 inline bool supportedFloatingType(ScalarType dtype) {
440   return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
441 }
442 
supportedFloatingType(const Tensor & t)443 inline bool supportedFloatingType(const Tensor& t) {
444   return supportedFloatingType(t.scalar_type());
445 }
446 
supportedFloatingOrComplexType(ScalarType dtype)447 inline bool supportedFloatingOrComplexType(ScalarType dtype) {
448   if (dtype == kComplexFloat || dtype == kComplexHalf) {
449     return supportsComplex();
450   }
451   return supportedFloatingType(dtype);
452 }
supportedFloatingOrComplexType(const Tensor & t)453 inline bool supportedFloatingOrComplexType(const Tensor& t) {
454   return supportedFloatingOrComplexType(t.scalar_type());
455 }
456 
457 
needsGather(const Tensor & t)458 inline bool needsGather(const Tensor& t) {
459   static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
460   return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ;
461 }
462 
463 } // namespace at::native::mps
464