xref: /aosp_15_r20/external/pytorch/test/cpp/api/jit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/jit.h>
4 #include <torch/script.h>
5 #include <torch/types.h>
6 
7 #include <string>
8 
TEST(TorchScriptTest,CanCompileMultipleFunctions)9 TEST(TorchScriptTest, CanCompileMultipleFunctions) {
10   auto module = torch::jit::compile(R"JIT(
11       def test_mul(a, b):
12         return a * b
13       def test_relu(a, b):
14         return torch.relu(a + b)
15       def test_while(a, i):
16         while bool(i < 10):
17           a += a
18           i += 1
19         return a
20       def test_len(a : List[int]):
21         return len(a)
22     )JIT");
23   auto a = torch::ones(1);
24   auto b = torch::ones(1);
25 
26   ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
27 
28   ASSERT_EQ(
29       2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
30 
31   ASSERT_TRUE(
32       0x200 ==
33       module->run_method("test_while", a, b).toTensor().item<int64_t>());
34 
35   at::IValue list = c10::List<int64_t>({3, 4});
36   ASSERT_EQ(2, module->run_method("test_len", list).toInt());
37 }
38 
TEST(TorchScriptTest,TestNestedIValueModuleArgMatching)39 TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
40   auto module = torch::jit::compile(R"JIT(
41       def nested_loop(a: List[List[Tensor]], b: int):
42         return torch.tensor(1.0) + b
43     )JIT");
44 
45   auto b = 3;
46 
47   torch::List<torch::Tensor> list({torch::rand({4, 4})});
48 
49   torch::List<torch::List<torch::Tensor>> list_of_lists;
50   list_of_lists.push_back(list);
51   module->run_method("nested_loop", list_of_lists, b);
52 
53   auto generic_list = c10::impl::GenericList(at::TensorType::get());
54   auto empty_generic_list =
55       c10::impl::GenericList(at::ListType::create(at::TensorType::get()));
56   empty_generic_list.push_back(generic_list);
57   module->run_method("nested_loop", empty_generic_list, b);
58 
59   auto too_many_lists = c10::impl::GenericList(
60       at::ListType::create(at::ListType::create(at::TensorType::get())));
61   too_many_lists.push_back(empty_generic_list);
62   try {
63     module->run_method("nested_loop", too_many_lists, b);
64     AT_ASSERT(false);
65   } catch (const c10::Error& error) {
66     AT_ASSERT(
67         std::string(error.what_without_backtrace())
68             .find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
69                   " for argument 'a' but instead found type "
70                   "'List[List[List[Tensor]]]'") == 0);
71   };
72 }
73 
TEST(TorchScriptTest,TestDictArgMatching)74 TEST(TorchScriptTest, TestDictArgMatching) {
75   auto module = torch::jit::compile(R"JIT(
76       def dict_op(a: Dict[str, Tensor], b: str):
77         return a[b]
78     )JIT");
79   c10::Dict<std::string, at::Tensor> dict;
80   dict.insert("hello", torch::ones({2}));
81   auto output = module->run_method("dict_op", dict, std::string("hello"));
82   ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
83 }
84 
TEST(TorchScriptTest,TestTupleArgMatching)85 TEST(TorchScriptTest, TestTupleArgMatching) {
86   auto module = torch::jit::compile(R"JIT(
87       def tuple_op(a: Tuple[List[int]]):
88         return a
89     )JIT");
90 
91   c10::List<int64_t> int_list({1});
92   auto tuple_generic_list = c10::ivalue::Tuple::create({int_list});
93 
94   // doesn't fail on arg matching
95   module->run_method("tuple_op", tuple_generic_list);
96 }
97 
TEST(TorchScriptTest,TestOptionalArgMatching)98 TEST(TorchScriptTest, TestOptionalArgMatching) {
99   auto module = torch::jit::compile(R"JIT(
100       def optional_tuple_op(a: Optional[Tuple[int, str]]):
101         if a is None:
102           return 0
103         else:
104           return a[0]
105     )JIT");
106 
107   auto optional_tuple = c10::ivalue::Tuple::create({2, std::string("hi")});
108 
109   ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple).toInt());
110   ASSERT_EQ(
111       0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
112 }
113 
TEST(TorchScriptTest,TestPickle)114 TEST(TorchScriptTest, TestPickle) {
115   torch::IValue float_value(2.3);
116 
117   // TODO: when tensors are stored in the pickle, delete this
118   std::vector<at::Tensor> tensor_table;
119   auto data = torch::jit::pickle(float_value, &tensor_table);
120 
121   torch::IValue ivalue = torch::jit::unpickle(data.data(), data.size());
122 
123   double diff = ivalue.toDouble() - float_value.toDouble();
124   double eps = 0.0001;
125   ASSERT_TRUE(diff < eps && diff > -eps);
126 }
127