xref: /aosp_15_r20/external/executorch/extension/aten_util/test/aten_bridge_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <mutex>
10 #include <numeric>
11 #include <random>
12 
13 #include <executorch/extension/aten_util/aten_bridge.h>
14 #include <executorch/test/utils/DeathTest.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using namespace torch::executor;
20 using namespace torch::executor::util;
21 using namespace executorch::extension;
22 
23 namespace {
generate_at_tensor()24 at::Tensor generate_at_tensor() {
25   return at::empty({4, 5, 6});
26 }
get_default_dim_order(const at::Tensor & t)27 std::vector<Tensor::DimOrderType> get_default_dim_order(const at::Tensor& t) {
28   std::vector<Tensor::DimOrderType> dim_order(t.dim());
29   std::iota(dim_order.begin(), dim_order.end(), 0);
30   return dim_order;
31 }
32 } // namespace
33 
TEST(ATenBridgeTest,AliasETensorToATenTensor)34 TEST(ATenBridgeTest, AliasETensorToATenTensor) {
35   auto at_tensor = generate_at_tensor();
36   std::vector<Tensor::SizesType> sizes(
37       at_tensor.sizes().begin(), at_tensor.sizes().end());
38   auto dim_order = get_default_dim_order(at_tensor);
39   std::vector<Tensor::StridesType> strides(
40       at_tensor.strides().begin(), at_tensor.strides().end());
41   auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
42   torch::executor::TensorImpl tensor_impl(
43       dtype,
44       at_tensor.dim(),
45       sizes.data(),
46       nullptr,
47       dim_order.data(),
48       strides.data());
49   torch::executor::Tensor etensor(&tensor_impl);
50   alias_etensor_to_attensor(at_tensor, etensor);
51   EXPECT_EQ(at_tensor.const_data_ptr(), etensor.const_data_ptr());
52 }
53 
TEST(ATenBridgeTest,AliasETensorToATenTensorFail)54 TEST(ATenBridgeTest, AliasETensorToATenTensorFail) {
55   auto at_tensor = generate_at_tensor();
56   std::vector<Tensor::SizesType> sizes(
57       at_tensor.sizes().begin(), at_tensor.sizes().end());
58   auto dim_order = get_default_dim_order(at_tensor);
59   std::vector<Tensor::StridesType> strides(
60       at_tensor.strides().begin(), at_tensor.strides().end());
61   auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
62   std::unique_ptr<torch::executor::TensorImpl> tensor_impl =
63       std::make_unique<TensorImpl>(
64           dtype, 1, sizes.data(), nullptr, dim_order.data(), strides.data());
65   torch::executor::Tensor etensor(tensor_impl.get());
66   // Empty sizes on etensor
67   ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
68 
69   strides = std::vector<Tensor::StridesType>();
70   tensor_impl = std::make_unique<torch::executor::TensorImpl>(
71       dtype,
72       at_tensor.dim(),
73       sizes.data(),
74       nullptr,
75       dim_order.data(),
76       strides.data());
77   etensor = torch::executor::Tensor(tensor_impl.get());
78   // Empty strides on etensor
79   ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
80 }
81 
TEST(ATenBridgeTest,AliasETensorToATenTensorNonContiguous)82 TEST(ATenBridgeTest, AliasETensorToATenTensorNonContiguous) {
83   auto at_tensor = generate_at_tensor();
84   auto sliced_tensor = at_tensor.slice(1, 0, 2);
85   auto sliced_tensor_contig = sliced_tensor.contiguous();
86   std::vector<Tensor::SizesType> sizes(
87       sliced_tensor.sizes().begin(), sliced_tensor.sizes().end());
88   auto dim_order = get_default_dim_order(at_tensor);
89   std::vector<Tensor::StridesType> strides(
90       sliced_tensor_contig.strides().begin(),
91       sliced_tensor_contig.strides().end());
92   auto dtype = torchToExecuTorchScalarType(sliced_tensor.options().dtype());
93   std::vector<uint8_t> etensor_data(sliced_tensor_contig.nbytes());
94   torch::executor::TensorImpl tensor_impl(
95       dtype,
96       sliced_tensor.dim(),
97       sizes.data(),
98       etensor_data.data(),
99       dim_order.data(),
100       strides.data());
101   torch::executor::Tensor etensor(&tensor_impl);
102   alias_etensor_to_attensor(sliced_tensor_contig, etensor);
103   EXPECT_EQ(sliced_tensor_contig.const_data_ptr(), etensor.const_data_ptr());
104   EXPECT_NE(sliced_tensor.const_data_ptr(), etensor.const_data_ptr());
105 }
106 
TEST(ATenBridgeTest,AliasETensorToATenTensorNonContiguousFail)107 TEST(ATenBridgeTest, AliasETensorToATenTensorNonContiguousFail) {
108   auto at_tensor = generate_at_tensor();
109   auto sliced_tensor = at_tensor.slice(1, 0, 2);
110   auto sliced_tensor_contig = sliced_tensor.contiguous();
111   std::vector<Tensor::SizesType> sizes(
112       sliced_tensor.sizes().begin(), sliced_tensor.sizes().end());
113   std::vector<Tensor::StridesType> strides(
114       sliced_tensor_contig.strides().begin(),
115       sliced_tensor_contig.strides().end());
116   auto dtype = torchToExecuTorchScalarType(sliced_tensor.options().dtype());
117   std::vector<uint8_t> etensor_data(sliced_tensor_contig.nbytes());
118   auto dim_order = get_default_dim_order(at_tensor);
119   torch::executor::TensorImpl tensor_impl(
120       dtype,
121       sliced_tensor.dim(),
122       sizes.data(),
123       etensor_data.data(),
124       dim_order.data(),
125       strides.data());
126   torch::executor::Tensor etensor(&tensor_impl);
127   ET_EXPECT_DEATH(alias_etensor_to_attensor(sliced_tensor, etensor), "");
128 }
129 
TEST(ATenBridgeTest,AliasATTensorToETensor)130 TEST(ATenBridgeTest, AliasATTensorToETensor) {
131   auto at_tensor = generate_at_tensor();
132   std::vector<Tensor::SizesType> sizes(
133       at_tensor.sizes().begin(), at_tensor.sizes().end());
134   auto dim_order = get_default_dim_order(at_tensor);
135   std::vector<Tensor::StridesType> strides(
136       at_tensor.strides().begin(), at_tensor.strides().end());
137   auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
138   std::vector<uint8_t> etensor_data(at_tensor.nbytes());
139   torch::executor::TensorImpl tensor_impl(
140       dtype,
141       at_tensor.dim(),
142       sizes.data(),
143       etensor_data.data(),
144       dim_order.data(),
145       strides.data());
146   torch::executor::Tensor etensor(&tensor_impl);
147   auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
148   EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
149 }
150 
TEST(ATenBridgeTest,AliasTensorPtrToATenTensor)151 TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
152   auto at_tensor = generate_at_tensor();
153   const auto& et_tensor_ptr = alias_tensor_ptr_to_attensor(at_tensor);
154   alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
155   EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
156 }
157