xref: /aosp_15_r20/external/pytorch/test/cpp/api/jit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <torch/jit.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/script.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/types.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <string>
8*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,CanCompileMultipleFunctions)9*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, CanCompileMultipleFunctions) {
10*da0073e9SAndroid Build Coastguard Worker   auto module = torch::jit::compile(R"JIT(
11*da0073e9SAndroid Build Coastguard Worker       def test_mul(a, b):
12*da0073e9SAndroid Build Coastguard Worker         return a * b
13*da0073e9SAndroid Build Coastguard Worker       def test_relu(a, b):
14*da0073e9SAndroid Build Coastguard Worker         return torch.relu(a + b)
15*da0073e9SAndroid Build Coastguard Worker       def test_while(a, i):
16*da0073e9SAndroid Build Coastguard Worker         while bool(i < 10):
17*da0073e9SAndroid Build Coastguard Worker           a += a
18*da0073e9SAndroid Build Coastguard Worker           i += 1
19*da0073e9SAndroid Build Coastguard Worker         return a
20*da0073e9SAndroid Build Coastguard Worker       def test_len(a : List[int]):
21*da0073e9SAndroid Build Coastguard Worker         return len(a)
22*da0073e9SAndroid Build Coastguard Worker     )JIT");
23*da0073e9SAndroid Build Coastguard Worker   auto a = torch::ones(1);
24*da0073e9SAndroid Build Coastguard Worker   auto b = torch::ones(1);
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
29*da0073e9SAndroid Build Coastguard Worker       2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
32*da0073e9SAndroid Build Coastguard Worker       0x200 ==
33*da0073e9SAndroid Build Coastguard Worker       module->run_method("test_while", a, b).toTensor().item<int64_t>());
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker   at::IValue list = c10::List<int64_t>({3, 4});
36*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(2, module->run_method("test_len", list).toInt());
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,TestNestedIValueModuleArgMatching)39*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
40*da0073e9SAndroid Build Coastguard Worker   auto module = torch::jit::compile(R"JIT(
41*da0073e9SAndroid Build Coastguard Worker       def nested_loop(a: List[List[Tensor]], b: int):
42*da0073e9SAndroid Build Coastguard Worker         return torch.tensor(1.0) + b
43*da0073e9SAndroid Build Coastguard Worker     )JIT");
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker   auto b = 3;
46*da0073e9SAndroid Build Coastguard Worker 
47*da0073e9SAndroid Build Coastguard Worker   torch::List<torch::Tensor> list({torch::rand({4, 4})});
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   torch::List<torch::List<torch::Tensor>> list_of_lists;
50*da0073e9SAndroid Build Coastguard Worker   list_of_lists.push_back(list);
51*da0073e9SAndroid Build Coastguard Worker   module->run_method("nested_loop", list_of_lists, b);
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker   auto generic_list = c10::impl::GenericList(at::TensorType::get());
54*da0073e9SAndroid Build Coastguard Worker   auto empty_generic_list =
55*da0073e9SAndroid Build Coastguard Worker       c10::impl::GenericList(at::ListType::create(at::TensorType::get()));
56*da0073e9SAndroid Build Coastguard Worker   empty_generic_list.push_back(generic_list);
57*da0073e9SAndroid Build Coastguard Worker   module->run_method("nested_loop", empty_generic_list, b);
58*da0073e9SAndroid Build Coastguard Worker 
59*da0073e9SAndroid Build Coastguard Worker   auto too_many_lists = c10::impl::GenericList(
60*da0073e9SAndroid Build Coastguard Worker       at::ListType::create(at::ListType::create(at::TensorType::get())));
61*da0073e9SAndroid Build Coastguard Worker   too_many_lists.push_back(empty_generic_list);
62*da0073e9SAndroid Build Coastguard Worker   try {
63*da0073e9SAndroid Build Coastguard Worker     module->run_method("nested_loop", too_many_lists, b);
64*da0073e9SAndroid Build Coastguard Worker     AT_ASSERT(false);
65*da0073e9SAndroid Build Coastguard Worker   } catch (const c10::Error& error) {
66*da0073e9SAndroid Build Coastguard Worker     AT_ASSERT(
67*da0073e9SAndroid Build Coastguard Worker         std::string(error.what_without_backtrace())
68*da0073e9SAndroid Build Coastguard Worker             .find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
69*da0073e9SAndroid Build Coastguard Worker                   " for argument 'a' but instead found type "
70*da0073e9SAndroid Build Coastguard Worker                   "'List[List[List[Tensor]]]'") == 0);
71*da0073e9SAndroid Build Coastguard Worker   };
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,TestDictArgMatching)74*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestDictArgMatching) {
75*da0073e9SAndroid Build Coastguard Worker   auto module = torch::jit::compile(R"JIT(
76*da0073e9SAndroid Build Coastguard Worker       def dict_op(a: Dict[str, Tensor], b: str):
77*da0073e9SAndroid Build Coastguard Worker         return a[b]
78*da0073e9SAndroid Build Coastguard Worker     )JIT");
79*da0073e9SAndroid Build Coastguard Worker   c10::Dict<std::string, at::Tensor> dict;
80*da0073e9SAndroid Build Coastguard Worker   dict.insert("hello", torch::ones({2}));
81*da0073e9SAndroid Build Coastguard Worker   auto output = module->run_method("dict_op", dict, std::string("hello"));
82*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,TestTupleArgMatching)85*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestTupleArgMatching) {
86*da0073e9SAndroid Build Coastguard Worker   auto module = torch::jit::compile(R"JIT(
87*da0073e9SAndroid Build Coastguard Worker       def tuple_op(a: Tuple[List[int]]):
88*da0073e9SAndroid Build Coastguard Worker         return a
89*da0073e9SAndroid Build Coastguard Worker     )JIT");
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker   c10::List<int64_t> int_list({1});
92*da0073e9SAndroid Build Coastguard Worker   auto tuple_generic_list = c10::ivalue::Tuple::create({int_list});
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker   // doesn't fail on arg matching
95*da0073e9SAndroid Build Coastguard Worker   module->run_method("tuple_op", tuple_generic_list);
96*da0073e9SAndroid Build Coastguard Worker }
97*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,TestOptionalArgMatching)98*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestOptionalArgMatching) {
99*da0073e9SAndroid Build Coastguard Worker   auto module = torch::jit::compile(R"JIT(
100*da0073e9SAndroid Build Coastguard Worker       def optional_tuple_op(a: Optional[Tuple[int, str]]):
101*da0073e9SAndroid Build Coastguard Worker         if a is None:
102*da0073e9SAndroid Build Coastguard Worker           return 0
103*da0073e9SAndroid Build Coastguard Worker         else:
104*da0073e9SAndroid Build Coastguard Worker           return a[0]
105*da0073e9SAndroid Build Coastguard Worker     )JIT");
106*da0073e9SAndroid Build Coastguard Worker 
107*da0073e9SAndroid Build Coastguard Worker   auto optional_tuple = c10::ivalue::Tuple::create({2, std::string("hi")});
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple).toInt());
110*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(
111*da0073e9SAndroid Build Coastguard Worker       0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker 
TEST(TorchScriptTest,TestPickle)114*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestPickle) {
115*da0073e9SAndroid Build Coastguard Worker   torch::IValue float_value(2.3);
116*da0073e9SAndroid Build Coastguard Worker 
117*da0073e9SAndroid Build Coastguard Worker   // TODO: when tensors are stored in the pickle, delete this
118*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> tensor_table;
119*da0073e9SAndroid Build Coastguard Worker   auto data = torch::jit::pickle(float_value, &tensor_table);
120*da0073e9SAndroid Build Coastguard Worker 
121*da0073e9SAndroid Build Coastguard Worker   torch::IValue ivalue = torch::jit::unpickle(data.data(), data.size());
122*da0073e9SAndroid Build Coastguard Worker 
123*da0073e9SAndroid Build Coastguard Worker   double diff = ivalue.toDouble() - float_value.toDouble();
124*da0073e9SAndroid Build Coastguard Worker   double eps = 0.0001;
125*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(diff < eps && diff > -eps);
126*da0073e9SAndroid Build Coastguard Worker }
127