xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <ATen/native/metal/MetalPrepackOpContext.h>
4 #include <c10/util/accumulate.h>
5 
6 namespace at::native::metal {
7 
unpack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)8 static c10::intrusive_ptr<Conv2dOpContext> unpack(
9     Tensor&& weight,
10     std::optional<Tensor>&& bias,
11     std::vector<int64_t>&& stride,
12     std::vector<int64_t>&& padding,
13     std::vector<int64_t>&& dilation,
14     int64_t groups,
15     const std::optional<Scalar>& output_min,
16     const std::optional<Scalar>& output_max) {
17   auto packedWeight = weight.contiguous(MemoryFormat::ChannelsLast);
18   return c10::make_intrusive<Conv2dOpContext>(
19       std::move(packedWeight),
20       std::move(bias),
21       stride,
22       padding,
23       dilation,
24       groups,
25       output_min,
26       output_max);
27 }
28 
unpack(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)29 static c10::intrusive_ptr<LinearOpContext> unpack(
30     Tensor&& weight,
31     std::optional<Tensor>&& bias,
32     const std::optional<Scalar>& output_min,
33     const std::optional<Scalar>& output_max) {
34   TORCH_CHECK(weight.dim() == 2);
35   // Don't need to do `weight.t()`
36   auto packedWeight = weight.view({weight.size(0), weight.size(1), 1, 1})
37                           .contiguous(MemoryFormat::ChannelsLast);
38   return c10::make_intrusive<LinearOpContext>(
39       std::move(packedWeight), std::move(bias), output_min, output_max);
40 }
41 
TORCH_LIBRARY(metal,m)42 TORCH_LIBRARY(metal, m) {
43   m.class_<Conv2dOpContext>("Conv2dOpContext")
44       .def_pickle(
45           [](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
46               -> SerializationTypeConv2dPrePack { // __getstate__
47             return op_context->pack();
48           },
49           [](SerializationTypeConv2dPrePack state)
50               -> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
51             return unpack(
52                 std::move(std::get<0>(state)),
53                 std::move(std::get<1>(state)),
54                 std::move(std::get<2>(state)),
55                 std::move(std::get<3>(state)),
56                 std::move(std::get<4>(state)),
57                 std::move(std::get<5>(state)),
58                 std::move(std::get<6>(state)),
59                 std::move(std::get<7>(state)));
60           });
61   m.class_<LinearOpContext>("LinearOpContext")
62       .def_pickle(
63           [](const c10::intrusive_ptr<LinearOpContext>& op_context)
64               -> SerializationTypeLinearPrePack { // __getstate__
65             return op_context->pack();
66           },
67           [](SerializationTypeLinearPrePack state)
68               -> c10::intrusive_ptr<LinearOpContext> { // __setstate__
69             return unpack(
70                 std::move(std::get<0>(state)),
71                 std::move(std::get<1>(state)),
72                 std::get<2>(state),
73                 std::get<3>(state));
74           });
75   m.def("copy_to_host(Tensor X) -> Tensor Y");
76 }
77 
TORCH_LIBRARY(metal_prepack,m)78 TORCH_LIBRARY(metal_prepack, m) {
79   m.def(
80       TORCH_SELECTIVE_SCHEMA("metal_prepack::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
81       "int[2] padding, int[2] dilation, int groups, "
82       "Scalar? output_min=None, Scalar? output_max=None) "
83       "-> __torch__.torch.classes.metal.Conv2dOpContext"));
84   m.def(
85       TORCH_SELECTIVE_SCHEMA("metal_prepack::conv2d_run(Tensor X, "
86       "__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y"));
87 
88   m.def(
89       TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_prepack(Tensor W, Tensor? B, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.metal.LinearOpContext"));
90 
91   m.def(
92       TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_run(Tensor X, __torch__.torch.classes.metal.LinearOpContext W_prepack) -> Tensor Y"));
93 }
94 
conv2d_prepack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)95 static c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
96     Tensor&& weight,
97     std::optional<Tensor>&& bias,
98     std::vector<int64_t>&& stride,
99     std::vector<int64_t>&& padding,
100     std::vector<int64_t>&& dilation,
101     int64_t groups,
102     const std::optional<Scalar>& output_min,
103     const std::optional<Scalar>& output_max) {
104   TORCH_CHECK(weight.dim() == 4);
105   return c10::make_intrusive<Conv2dOpContext>(
106       std::move(weight),
107       std::move(bias),
108       stride,
109       padding,
110       dilation,
111       groups,
112       output_min,
113       output_max);
114 }
115 
linear_prepack(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)116 static c10::intrusive_ptr<LinearOpContext> linear_prepack(
117     Tensor&& weight,
118     std::optional<Tensor>&& bias,
119     const std::optional<Scalar>& output_min,
120     const std::optional<Scalar>& output_max) {
121   return c10::make_intrusive<LinearOpContext>(
122       std::move(weight), std::move(bias), output_min, output_max);
123 }
124 
TORCH_LIBRARY_IMPL(metal_prepack,CPU,m)125 TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
126   m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_prepack"), TORCH_FN(conv2d_prepack));
127   m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_prepack"), TORCH_FN(linear_prepack));
128 }
129 
130 } // namespace at::native::metal
131