1 #include <ATen/ATen.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <ATen/native/metal/MetalPrepackOpContext.h>
4 #include <c10/util/accumulate.h>
5
6 namespace at::native::metal {
7
unpack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)8 static c10::intrusive_ptr<Conv2dOpContext> unpack(
9 Tensor&& weight,
10 std::optional<Tensor>&& bias,
11 std::vector<int64_t>&& stride,
12 std::vector<int64_t>&& padding,
13 std::vector<int64_t>&& dilation,
14 int64_t groups,
15 const std::optional<Scalar>& output_min,
16 const std::optional<Scalar>& output_max) {
17 auto packedWeight = weight.contiguous(MemoryFormat::ChannelsLast);
18 return c10::make_intrusive<Conv2dOpContext>(
19 std::move(packedWeight),
20 std::move(bias),
21 stride,
22 padding,
23 dilation,
24 groups,
25 output_min,
26 output_max);
27 }
28
unpack(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)29 static c10::intrusive_ptr<LinearOpContext> unpack(
30 Tensor&& weight,
31 std::optional<Tensor>&& bias,
32 const std::optional<Scalar>& output_min,
33 const std::optional<Scalar>& output_max) {
34 TORCH_CHECK(weight.dim() == 2);
35 // Don't need to do `weight.t()`
36 auto packedWeight = weight.view({weight.size(0), weight.size(1), 1, 1})
37 .contiguous(MemoryFormat::ChannelsLast);
38 return c10::make_intrusive<LinearOpContext>(
39 std::move(packedWeight), std::move(bias), output_min, output_max);
40 }
41
TORCH_LIBRARY(metal,m)42 TORCH_LIBRARY(metal, m) {
43 m.class_<Conv2dOpContext>("Conv2dOpContext")
44 .def_pickle(
45 [](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
46 -> SerializationTypeConv2dPrePack { // __getstate__
47 return op_context->pack();
48 },
49 [](SerializationTypeConv2dPrePack state)
50 -> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
51 return unpack(
52 std::move(std::get<0>(state)),
53 std::move(std::get<1>(state)),
54 std::move(std::get<2>(state)),
55 std::move(std::get<3>(state)),
56 std::move(std::get<4>(state)),
57 std::move(std::get<5>(state)),
58 std::move(std::get<6>(state)),
59 std::move(std::get<7>(state)));
60 });
61 m.class_<LinearOpContext>("LinearOpContext")
62 .def_pickle(
63 [](const c10::intrusive_ptr<LinearOpContext>& op_context)
64 -> SerializationTypeLinearPrePack { // __getstate__
65 return op_context->pack();
66 },
67 [](SerializationTypeLinearPrePack state)
68 -> c10::intrusive_ptr<LinearOpContext> { // __setstate__
69 return unpack(
70 std::move(std::get<0>(state)),
71 std::move(std::get<1>(state)),
72 std::get<2>(state),
73 std::get<3>(state));
74 });
75 m.def("copy_to_host(Tensor X) -> Tensor Y");
76 }
77
TORCH_LIBRARY(metal_prepack,m)78 TORCH_LIBRARY(metal_prepack, m) {
79 m.def(
80 TORCH_SELECTIVE_SCHEMA("metal_prepack::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
81 "int[2] padding, int[2] dilation, int groups, "
82 "Scalar? output_min=None, Scalar? output_max=None) "
83 "-> __torch__.torch.classes.metal.Conv2dOpContext"));
84 m.def(
85 TORCH_SELECTIVE_SCHEMA("metal_prepack::conv2d_run(Tensor X, "
86 "__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y"));
87
88 m.def(
89 TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_prepack(Tensor W, Tensor? B, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.metal.LinearOpContext"));
90
91 m.def(
92 TORCH_SELECTIVE_SCHEMA("metal_prepack::linear_run(Tensor X, __torch__.torch.classes.metal.LinearOpContext W_prepack) -> Tensor Y"));
93 }
94
conv2d_prepack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)95 static c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
96 Tensor&& weight,
97 std::optional<Tensor>&& bias,
98 std::vector<int64_t>&& stride,
99 std::vector<int64_t>&& padding,
100 std::vector<int64_t>&& dilation,
101 int64_t groups,
102 const std::optional<Scalar>& output_min,
103 const std::optional<Scalar>& output_max) {
104 TORCH_CHECK(weight.dim() == 4);
105 return c10::make_intrusive<Conv2dOpContext>(
106 std::move(weight),
107 std::move(bias),
108 stride,
109 padding,
110 dilation,
111 groups,
112 output_min,
113 output_max);
114 }
115
linear_prepack(Tensor && weight,std::optional<Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)116 static c10::intrusive_ptr<LinearOpContext> linear_prepack(
117 Tensor&& weight,
118 std::optional<Tensor>&& bias,
119 const std::optional<Scalar>& output_min,
120 const std::optional<Scalar>& output_max) {
121 return c10::make_intrusive<LinearOpContext>(
122 std::move(weight), std::move(bias), output_min, output_max);
123 }
124
TORCH_LIBRARY_IMPL(metal_prepack,CPU,m)125 TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
126 m.impl(TORCH_SELECTIVE_NAME("metal_prepack::conv2d_prepack"), TORCH_FN(conv2d_prepack));
127 m.impl(TORCH_SELECTIVE_NAME("metal_prepack::linear_prepack"), TORCH_FN(linear_prepack));
128 }
129
130 } // namespace at::native::metal
131