1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <ATen/ATen.h> 9 #include <vector> 10 11 namespace torch::distributed::c10d::quantization { 12 13 at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input); 14 at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input); 15 16 } // namespace torch::distributed::c10d::quantization 17