#include #include #include #ifdef AT_PER_OPERATOR_HEADERS #include #endif namespace at { class Tensor; namespace native { template void _assert_match(const O& original, const C& compared, const std::string& name) { if (compared) { bool equal = (original == compared.value()); if (!equal) { std::stringstream msg; msg << "Tensor " << name << " mismatch!"; AT_ASSERT(equal, msg.str()); } } } void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional dtype) { _assert_match(tensor.sizes(), sizes, "sizes"); _assert_match(tensor.strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); } } } // namespace at::native