xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/quantization/quantization_gpu.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <ATen/ATen.h>
9 #include <vector>
10 
11 namespace torch::distributed::c10d::quantization {
12 
13 at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input);
14 at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input);
15 
16 } // namespace torch::distributed::c10d::quantization
17