xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/tensorexpr/eval.h>
3 #include <torch/csrc/jit/tensorexpr/expr.h>
4 #include <torch/csrc/jit/tensorexpr/loopnest.h>
5 #include <torch/csrc/jit/tensorexpr/operators/operators.h>
6 #include <torch/torch.h>
7 
8 using namespace torch::jit::tensorexpr;
9 
10 using Tensors = std::vector<Tensor>;
11 using Args = std::vector<CodeGen::BufferArg>;
compile(const Args & inputs,const Tensors & outputs)12 std::unique_ptr<SimpleIREvaluator> compile(
13     const Args& inputs,
14     const Tensors& outputs) {
15   LoopNest nest({outputs});
16   nest.prepareForCodegen();
17   nest.simplify();
18   auto join = inputs;
19   join.insert(join.end(), outputs.begin(), outputs.end());
20   return std::make_unique<SimpleIREvaluator>(nest.root_stmt(), join);
21 }
22 
TEST(Ops,Sum)23 TEST(Ops, Sum) {
24   constexpr int M = 8;
25   constexpr int N = 16;
26   std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
27   std::vector<std::vector<ExprHandle>> outputShapes = {{N}, {M}, {}};
28   for (unsigned idx = 0; idx < testDims.size(); idx++) {
29     const auto& dims = testDims[idx];
30     const auto& outShape = outputShapes[idx];
31 
32     BufHandle a("a", {M, N}, kFloat);
33     std::vector<ExprHandle> outStrides =
34         c10::fmap<ExprHandle>(make_contiguous_strides(outShape));
35     Tensor b = computeSum(
36         {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
37     auto cg = compile({a}, {b});
38 
39     auto at = at::arange(M * N, at::kFloat).view({M, N});
40     auto ref = at::sum(at, dims);
41     auto bt = at::empty_like(ref);
42 
43     cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
44 
45     ASSERT_TRUE(at::allclose(bt, ref));
46   }
47 }
48 
TEST(Ops,ChannelsLastSum)49 TEST(Ops, ChannelsLastSum) {
50   constexpr int A = 2;
51   constexpr int B = 3;
52   constexpr int C = 4;
53   constexpr int D = 5;
54   constexpr int E = 6;
55   std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
56 
57   std::vector<std::vector<ExprHandle>> outputShapes = {
58       {B, C, D, E}, {A, C, D, E}, {C, D, E}};
59   for (unsigned idx = 0; idx < testDims.size(); idx++) {
60     const auto& dims = testDims[idx];
61     const auto& outShape = outputShapes[idx];
62 
63     BufHandle a("a", {A, B, C, D, E}, kFloat);
64     std::vector<ExprHandle> outStrides =
65         c10::fmap<ExprHandle>(make_channels_last_strides(outShape));
66     Tensor b = computeSum(
67         {a, dims, false}, outShape, outStrides, c10::kFloat, at::kCPU);
68     auto cg = compile({a}, {b});
69 
70     auto at = at::arange(A * B * C * D * E, at::kFloat).view({A, B, C, D, E});
71     auto ref = at::sum(at, dims);
72     auto bt = at::empty_like(ref);
73 
74     cg->call({at.data_ptr<float>(), bt.data_ptr<float>()});
75 
76     ASSERT_TRUE(at::allclose(bt, ref));
77   }
78 }
79