1 // Copyright © 2022 Apple Inc.
2
3 #pragma once
4
5 #include <initializer_list>
6 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
7 #include <ATen/Tensor.h>
8 #include <ATen/Utils.h>
9 #include <ATen/mps/MPSStream.h>
10 #include <ATen/native/mps/TensorFactory.h>
11 #include <c10/core/ScalarType.h>
12 #include <torch/library.h>
13 #include <unordered_map>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #else
19 #include <ATen/ops/empty.h>
20 #include <ATen/ops/empty_like.h>
21 #include <ATen/ops/zeros.h>
22 #include <ATen/ops/zeros_like.h>
23 #endif
24
25 #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
26
27 // Fwd declarations
28 namespace at {
29 struct TensorIteratorBase;
30 }
31 using namespace at::mps;
32
33 namespace at::native::mps {
34
35 void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
36
37 struct MPSScalar {
getMTLBufferMPSScalar38 id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
39
40 size_t size = 0;
41 ScalarType type = ScalarType::Undefined;
42 c10::DataPtr buffer; // stores MTLBuffer (frees buffer if MPSScalar instance goes out of scope)
43 union {
44 float f; // MPS doesn't support 'double'
45 at::Half h;
46 int64_t i;
47 bool b;
48 c10::complex<float> cf;
49 c10::complex<at::Half> ch;
50 at::BFloat16 bf16;
51 } value {};
52 };
53
54 void runMPSGraph(MPSStream* mpsStream,
55 MPSGraph* mpsGraph,
56 NSDictionary* feeds,
57 NSDictionary* results);
58
59 MPSDataType getMPSDataType(ScalarType scalar_type);
getMPSDataType(const Tensor & t)60 static inline MPSDataType getMPSDataType(const Tensor& t) {
61 return getMPSDataType(t.scalar_type());
62 }
63 MPSDataType getMPSScalarType(ScalarType scalar_type);
getMPSScalarType(const Tensor & t)64 static inline MPSDataType getMPSScalarType(const Tensor& t) {
65 return getMPSScalarType(t.scalar_type());
66 }
67 MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
68 std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
69 static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) {
70 return getMPSTypeString(t.scalar_type(), short_name);
71 }
72 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
scalarToMetalTypeString(const Tensor & t)73 static inline std::string scalarToMetalTypeString(const Tensor& t) {
74 return scalarToMetalTypeString(t.scalar_type());
75 }
76 NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
77 NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
78 std::string getMPSShapeString(MPSShape* shape);
79 std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
80 std::string getArrayRefString(const IntArrayRef s);
81 // use has_storage() on the returned tensor to determine if src actually is a view
82 Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
83 Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
84 bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
85 MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
86 MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
87 MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
88
89 MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
90 MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes = nil, MPSShape* strides = nil);
91 // The MPSShape could vary based on memory format
92 Tensor getTensorView(const Tensor& t, MPSShape* shape);
93 MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
94 MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
95
getMTLBufferStorage(const at::Tensor & tensor)96 static inline id<MTLBuffer> getMTLBufferStorage(const at::Tensor& tensor) {
97 return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
98 }
99
100 class Placeholder {
101 public:
Placeholder()102 Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor * mpsGraphTensor)103 Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
104 Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
105 Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
106 bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
getMPSGraphTensor()107 MPSGraphTensor* getMPSGraphTensor() {
108 return _placeholder;
109 }
getMPSGraphTensorData()110 MPSGraphTensorData* getMPSGraphTensorData() {
111 return _value;
112 }
isIntermediate()113 bool isIntermediate() {
114 return _value == nullptr;
115 }
116
117 private:
118 MPSGraphTensor* _placeholder;
119 MPSGraphTensorData* _value;
120 Tensor _tensor;
121 };
122
123 void resize_tensor(Tensor* output);
124 Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device);
125 MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
126 MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
127 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
128 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
129 MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
130 MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
131
132 MPSGraph* make_mps_graph();
133 void printTensorNDArray(const Tensor& t);
134 MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType);
135
136 MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
137 MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
138 MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
139 MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
140 MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar);
141
142 string get_mem_format_string(c10::MemoryFormat memory_format);
143
144 using MPSCacheKey = uint64_t;
145
146 // derive this class to cache a graph and its inputs/outputs
147 // can be used to store any NSObject
148 struct MPSCachedGraph
149 {
MPSCachedGraphMPSCachedGraph150 MPSCachedGraph(NSObject *object) : _object([object retain]) {}
~MPSCachedGraphMPSCachedGraph151 virtual ~MPSCachedGraph() {
152 [_object release];
153 _object = nullptr;
154 }
155
156 template<typename T>
asMPSCachedGraph157 inline T* as() {
158 return static_cast<T*>(this);
159 }
160
graphMPSCachedGraph161 MPSGraph *graph() const { return (MPSGraph *)_object; }
objectMPSCachedGraph162 NSObject *object() const { return _object; }
163 private:
164 NSObject *_object = nullptr;
165 };
166
167 struct MPSUnaryCachedGraph : public MPSCachedGraph
168 {
MPSUnaryCachedGraphMPSUnaryCachedGraph169 MPSUnaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
170 MPSGraphTensor *inputTensor_ = nil;
171 MPSGraphTensor *outputTensor_ = nil;
172 };
173
174 struct MPSUnaryGradCachedGraph : public MPSCachedGraph
175 {
MPSUnaryGradCachedGraphMPSUnaryGradCachedGraph176 MPSUnaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
177 MPSGraphTensor *gradOutputTensor_ = nil;
178 MPSGraphTensor *inputTensor_ = nil;
179 MPSGraphTensor *outputTensor_ = nil; // some backward input is actually the forward's output
180 MPSGraphTensor *gradInputTensor_ = nil;
181 };
182
183 struct MPSBinaryCachedGraph : public MPSCachedGraph
184 {
MPSBinaryCachedGraphMPSBinaryCachedGraph185 MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
186 MPSGraphTensor *inputTensor_ = nil;
187 MPSGraphTensor *otherTensor_ = nil;
188 MPSGraphTensor *outputTensor_ = nil;
189 };
190
191 struct MPSBinaryGradCachedGraph : public MPSCachedGraph
192 {
MPSBinaryGradCachedGraphMPSBinaryGradCachedGraph193 MPSBinaryGradCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
194 MPSGraphTensor *gradOutputTensor_ = nil;
195 MPSGraphTensor *inputTensor_ = nil;
196 MPSGraphTensor *otherTensor_ = nil;
197 MPSGraphTensor *gradInputTensor_ = nil;
198 };
199
200 // TODO: Improve the overall design of MPSGraphCache.
201 // https://github.com/pytorch/pytorch/issues/77176
202 // Cache holding various keys mapped to graphs
203 struct MPSGraphCache
204 {
205 typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
206
207 struct CacheEntry {
CacheEntryMPSGraphCache::CacheEntry208 CacheEntry(const std::string& key, MPSCachedGraph *cachedGraph) : cachedGraph_(cachedGraph), key_(key) {}
209 MPSCachedGraph* cachedGraph_ = nullptr;
210 std::string key_;
211 };
212
213 public:
214
getInstanceMPSGraphCache215 static MPSGraphCache* getInstance() {
216 if(_instance_cache == nullptr) {
217 _instance_cache = new MPSGraphCache();
218 }
219 return _instance_cache;
220 }
221
~MPSGraphCacheMPSGraphCache222 ~MPSGraphCache() {
223 dispatch_release(serialQueue_);
224
225 for (const auto& i : cache_) {
226 delete i.second.cachedGraph_;
227 }
228 }
229
230 // Disallow the copy constructor and operator= functions
231 MPSGraphCache(const MPSGraphCache&) = delete;
232 void operator=(const MPSGraphCache&) = delete;
233
CreateCachedGraphMPSGraphCache234 MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
235
236 __block MPSCachedGraph* cachedGraph = nil;
237
238 MPSCacheKey hash = std::hash<std::string>{}(key);
239
240 dispatch_sync_with_rethrow(serialQueue_, ^() {
241 // verify the cached entry doesn't already exist
242 if (cache_.count(hash) != 0) {
243 auto& entry = cache_.at(hash);
244 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
245 cachedGraph = entry.cachedGraph_;
246 } else {
247 cachedGraph = createCacheBlock();
248 CacheEntry entry(key, cachedGraph);
249 cache_.emplace(hash, entry);
250 profileCachedGraph(entry);
251 }
252 });
253 return cachedGraph;
254 }
255
256 template<typename T>
CreateCachedGraphAsMPSGraphCache257 inline T* CreateCachedGraphAs(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
258 return static_cast<T *>(CreateCachedGraph(key, createCacheBlock));
259 }
260
LookUpMPSGraphCache261 MPSCachedGraph* LookUp(const std::string& key) const {
262
263 __block MPSCachedGraph* cachedGraph = nullptr;
264
265 MPSCacheKey hash = std::hash<std::string>{}(key);
266
267 dispatch_sync(serialQueue_, ^() {
268
269 if (cache_.count(hash) != 0) {
270 auto& entry = cache_.at(hash);
271 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
272 cachedGraph = entry.cachedGraph_;
273 profileCachedGraph(entry);
274 }
275 });
276 return cachedGraph;
277 }
278
279 template<typename T>
LookUpAsMPSGraphCache280 inline T* LookUpAs(const std::string& key) const {
281 return static_cast<T *>(LookUp(key));
282 }
283
284 private:
MPSGraphCacheMPSGraphCache285 MPSGraphCache() {
286 serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
287 }
288 // this is defined in OperationUtils.mm to not include
289 // MPSProfiler.h in header OperationUtils.h
290 void profileCachedGraph(const CacheEntry& cacheEntry) const;
291
292 static MPSGraphCache* _instance_cache;
293 std::unordered_map<MPSCacheKey, CacheEntry> cache_;
294 dispatch_queue_t serialQueue_ = nullptr;
295
296 };
297
298 // Common template for creating graph with a specified cache if missing
299 template<typename T>
LookUpOrCreateCachedGraph(const std::string & key,std::function<void (MPSGraph *,T *)> instantiate)300 inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(MPSGraph*, T*)> instantiate) {
301 auto cache_ = MPSGraphCache::getInstance();
302 if (auto rc = cache_->LookUpAs<T>(key)) {
303 return rc;
304 }
305 return cache_->CreateCachedGraphAs<T>(key, ^mps::MPSCachedGraph*() {
306 T* newCachedGraph = nil;
307 @autoreleasepool {
308 // Initialize graph
309 auto mpsGraph = mps::make_mps_graph();
310 newCachedGraph = new T(mpsGraph);
311 instantiate(mpsGraph, newCachedGraph);
312 }
313 return newCachedGraph;
314 });
315 }
316
317 // Common math operations
318 MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
319
320 #define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
321 if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
322 TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
323 ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
324 }
325
326 /**
327 * Returns distance from lowest to highest element offset in given tensor.
328 */
329 size_t compute_storage_numel_distance(const at::Tensor& t);
330
331 /**
332 * Checks whether tensor is mapped to a contiguous area in the storage.
333 */
is_dense_in_storage(const at::Tensor & t)334 inline bool is_dense_in_storage(const at::Tensor& t) {
335 return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
336 }
337
338
339 class MetalShaderLibrary {
340 public:
MetalShaderLibrary(const std::string & src)341 MetalShaderLibrary(const std::string& src): shaderSource(src), nparams(0), compile_options(nullptr){}
MetalShaderLibrary(const std::string & src,unsigned nparams_)342 MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string & src,unsigned nparams_,MTLCompileOptions * compile_options_)343 MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
344 MetalShaderLibrary(const MetalShaderLibrary&) = delete;
getPipelineStateForFunc(const std::string & fname)345 inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
346 return getLibraryPipelineState(getLibrary(), fname).first;
347 }
getPipelineStateForFunc(const std::string & fname,const std::initializer_list<std::string> & params)348 id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
349 return getLibraryPipelineState(getLibrary(params), fname).first;
350 }
getMTLFunction(const std::string & fname)351 inline id<MTLFunction> getMTLFunction(const std::string& fname) {
352 return getLibraryPipelineState(getLibrary(), fname).second;
353 }
getMTLFunction(const std::string & fname,const std::initializer_list<std::string> & params)354 id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) {
355 return getLibraryPipelineState(getLibrary(params), fname).second;
356 }
357 private:
358 std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
359 id<MTLLibrary> getLibrary();
360 id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
361
362 id<MTLLibrary> compileLibrary(const std::string& src);
363 std::string shaderSource;
364 unsigned nparams;
365 MTLCompileOptions* compile_options;
366 id<MTLLibrary> library = nil;
367 std::unordered_map<std::string, id<MTLLibrary>> libMap;
368 std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
369 };
370
371 template<typename encoder_t,
372 typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
mtl_setBuffer(encoder_t encoder,const Tensor & t,unsigned idx)373 static inline void mtl_setBuffer(encoder_t encoder, const Tensor& t, unsigned idx) {
374 [encoder setBuffer:getMTLBufferStorage(t)
375 offset:t.storage_offset() * t.element_size()
376 atIndex:idx];
377 }
378
379 template<typename T,
380 typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
mtl_setBytes(id<MTLComputeCommandEncoder> encoder,const T val,unsigned idx)381 static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
382 [encoder setBytes:&val length:sizeof(T) atIndex: idx];
383 }
384
385 template<typename Container,
386 typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
mtl_setBytes(id<MTLComputeCommandEncoder> encoder,const Container & values,unsigned idx)387 static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
388 [encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex: idx];
389 }
390
mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,id<MTLComputePipelineState> cplState,uint32_t length)391 static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
392 id<MTLComputePipelineState> cplState,
393 uint32_t length) {
394 const uint32_t maxThreadsPerGroup = [cplState maxTotalThreadsPerThreadgroup];
395 auto size = MTLSizeMake(length, 1, 1);
396 auto threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, length), 1, 1);
397 [encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
398 }
399
400 id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
401
dictionaryFromPlaceholders(Placeholder & p1)402 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
403 return @{ p1.getMPSGraphTensor(): p1.getMPSGraphTensorData() };
404 }
405
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2)406 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2) {
407 return @{
408 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
409 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
410 };
411 }
412
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2,Placeholder & p3)413 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3) {
414 return @{
415 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
416 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
417 p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
418 };
419 }
420
dictionaryFromPlaceholders(Placeholder & p1,Placeholder & p2,Placeholder & p3,Placeholder & p4)421 inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1, Placeholder& p2, Placeholder& p3, Placeholder& p4) {
422 return @{
423 p1.getMPSGraphTensor(): p1.getMPSGraphTensorData(),
424 p2.getMPSGraphTensor(): p2.getMPSGraphTensorData(),
425 p3.getMPSGraphTensor(): p3.getMPSGraphTensorData(),
426 p4.getMPSGraphTensor(): p4.getMPSGraphTensorData(),
427 };
428 }
429
runMPSGraph(MPSStream * stream,MPSGraph * graph,NSDictionary * feeds,Placeholder & result)430 inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds, Placeholder& result) {
431 runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
432 }
433
supportsComplex()434 inline bool supportsComplex() {
435 return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
436 }
437
438 // MPS yet to support double types, but starting from MacOS 14, supports bfloat16
supportedFloatingType(ScalarType dtype)439 inline bool supportedFloatingType(ScalarType dtype) {
440 return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
441 }
442
supportedFloatingType(const Tensor & t)443 inline bool supportedFloatingType(const Tensor& t) {
444 return supportedFloatingType(t.scalar_type());
445 }
446
supportedFloatingOrComplexType(ScalarType dtype)447 inline bool supportedFloatingOrComplexType(ScalarType dtype) {
448 if (dtype == kComplexFloat || dtype == kComplexHalf) {
449 return supportsComplex();
450 }
451 return supportedFloatingType(dtype);
452 }
supportedFloatingOrComplexType(const Tensor & t)453 inline bool supportedFloatingOrComplexType(const Tensor& t) {
454 return supportedFloatingOrComplexType(t.scalar_type());
455 }
456
457
needsGather(const Tensor & t)458 inline bool needsGather(const Tensor& t) {
459 static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
460 return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ;
461 }
462
463 } // namespace at::native::mps
464