xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_schema_matching.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/runtime/custom_operator.h>
5 #include <torch/csrc/jit/testing/file_check.h>
6 #include <torch/jit.h>
7 
8 #include <sstream>
9 #include <string>
10 
11 namespace torch {
12 namespace jit {
13 
TEST(SchemaMatchingTest,VarType)14 TEST(SchemaMatchingTest, VarType) {
15   RegisterOperators reg({
16       Operator(
17           "aten::test_vartype(t[] a, t b) -> (t)",
18           [](Stack& stack) {
19             c10::List<double> list;
20             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
21             double a;
22             pop(stack, list, a);
23             push(stack, a);
24           },
25           c10::AliasAnalysisKind::FROM_SCHEMA),
26   });
27   Module m("m");
28   m.define(R"(
29       def test(self):
30         a = (1.0, 2.0)
31         return torch.test_vartype(a, 2.0)
32     )");
33   auto result = m.run_method("test");
34   TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
35 
36   const std::string error_example = R"JIT(
37       def test_2(self):
38           a = (1.0, 2.0)
39           non_float = (1, 1)
40           return torch.test_vartype(a, non_float)
41     )JIT";
42 
43   std::string err = "";
44   try {
45     m.define(error_example);
46   } catch (const std::exception& e) {
47     err = e.what();
48   }
49   TORCH_INTERNAL_ASSERT(
50       err.find("previously matched to type") != std::string::npos);
51 }
52 
TEST(SchemaMatchingTest,VarType2)53 TEST(SchemaMatchingTest, VarType2) {
54   RegisterOperators reg({
55       Operator(
56           "aten::test_vartype2(t a, t[] b) -> (t[])",
57           [](Stack& stack) {
58             // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
59             double a;
60             c10::List<double> list;
61             pop(stack, a, list);
62             push(stack, a);
63           },
64           AliasAnalysisKind::FROM_SCHEMA),
65   });
66   Module m("m");
67   m.define(R"JIT(
68       def test(self):
69           a = (1.0, 2.0)
70           return torch.test_vartype2(3.0, a)
71     )JIT");
72   auto result = m.run_method("test");
73   TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
74 
75   static const auto error_exam2 = R"JIT(
76       def test_2(self):
77           a = (1, 2)
78           return torch.test_vartype2(3.0, a)
79     )JIT";
80 
81   std::string err = "";
82   try {
83     m.define(error_exam2);
84   } catch (const std::exception& e) {
85     err = e.what();
86   }
87   TORCH_INTERNAL_ASSERT(
88       err.find("previously matched to type") != std::string::npos);
89 }
90 } // namespace jit
91 } // namespace torch
92