xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_jit_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/testing/file_check.h>
5 #include "torch/csrc/jit/ir/ir.h"
6 #include "torch/csrc/jit/ir/irparser.h"
7 
8 namespace torch {
9 namespace jit {
10 
TEST(JitTypeTest,IsComplete)11 TEST(JitTypeTest, IsComplete) {
12   auto tt = c10::TensorType::create(
13       at::kFloat,
14       at::kCPU,
15       c10::SymbolicShape(std::vector<std::optional<int64_t>>({1, 49})),
16       std::vector<c10::Stride>(
17           {c10::Stride{2, true, 1},
18            c10::Stride{1, true, 1},
19            c10::Stride{0, true, std::nullopt}}),
20       false);
21   TORCH_INTERNAL_ASSERT(!tt->isComplete());
22   TORCH_INTERNAL_ASSERT(!tt->strides().isComplete());
23 }
24 
TEST(JitTypeTest,UnifyTypes)25 TEST(JitTypeTest, UnifyTypes) {
26   auto bool_tensor = TensorType::get()->withScalarType(at::kBool);
27   auto opt_bool_tensor = OptionalType::create(bool_tensor);
28   auto unified_opt_bool = unifyTypes(bool_tensor, opt_bool_tensor);
29   TORCH_INTERNAL_ASSERT(opt_bool_tensor->isSubtypeOf(**unified_opt_bool));
30 
31   auto tensor = TensorType::get();
32   TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(*opt_bool_tensor));
33   auto unified = unifyTypes(opt_bool_tensor, tensor);
34   TORCH_INTERNAL_ASSERT(unified);
35   auto elem = (*unified)->expectRef<OptionalType>().getElementType();
36   TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(*TensorType::get()));
37 
38   auto opt_tuple_none_int = OptionalType::create(
39       TupleType::create({NoneType::get(), IntType::get()}));
40   auto tuple_int_none = TupleType::create({IntType::get(), NoneType::get()});
41   auto out = unifyTypes(opt_tuple_none_int, tuple_int_none);
42   TORCH_INTERNAL_ASSERT(out);
43 
44   std::stringstream ss;
45   ss << (*out)->annotation_str();
46   testing::FileCheck()
47       .check("Optional[Tuple[Optional[int], Optional[int]]]")
48       ->run(ss.str());
49 
50   auto fut_1 = FutureType::create(IntType::get());
51   auto fut_2 = FutureType::create(NoneType::get());
52   auto fut_out = unifyTypes(fut_1, fut_2);
53   TORCH_INTERNAL_ASSERT(fut_out);
54   TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf(
55       *FutureType::create(OptionalType::create(IntType::get()))));
56 
57   auto dict_1 = DictType::create(IntType::get(), NoneType::get());
58   auto dict_2 = DictType::create(IntType::get(), IntType::get());
59   auto dict_out = unifyTypes(dict_1, dict_2);
60   TORCH_INTERNAL_ASSERT(!dict_out);
61 }
62 
63 } // namespace jit
64 } // namespace torch
65