1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <torch/jit.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/script.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/types.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #include <string>
8*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,CanCompileMultipleFunctions)9*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, CanCompileMultipleFunctions) {
10*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::compile(R"JIT(
11*da0073e9SAndroid Build Coastguard Worker def test_mul(a, b):
12*da0073e9SAndroid Build Coastguard Worker return a * b
13*da0073e9SAndroid Build Coastguard Worker def test_relu(a, b):
14*da0073e9SAndroid Build Coastguard Worker return torch.relu(a + b)
15*da0073e9SAndroid Build Coastguard Worker def test_while(a, i):
16*da0073e9SAndroid Build Coastguard Worker while bool(i < 10):
17*da0073e9SAndroid Build Coastguard Worker a += a
18*da0073e9SAndroid Build Coastguard Worker i += 1
19*da0073e9SAndroid Build Coastguard Worker return a
20*da0073e9SAndroid Build Coastguard Worker def test_len(a : List[int]):
21*da0073e9SAndroid Build Coastguard Worker return len(a)
22*da0073e9SAndroid Build Coastguard Worker )JIT");
23*da0073e9SAndroid Build Coastguard Worker auto a = torch::ones(1);
24*da0073e9SAndroid Build Coastguard Worker auto b = torch::ones(1);
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
29*da0073e9SAndroid Build Coastguard Worker 2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
32*da0073e9SAndroid Build Coastguard Worker 0x200 ==
33*da0073e9SAndroid Build Coastguard Worker module->run_method("test_while", a, b).toTensor().item<int64_t>());
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker at::IValue list = c10::List<int64_t>({3, 4});
36*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(2, module->run_method("test_len", list).toInt());
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,TestNestedIValueModuleArgMatching)39*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
40*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::compile(R"JIT(
41*da0073e9SAndroid Build Coastguard Worker def nested_loop(a: List[List[Tensor]], b: int):
42*da0073e9SAndroid Build Coastguard Worker return torch.tensor(1.0) + b
43*da0073e9SAndroid Build Coastguard Worker )JIT");
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker auto b = 3;
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker torch::List<torch::Tensor> list({torch::rand({4, 4})});
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker torch::List<torch::List<torch::Tensor>> list_of_lists;
50*da0073e9SAndroid Build Coastguard Worker list_of_lists.push_back(list);
51*da0073e9SAndroid Build Coastguard Worker module->run_method("nested_loop", list_of_lists, b);
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker auto generic_list = c10::impl::GenericList(at::TensorType::get());
54*da0073e9SAndroid Build Coastguard Worker auto empty_generic_list =
55*da0073e9SAndroid Build Coastguard Worker c10::impl::GenericList(at::ListType::create(at::TensorType::get()));
56*da0073e9SAndroid Build Coastguard Worker empty_generic_list.push_back(generic_list);
57*da0073e9SAndroid Build Coastguard Worker module->run_method("nested_loop", empty_generic_list, b);
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker auto too_many_lists = c10::impl::GenericList(
60*da0073e9SAndroid Build Coastguard Worker at::ListType::create(at::ListType::create(at::TensorType::get())));
61*da0073e9SAndroid Build Coastguard Worker too_many_lists.push_back(empty_generic_list);
62*da0073e9SAndroid Build Coastguard Worker try {
63*da0073e9SAndroid Build Coastguard Worker module->run_method("nested_loop", too_many_lists, b);
64*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(false);
65*da0073e9SAndroid Build Coastguard Worker } catch (const c10::Error& error) {
66*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(
67*da0073e9SAndroid Build Coastguard Worker std::string(error.what_without_backtrace())
68*da0073e9SAndroid Build Coastguard Worker .find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
69*da0073e9SAndroid Build Coastguard Worker " for argument 'a' but instead found type "
70*da0073e9SAndroid Build Coastguard Worker "'List[List[List[Tensor]]]'") == 0);
71*da0073e9SAndroid Build Coastguard Worker };
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,TestDictArgMatching)74*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestDictArgMatching) {
75*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::compile(R"JIT(
76*da0073e9SAndroid Build Coastguard Worker def dict_op(a: Dict[str, Tensor], b: str):
77*da0073e9SAndroid Build Coastguard Worker return a[b]
78*da0073e9SAndroid Build Coastguard Worker )JIT");
79*da0073e9SAndroid Build Coastguard Worker c10::Dict<std::string, at::Tensor> dict;
80*da0073e9SAndroid Build Coastguard Worker dict.insert("hello", torch::ones({2}));
81*da0073e9SAndroid Build Coastguard Worker auto output = module->run_method("dict_op", dict, std::string("hello"));
82*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,TestTupleArgMatching)85*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestTupleArgMatching) {
86*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::compile(R"JIT(
87*da0073e9SAndroid Build Coastguard Worker def tuple_op(a: Tuple[List[int]]):
88*da0073e9SAndroid Build Coastguard Worker return a
89*da0073e9SAndroid Build Coastguard Worker )JIT");
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker c10::List<int64_t> int_list({1});
92*da0073e9SAndroid Build Coastguard Worker auto tuple_generic_list = c10::ivalue::Tuple::create({int_list});
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker // doesn't fail on arg matching
95*da0073e9SAndroid Build Coastguard Worker module->run_method("tuple_op", tuple_generic_list);
96*da0073e9SAndroid Build Coastguard Worker }
97*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,TestOptionalArgMatching)98*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestOptionalArgMatching) {
99*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::compile(R"JIT(
100*da0073e9SAndroid Build Coastguard Worker def optional_tuple_op(a: Optional[Tuple[int, str]]):
101*da0073e9SAndroid Build Coastguard Worker if a is None:
102*da0073e9SAndroid Build Coastguard Worker return 0
103*da0073e9SAndroid Build Coastguard Worker else:
104*da0073e9SAndroid Build Coastguard Worker return a[0]
105*da0073e9SAndroid Build Coastguard Worker )JIT");
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker auto optional_tuple = c10::ivalue::Tuple::create({2, std::string("hi")});
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(2, module->run_method("optional_tuple_op", optional_tuple).toInt());
110*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
111*da0073e9SAndroid Build Coastguard Worker 0, module->run_method("optional_tuple_op", torch::jit::IValue()).toInt());
112*da0073e9SAndroid Build Coastguard Worker }
113*da0073e9SAndroid Build Coastguard Worker
TEST(TorchScriptTest,TestPickle)114*da0073e9SAndroid Build Coastguard Worker TEST(TorchScriptTest, TestPickle) {
115*da0073e9SAndroid Build Coastguard Worker torch::IValue float_value(2.3);
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker // TODO: when tensors are stored in the pickle, delete this
118*da0073e9SAndroid Build Coastguard Worker std::vector<at::Tensor> tensor_table;
119*da0073e9SAndroid Build Coastguard Worker auto data = torch::jit::pickle(float_value, &tensor_table);
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker torch::IValue ivalue = torch::jit::unpickle(data.data(), data.size());
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker double diff = ivalue.toDouble() - float_value.toDouble();
124*da0073e9SAndroid Build Coastguard Worker double eps = 0.0001;
125*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(diff < eps && diff > -eps);
126*da0073e9SAndroid Build Coastguard Worker }
127