xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/mini_environment.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 namespace torch::jit {
7 
8 // Simple data structure for containing a type T in nested control blocks
9 // Should only be used after initial compilation where type checking and
10 // loads and stores are emitted
11 
12 template <typename T>
13 struct MiniEnvironment {
14   MiniEnvironment(Block* b, std::shared_ptr<MiniEnvironment> next = nullptr)
nextMiniEnvironment15       : next(std::move(next)) {}
16 
17   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
18   std::shared_ptr<MiniEnvironment<T>> next;
19 
findInThisFrameMiniEnvironment20   T findInThisFrame(const std::string& name) {
21     auto it = table.find(name);
22     if (it != table.end()) {
23       return it->second;
24     }
25     return nullptr;
26   }
27 
findInAnyFrameMiniEnvironment28   T findInAnyFrame(const std::string& name) {
29     for (auto runner = this; runner; runner = runner->next.get()) {
30       if (auto r = runner->findInThisFrame(name)) {
31         return r;
32       }
33     }
34     return nullptr;
35   }
36 
setVarMiniEnvironment37   void setVar(const std::string& name, T value) {
38     table[name] = value;
39   }
40 
definedVariablesMiniEnvironment41   std::vector<std::string> definedVariables() {
42     std::vector<std::string> result;
43     result.reserve(table.size());
44     for (auto& kv : table) {
45       result.push_back(kv.first);
46     }
47     std::sort(result.begin(), result.end());
48     return result;
49   }
50 
51  private:
52   std::unordered_map<std::string, T> table;
53 };
54 
55 } // namespace torch::jit
56