xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/resolver.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/qualified_name.h>
5 #include <torch/csrc/jit/frontend/sugared_value.h>
6 
7 namespace torch::jit {
8 
9 struct Resolver;
10 using ResolverPtr = std::shared_ptr<Resolver>;
11 
12 /**
13  * class Resolver
14  *
15  * Represents an "outer environment" in which we an look up names and return
16  * a corresponding SugaredValue. This is used during compilation to resolve
17  * references to names which are not defined internal to the graph.
18  *
19  * Example: PythonResolver looks at the enclosing Python scope for `name`.
20  *
21  * NOTE: When adding methods, keep this an abstract class (i.e. all new methods
22  * should be purely virtual). Resist the urge to provide a default
23  * implementation; you should explicitly think about how each resolver would
24  * handle the method.
25  */
26 struct Resolver {
27   virtual ~Resolver() = default;
28 
29   // Resolve a given name to a SugaredValue. This takes the method `m` that the
30   // caller is currently constructing, since we may need to insert nodes into
31   // the graph to create a value.
resolveValueResolver32   virtual std::shared_ptr<SugaredValue> resolveValue(
33       const std::string& name,
34       GraphFunction& m,
35       const SourceRange& loc) {
36     return nullptr;
37   }
38 
39   // Resolve `name` to a TypePtr.
resolveTypeResolver40   virtual TypePtr resolveType(const std::string& name, const SourceRange& loc) {
41     return nullptr;
42   }
43 };
44 
45 // A resolver that only understands "torch.foo()" lookups.
46 struct NativeResolver : public Resolver {
resolveValueNativeResolver47   std::shared_ptr<SugaredValue> resolveValue(
48       const std::string& name,
49       GraphFunction& m,
50       const SourceRange& loc) override {
51     if (name == "torch") {
52       return std::make_shared<BuiltinModule>("aten");
53     }
54     return nullptr;
55   }
56 
resolveTypeNativeResolver57   TypePtr resolveType(const std::string& name, const SourceRange& loc)
58       override {
59     return nullptr;
60   }
61 };
62 
nativeResolver()63 inline std::shared_ptr<NativeResolver> nativeResolver() {
64   return std::make_shared<NativeResolver>();
65 }
66 } // namespace torch::jit
67