1 #pragma once 2 3 #ifdef USE_VULKAN_API 4 5 #include <ATen/native/vulkan/ops/Common.h> 6 #include <ATen/native/vulkan/ops/VulkanPackedContext.h> 7 8 namespace at { 9 namespace native { 10 namespace vulkan { 11 namespace ops { 12 13 enum Conv2dMethod { 14 Conv2dDepthwise, 15 Conv2dPointwise, 16 Conv2dSlidingWindow, 17 }; 18 19 namespace conv2d { 20 21 Tensor rearrange_weights_dw(const Tensor& weight_in); 22 Tensor rearrange_weights_2d(const Tensor& weight_in, bool tconv); 23 Tensor rearrange_bias( 24 const std::optional<Tensor>& bias_in, 25 const at::Tensor& weight_in, 26 bool tconv); 27 28 } // namespace conv2d 29 30 namespace qconv2d_vk { 31 32 struct QParams final { 33 api::utils::uvec3 out_extents; 34 int32_t ic4; 35 api::utils::ivec4 sizes_2d; 36 float output_scale; 37 float input_scale; 38 int32_t output_zero_point; 39 int32_t input_zero_point; 40 float weight_scale; 41 float bias_scale; 42 int32_t weight_zero_point; 43 int32_t bias_zero_point; 44 api::utils::ivec2 kernel_size; 45 api::utils::ivec2 stride; 46 api::utils::ivec2 padding; 47 api::utils::ivec2 dilate; 48 api::utils::vec2 clamp; 49 api::utils::ivec4 src_filter; 50 }; 51 52 } // namespace qconv2d_vk 53 54 class Conv2dPackedContext final : virtual public VulkanPackedContext, 55 public torch::jit::CustomClassHolder { 56 private: 57 c10::impl::GenericList unpacked_; 58 api::ShaderInfo compute_shader_{}; 59 60 public: 61 Conv2dPackedContext( 62 const Tensor& weight, 63 const std::optional<Tensor>& bias, 64 const IntArrayRef stride_arg, 65 const IntArrayRef padding_arg, 66 const IntArrayRef dilation_arg, 67 const bool transposed, 68 const bool quantized, 69 const IntArrayRef output_padding_arg, 70 const int64_t groups, 71 const std::optional<Scalar>& output_min = std::nullopt, 72 const std::optional<Scalar>& output_max = std::nullopt); 73 74 /* 75 * Assigns a name to each index in the unpacked list. 76 */ 77 struct Unpacked final { 78 static constexpr uint32_t Weight = 0u; 79 static constexpr uint32_t Bias = 1u; 80 static constexpr uint32_t Stride = 2u; 81 static constexpr uint32_t Padding = 3u; 82 static constexpr uint32_t Dilation = 4u; 83 static constexpr uint32_t isTransposed = 5u; 84 static constexpr uint32_t isQuantized = 6u; 85 static constexpr uint32_t OutputPadding = 7u; 86 static constexpr uint32_t Groups = 8u; 87 static constexpr uint32_t OutputMin = 9u; 88 static constexpr uint32_t OutputMax = 10u; 89 90 static constexpr uint32_t NumArgs = 11u; 91 }; 92 93 /* 94 * Assigns a name to each index in the packed list. 95 */ 96 struct Packed final { 97 static constexpr uint32_t Weight = 0u; 98 static constexpr uint32_t Bias = 1u; 99 static constexpr uint32_t OverlayRegion = 2u; 100 static constexpr uint32_t Stride = 3u; 101 static constexpr uint32_t Padding = 4u; 102 static constexpr uint32_t OutputPadding = 5u; 103 static constexpr uint32_t Dilation = 6u; 104 static constexpr uint32_t isTransposed = 7u; 105 static constexpr uint32_t isQuantized = 8u; 106 static constexpr uint32_t Groups = 9u; 107 static constexpr uint32_t OutputMin = 10u; 108 static constexpr uint32_t OutputMax = 11u; 109 static constexpr uint32_t ConvMethod = 12u; 110 static constexpr uint32_t WeightSizes = 13u; 111 112 static constexpr uint32_t NumArgs = 14u; 113 }; 114 115 static Conv2dPackedContext pack(c10::impl::GenericList); 116 unpack()117 const c10::impl::GenericList unpack() const override { 118 TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!"); 119 120 return unpacked_; 121 } 122 compute_shader()123 inline api::ShaderInfo& compute_shader() { 124 return compute_shader_; 125 } 126 }; 127 128 c10::intrusive_ptr<Conv2dPackedContext> create_conv2d_context( 129 Tensor&& weight, 130 std::optional<Tensor>&& bias, 131 std::vector<int64_t>&& stride, 132 std::vector<int64_t>&& padding, 133 std::vector<int64_t>&& dilation, 134 const int64_t groups, 135 const std::optional<Scalar>& output_min = std::nullopt, 136 const std::optional<Scalar>& output_max = std::nullopt); 137 138 Tensor run_conv2d_context( 139 const Tensor& input, 140 const c10::intrusive_ptr<Conv2dPackedContext>& context); 141 142 c10::intrusive_ptr<Conv2dPackedContext> create_tconv2d_context( 143 Tensor&& weight, 144 std::optional<Tensor>&& bias, 145 std::vector<int64_t>&& stride, 146 std::vector<int64_t>&& padding, 147 std::vector<int64_t>&& output_padding, 148 std::vector<int64_t>&& dilation, 149 const int64_t groups, 150 const std::optional<Scalar>& output_min = std::nullopt, 151 const std::optional<Scalar>& output_max = std::nullopt); 152 153 Tensor run_tconv2d_context( 154 const Tensor& input, 155 const c10::intrusive_ptr<Conv2dPackedContext>& context); 156 157 c10::intrusive_ptr<Conv2dPackedContext> create_qconv2d_context( 158 Tensor&& weight, 159 std::optional<Tensor>&& bias, 160 std::vector<int64_t>&& stride, 161 std::vector<int64_t>&& padding, 162 std::vector<int64_t>&& dilation, 163 const int64_t groups, 164 const std::optional<Scalar>& output_min = std::nullopt, 165 const std::optional<Scalar>& output_max = std::nullopt); 166 167 Tensor run_qconv2d_context( 168 const Tensor& input_arg, 169 double scale, 170 int64_t zero_point, 171 const c10::intrusive_ptr<Conv2dPackedContext>& conv_context); 172 173 c10::intrusive_ptr<Conv2dPackedContext> create_qtconv2d_context( 174 Tensor&& weight, 175 std::optional<Tensor>&& bias, 176 std::vector<int64_t>&& stride, 177 std::vector<int64_t>&& padding, 178 std::vector<int64_t>&& output_padding, 179 std::vector<int64_t>&& dilation, 180 const int64_t groups, 181 const std::optional<Scalar>& output_min = std::nullopt, 182 const std::optional<Scalar>& output_max = std::nullopt); 183 184 // Backwards compatibility 185 class Conv2dOpContext final : public torch::jit::CustomClassHolder { 186 public: 187 static Conv2dOpContext create( 188 const Tensor& weight, 189 const std::optional<Tensor>& bias, 190 IntArrayRef stride, 191 IntArrayRef padding, 192 IntArrayRef dilation, 193 bool transposed, 194 IntArrayRef output_padding, 195 int64_t groups, 196 const std::optional<Scalar>& output_min = std::nullopt, 197 const std::optional<Scalar>& output_max = std::nullopt); 198 199 using State = std::tuple< 200 Tensor, 201 std::optional<Tensor>, 202 std::vector<int64_t>, 203 std::vector<int64_t>, 204 std::vector<int64_t>, 205 int64_t, 206 std::optional<Scalar>, 207 std::optional<Scalar>>; 208 209 Tensor run(const Tensor& input) const; 210 State unpack() const; 211 212 private: 213 explicit Conv2dOpContext(Conv2dPackedContext conv_context); 214 Conv2dPackedContext conv_context_; 215 }; 216 217 Tensor conv2d_clamp_run( 218 const Tensor& input, 219 const c10::intrusive_ptr<Conv2dOpContext>& context); 220 221 c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack( 222 Tensor&& weight, 223 std::optional<Tensor>&& bias, 224 std::vector<int64_t>&& stride, 225 std::vector<int64_t>&& padding, 226 std::vector<int64_t>&& dilation, 227 const int64_t groups, 228 const std::optional<Scalar>& output_min, 229 const std::optional<Scalar>& output_max); 230 231 class Conv1dPackedContext final : virtual public VulkanPackedContext, 232 public torch::jit::CustomClassHolder { 233 private: 234 c10::impl::GenericList unpacked_; 235 api::ShaderInfo compute_shader_{}; 236 237 public: 238 Conv1dPackedContext( 239 const Tensor& weight, 240 const std::optional<Tensor>& bias, 241 const IntArrayRef stride_arg, 242 const IntArrayRef padding_arg, 243 const IntArrayRef dilation_arg, 244 const int64_t groups); 245 246 /* 247 * Assigns a name to each index in the unpacked list. 248 */ 249 struct Unpacked final { 250 static constexpr uint32_t Weight = 0u; 251 static constexpr uint32_t Bias = 1u; 252 static constexpr uint32_t Stride = 2u; 253 static constexpr uint32_t Padding = 3u; 254 static constexpr uint32_t Dilation = 4u; 255 static constexpr uint32_t Groups = 5u; 256 257 static constexpr uint32_t NumArgs = 6u; 258 }; 259 260 /* 261 * Assigns a name to each index in the packed list. 262 */ 263 struct Packed final { 264 static constexpr uint32_t Weight = 0u; 265 static constexpr uint32_t Bias = 1u; 266 static constexpr uint32_t Stride = 2u; 267 static constexpr uint32_t Padding = 3u; 268 static constexpr uint32_t Dilation = 4u; 269 static constexpr uint32_t Groups = 5u; 270 static constexpr uint32_t WeightSizes = 6u; 271 272 static constexpr uint32_t NumArgs = 7u; 273 }; 274 275 static Conv1dPackedContext pack(c10::impl::GenericList); 276 unpack()277 const c10::impl::GenericList unpack() const override { 278 TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!"); 279 280 return unpacked_; 281 } 282 compute_shader()283 inline api::ShaderInfo& compute_shader() { 284 return compute_shader_; 285 } 286 }; 287 288 c10::intrusive_ptr<Conv1dPackedContext> create_conv1d_context( 289 Tensor&& weight, 290 std::optional<Tensor>&& bias, 291 std::vector<int64_t>&& stride, 292 std::vector<int64_t>&& padding, 293 std::vector<int64_t>&& dilation, 294 const int64_t groups); 295 296 Tensor run_conv1d_context( 297 const Tensor& input, 298 const c10::intrusive_ptr<Conv1dPackedContext>& context); 299 300 } // namespace ops 301 } // namespace vulkan 302 } // namespace native 303 } // namespace at 304 305 #endif /* USE_VULKAN_API */ 306