xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_cache.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/lazy/core/cache.h>
5 #include <torch/csrc/lazy/core/hash.h>
6 #include <torch/csrc/lazy/core/ir.h>
7 #include <torch/csrc/lazy/core/shape.h>
8 #include <torch/csrc/lazy/ts_backend/ts_node.h>
9 
10 namespace torch {
11 namespace lazy {
12 
13 class CacheNode : public Node {
14  public:
CacheNode(const std::string & str)15   explicit CacheNode(const std::string& str)
16       : Node(OpKind(), /* num_outputs */ 1), hash_(Hash(str)), str_(str) {}
17   ~CacheNode() override = default;
18 
operands() const19   const std::vector<Output>& operands() const override {
20     TORCH_INTERNAL_ASSERT(false, "Can't access operands of test node");
21   }
22 
operand(size_t i) const23   const Output& operand(size_t i) const override {
24     TORCH_INTERNAL_ASSERT(false, "Can't access operand[i] of test node");
25   }
26 
hash() const27   hash_t hash() const override {
28     return hash_;
29   }
shapeHash() const30   hash_t shapeHash() const override {
31     return hash_;
32   }
33 
34  private:
35   hash_t hash_;
36   std::string str_;
37 };
38 
TEST(CacheTest,BasicTest)39 TEST(CacheTest, BasicTest) {
40   std::shared_ptr<CacheNode> a = std::make_shared<CacheNode>("a");
41   std::shared_ptr<CacheNode> b = std::make_shared<CacheNode>("b");
42   std::shared_ptr<CacheNode> c = std::make_shared<CacheNode>("c");
43   Cache<hash_t, CacheNode, HashReducer> cache(2);
44 
45   cache.Add(a->hash(), a);
46   EXPECT_EQ(cache.Get(a->hash()), a);
47   EXPECT_EQ(cache.Get(b->hash()), nullptr);
48   EXPECT_EQ(cache.Get(c->hash()), nullptr);
49 
50   cache.Add(b->hash(), b);
51   EXPECT_EQ(cache.Get(a->hash()), a);
52   EXPECT_EQ(cache.Get(b->hash()), b);
53   EXPECT_EQ(cache.Get(c->hash()), nullptr);
54 
55   cache.Add(c->hash(), c);
56   EXPECT_EQ(cache.Get(a->hash()), nullptr); // a has been evicted
57   EXPECT_EQ(cache.Get(b->hash()), b);
58   EXPECT_EQ(cache.Get(c->hash()), c);
59 
60   cache.Erase(c->hash());
61   EXPECT_EQ(cache.Get(a->hash()), nullptr);
62   EXPECT_EQ(cache.Get(b->hash()), b);
63   EXPECT_EQ(cache.Get(c->hash()), nullptr); // c has been removed
64 
65   cache.Clear();
66   EXPECT_EQ(cache.Get(a->hash()), nullptr);
67   EXPECT_EQ(cache.Get(b->hash()), nullptr);
68   EXPECT_EQ(cache.Get(c->hash()), nullptr);
69 }
70 
71 class CacheNodeWithShape : public TsNode {
72  public:
CacheNodeWithShape(const Shape & shape)73   explicit CacheNodeWithShape(const Shape& shape)
74       : TsNode(OpKind(), shape, /* num_outputs */ 1, /* seed */ 0) {}
75 };
76 
TEST(CacheTest,ShapeCacheTestForDynamicShape)77 TEST(CacheTest, ShapeCacheTestForDynamicShape) {
78   // enable dynamic shape
79   FLAGS_ltc_enable_dynamic_shapes = true;
80 
81   CacheNodeWithShape nodes[] = {
82       CacheNodeWithShape(Shape(c10::kFloat, {2, 4})),
83       CacheNodeWithShape(Shape(c10::kFloat, {4, 2}))};
84 
85   /*
86    * Make sure the cached shape for node (2, 4) is not used for node (4, 2)
87    */
88   for (auto& node : nodes) {
89     EXPECT_EQ(node.shape(), node.computeShape([&]() { return node.shape(); }));
90   }
91 
92   // reset the flag
93   FLAGS_ltc_enable_dynamic_shapes = false;
94 }
95 
96 } // namespace lazy
97 } // namespace torch
98