xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalAten.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalTensorImpl.h>
2#import <ATen/native/metal/MetalTensorImplStorage.h>
3#import <ATen/native/metal/MetalContext.h>
4#import <ATen/native/metal/MetalTensorUtils.h>
5#include <ATen/metal/Context.h>
6#include <torch/script.h>
7
8namespace at {
9namespace native::metal {
10
11static Tensor& copy_from_metal_(Tensor& dst, const Tensor& src) {
12  TORCH_INTERNAL_ASSERT(
13      src.device().type() == DeviceType::Metal,
14      "copy_from_metal input tensor's device is not metal");
15  TORCH_INTERNAL_ASSERT(
16      dst.device().is_cpu(),
17      "copy_from_metal is implemented only for CPU device output");
18  TORCH_INTERNAL_ASSERT(
19      dst.layout() == Layout::Strided,
20      "copy_from_metal is implemented only for Strided layout output");
21  TORCH_INTERNAL_ASSERT(
22      dst.scalar_type() == ScalarType::Float,
23      "copy_from_metal is implemented only for float dtype output, got:",
24      dst.scalar_type());
25  TORCH_INTERNAL_ASSERT(
26      dst.is_contiguous(),
27      "copy_from_metal is implemented only for contiguous output tensor");
28  if(dst.numel() == 0){
29    return dst;
30  }
31  MetalTensorImplStorage& tensorImplStorage = getTensorImplStorage(src);
32  tensorImplStorage.copy_data_to_host(dst.data_ptr<float>());
33  return dst;
34}
35
36static Tensor& copy_to_metal_(Tensor& dst, const Tensor& src) {
37  TORCH_INTERNAL_ASSERT(
38      dst.device().type() == DeviceType::Metal,
39      "copy_to_metal_ output tensor's device is not metal");
40  TORCH_INTERNAL_ASSERT(
41      src.device().is_cpu(),
42      "copy_to_metal_ is implemented only for CPU device input");
43  TORCH_INTERNAL_ASSERT(
44      src.layout() == Layout::Strided,
45      "copy_to_metal_ is implemented only for Strided layout input");
46  TORCH_INTERNAL_ASSERT(
47      src.scalar_type() == ScalarType::Float,
48      "copy_to_metal_ is implemented only for float dtype");
49
50  auto cpu_tensor_contiguous = src.contiguous();
51  MetalTensorImplStorage& tensorImplStorage = getTensorImplStorage(dst);
52  tensorImplStorage.set_data_from_host(cpu_tensor_contiguous.data_ptr<float>());
53  return dst;
54}
55
56static Tensor& metal_copy_impl_(Tensor& dst, const Tensor& src) {
57  if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) {
58    return copy_from_metal_(dst, src);
59  }
60  if (src.device().type() == at::kCPU && dst.device().type() == at::kMetal) {
61    return copy_to_metal_(dst, src);
62  }
63  TORCH_INTERNAL_ASSERT(
64      src.device().type() == DeviceType::Metal,
65      "metal_copy_ is implemented only for CPU,Strided,float->Metal; Metal->CPU,Strided,float");
66  return dst;
67}
68
69#pragma mark - ATen Ops
70
71static Tensor empty(
72    c10::SymIntArrayRef sym_size,
73    std::optional<ScalarType> dtype,
74    std::optional<Layout> layout,
75    std::optional<Device> device,
76    std::optional<bool> pin_memory,
77    std::optional<MemoryFormat> memory_format) {
78  auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
79  TORCH_CHECK(
80      !pin_memory.has_value(),
81      "'pin_memory' argument is incompatible with Metal tensor");
82  TORCH_CHECK(
83      !memory_format.has_value(),
84      "'memory_format' argument is incompatible with Metal tensor");
85  MetalTensorImplStorage mt{size.vec()};
86  return makeTensor(
87      std::move(mt), at::device(at::kMetal).dtype(dtype));
88};
89
90static Tensor empty_strided(
91    IntArrayRef size,
92    IntArrayRef stride,
93    std::optional<ScalarType> dtype,
94    std::optional<Layout> layout,
95    std::optional<Device> device,
96    std::optional<bool> pin_memory) {
97  TORCH_CHECK(
98      !pin_memory.has_value() || !pin_memory.value(),
99      "'pin_memory' argument is incompatible with Metal tensor");
100  MetalTensorImplStorage mt{size.vec(), stride.vec()};
101  return makeTensor(
102      std::move(mt), at::device(at::kMetal).dtype(dtype));
103}
104
105
106TORCH_LIBRARY_IMPL(aten, Metal, m) {
107  m.impl(TORCH_SELECTIVE_NAME("aten::empty.memory_format"), empty);
108  m.impl(TORCH_SELECTIVE_NAME("aten::empty_strided"), TORCH_FN(empty_strided));
109}
110
111} // namespace native::metal
112
113struct MetalImpl : public at::metal::MetalInterface {
114  bool is_metal_available() const override {
115#if defined(USE_PYTORCH_METAL)
116    return [[MetalContext sharedInstance] available];
117#else
118    return false;
119#endif
120  }
121  at::Tensor& metal_copy_(at::Tensor& input, const at::Tensor& src)
122      const override {
123    TORCH_CHECK(
124        is_metal_available(), "Metal is not available on the current device");
125    return native::metal::metal_copy_impl_(input, src);
126  }
127};
128#if defined(USE_PYTORCH_METAL)
129static at::metal::MetalImplRegistrar g_metal_impl(new MetalImpl());
130#endif
131
132} // namespace at
133