xref: /aosp_15_r20/external/pytorch/aten/src/ATen/BlasBackend.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <ostream>
6*da0073e9SAndroid Build Coastguard Worker #include <string>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace at {
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker enum class BlasBackend : int8_t { Cublas, Cublaslt };
11*da0073e9SAndroid Build Coastguard Worker 
BlasBackendToString(at::BlasBackend backend)12*da0073e9SAndroid Build Coastguard Worker inline std::string BlasBackendToString(at::BlasBackend backend) {
13*da0073e9SAndroid Build Coastguard Worker   switch (backend) {
14*da0073e9SAndroid Build Coastguard Worker     case BlasBackend::Cublas:
15*da0073e9SAndroid Build Coastguard Worker       return "at::BlasBackend::Cublas";
16*da0073e9SAndroid Build Coastguard Worker     case BlasBackend::Cublaslt:
17*da0073e9SAndroid Build Coastguard Worker       return "at::BlasBackend::Cublaslt";
18*da0073e9SAndroid Build Coastguard Worker     default:
19*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(false, "Unknown blas backend");
20*da0073e9SAndroid Build Coastguard Worker   }
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
24*da0073e9SAndroid Build Coastguard Worker   return stream << BlasBackendToString(backend);
25*da0073e9SAndroid Build Coastguard Worker }
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker } // namespace at
28