/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include namespace executorch { namespace extension { namespace { void check_tensor_meta(const at::Tensor& a, const exec_aten::Tensor& b) { // Check sizes/strides pointers ET_CHECK_MSG( b.sizes().data() != nullptr, "ETensor must have valid sizes array"); ET_CHECK_MSG( b.strides().data() != nullptr, "ETensor must have valid strides array"); // Check disabled because in ASR model we get 1 element tensor with different // rank. /* ET_CHECK_MSG( a.dim() == b.dim(), "at::Tensor and ETensor must have same rank." " at::Tensor rank %zd, ETensor rank %zd.", a.dim(), b.dim()); */ // check sizes for (size_t i = 0, dims = a.dim(); i < dims; ++i) { ET_CHECK_MSG( a.size(i) == b.size(i), "Sizes dont match at index:%zd, a size %zd != b size %zd", i, ssize_t(a.size(i)), ssize_t(b.size(i))); } // check strides for (size_t i = 0, dims = a.dim(); i < dims; ++i) { // Dont match strides if the size is 1. // Why? Because tensor is non-contig only if // strides dont match product(sizes[i:]) when size(i) > 1 // Strong assumption that must be tested and validated. ET_CHECK_MSG( (a.size(i) == 1 || (a.stride(i) == b.strides()[i])), "Strides dont match at index:%zd, a stride %zd != b stride %zd", i, ssize_t(a.stride(i)), ssize_t(b.strides()[i])); } // check dtype ET_CHECK_MSG( b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()), "dtypes dont match a %hhd vs. b %hhd", static_cast(torch_to_executorch_scalar_type(a.options().dtype())), static_cast(b.scalar_type())); } } // namespace executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type( caffe2::TypeMeta type) { const auto intermediate = static_cast::type>( c10::typeMetaToScalarType(type)); ET_CHECK_MSG( intermediate >= 0 && intermediate <= static_cast::type>( executorch::runtime::etensor::ScalarType::UInt64), "ScalarType %d unsupported in Executorch", intermediate); return static_cast(intermediate); } c10::ScalarType executorch_to_torch_scalar_type( torch::executor::ScalarType type) { const auto intermediate = static_cast< std::underlying_type::type>( type); ET_CHECK_MSG( intermediate >= 0 && intermediate <= static_cast::type>( executorch::runtime::etensor::ScalarType::UInt64), "ScalarType %d unsupported in Executorch", intermediate); return static_cast(intermediate); } /* * Following makes two assumptions: * 1. aten_tensor's lifetime is longer than the liftime within which mutable_et * is consumed * 2. memory previously allocated to mutable_et, is leaked. However under the * assumption , a strong one, that, such memory is arena allocated whose * lifetime is tied to model's lifetime, we assume that memory is not leaked as * it is freed when arean is freed. * @param[in] aten_tensor Input at::Tensor * @param[in/out] mutable_et ETensor whose underlying memory now will alias to * aten_tensor */ void alias_etensor_to_attensor( at::Tensor& aten_tensor, torch::executor::Tensor& mutable_et) { // TODO(kimishpatel): contiguous according to memformat // Right now we assume everything is channels first contiguous // Note that input tensor must be contiguous for us to alias. // Mixing aliasing and copying is dangerous since if we aliased // the instance of mutatble_et to aten_tensor in the previous call, // then in the next call copying will not be the correct behavior. ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous"); check_tensor_meta(aten_tensor, mutable_et); mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr()); } at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) { c10::ScalarType dtype = executorch_to_torch_scalar_type(etensor.scalar_type()); std::vector at_tensor_sizes( etensor.sizes().begin(), etensor.sizes().end()); std::vector at_tensor_strides( etensor.strides().begin(), etensor.strides().end()); at::Tensor t = at::from_blob( etensor.mutable_data_ptr(), at_tensor_sizes, at_tensor_strides, at::TensorOptions(dtype)); check_tensor_meta(t, etensor); return t; } TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) { return make_tensor_ptr( {t.sizes().begin(), t.sizes().end()}, t.mutable_data_ptr(), torch::executor::ScalarType(t.scalar_type())); } } // namespace extension } // namespace executorch