#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #endif namespace at::native { Tensor conv_tbc(const Tensor& self, const Tensor& weight, const Tensor& bias, int64_t pad) { TORCH_CHECK(self.dim() == 3, "Input must have 3 dims: time, batch, " "in_channel"); TORCH_CHECK(weight.dim() == 3, "Weight tensor must have 3 dims: kernel_width," " in_channels, out_channels."); TORCH_CHECK(bias.dim() == 1, "Bias must be 1-D"); auto input_size = self.sizes(); auto weight_size = weight.sizes(); auto ilen = input_size[0]; auto batchSize = input_size[1]; auto inputPlanes = input_size[2]; auto outputPlanes = weight_size[2]; auto kw = weight_size[0]; auto olen = input_size[0] - kw + 1 + pad * 2; auto real_pad = (olen - ilen + kw - 1) / 2; // Make sure shapes are correct. // Input = (time, batch, in_channels) // Weight = (kernel_width, in_channels, out_channels) // Bias = (out_channels) TORCH_CHECK(inputPlanes == weight_size[1], "Input dim 2 (input channels) " "is not == dim 1 in the weight tensor"); TORCH_CHECK(weight_size[2] == bias.sizes()[0], "Bias size must equal dim 2 in " "the weight tensor (output channels)."); // input * weights + bias -> output_features Tensor output = at::empty({ olen, input_size[1], weight_size[2], }, self.options()); output.copy_(bias.expand(output.sizes())); for (const auto k : c10::irange(kw)) { int iShift = std::max(0, static_cast(k - real_pad)); int oShift = std::max(0, static_cast(real_pad - k)); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) int t = std::min(ilen + real_pad - k, olen) - oShift; // Note: gemm assumes column-major matrices // input is l*m (row-major) // weight is m*r (row-major) // output is l*r (row-major) if (t > 0) { auto W = weight[k]; auto I = self.narrow(0, iShift, t).view({t * batchSize, inputPlanes}); auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes}); O.addmm_(I, W); } } return output; } std::tuple conv_tbc_backward(const Tensor& dOutput, const Tensor& input, const Tensor& weight, const Tensor& bias, int64_t pad) { auto input_size = input.sizes(); auto weight_size = weight.sizes(); auto ilen = input_size[0]; auto batchSize = input_size[1]; auto inputPlanes = input_size[2]; auto outputPlanes = weight_size[2]; auto kw = weight.sizes()[0]; auto olen = input_size[0] - kw + 1 + pad * 2; // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) int real_pad = (olen - ilen + kw - 1) / 2; Tensor dInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); for (int k = 0; k < kw; k++) { int iShift = std::max(0, k - real_pad); int oShift = std::max(0, real_pad - k); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) int t = std::min(ilen + real_pad - k, olen) - oShift; // dOutput * T(weight) -> dInput if (t > 0) { auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes}); auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes}); dI.addmm_(dO, weight[k].t()); } } Tensor dWeight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); for (int k = 0; k < kw; k++) { int iShift = std::max(0, k - real_pad); int oShift = std::max(0, real_pad - k); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) int t = std::min(ilen + real_pad - k, olen) - oShift; // T(input) * dOutput -> dWeight if (t > 0) { auto dW = dWeight[k]; auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes}); auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t(); dW.addmm_(I, dO); } } Tensor dBias = at::zeros_like(bias, LEGACY_CONTIGUOUS_MEMORY_FORMAT); auto tmp = dOutput.sum(0, false); dBias.copy_(tmp.sum(0)); return std::make_tuple(dInput, dWeight, dBias); } } // namespace at::native