xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/compilation_unit.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/function.h>
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/api/function_impl.h>
5 #include <torch/csrc/jit/frontend/name_mangler.h>
6 #include <torch/csrc/jit/frontend/source_range.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/runtime/graph_executor.h>
9 
10 #include <torch/csrc/Export.h>
11 
12 #include <ATen/core/function_schema.h>
13 #include <ATen/core/qualified_name.h>
14 #include <c10/util/ArrayRef.h>
15 #include <optional>
16 
17 #include <functional>
18 #include <memory>
19 #include <mutex>
20 #include <ostream>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 namespace torch::jit {
26 
27 struct Def;
28 struct Property;
29 struct ClassDef;
30 struct SugaredValue;
31 struct Resolver;
32 
33 using ResolverPtr = std::shared_ptr<Resolver>;
34 struct Self {
35   virtual ~Self() = default;
36   virtual std::shared_ptr<SugaredValue> makeSugared(Value* v) const = 0;
37   virtual ClassTypePtr getClassType() const = 0;
38 };
39 
40 // A CompilationUnit is a list of named Functions
41 // with helper methods to iterate the list or invoke the function.
42 // Classes have a CompilationUnit holding the class methods,
43 // and Modules have a CompilationUnit holding the Functions that
44 // are used to implement their Methods
45 
46 struct TORCH_API CompilationUnit {
47   enum class FunctionType { Method, Hook, PreHook };
48   // constructor that takes a set of functions to compile using the native
49   // resolver
50   explicit CompilationUnit(const std::string& source);
51   CompilationUnit() = default;
52 
53   CompilationUnit& operator=(CompilationUnit&&) = default;
54   CompilationUnit(CompilationUnit&&) = default;
55   CompilationUnit& operator=(const CompilationUnit&) = delete;
56   CompilationUnit(const CompilationUnit&) = delete;
57 
find_functionCompilationUnit58   Function* find_function(const c10::QualifiedName& name) const {
59     auto it = dict_.find(name);
60     if (it == dict_.end()) {
61       return nullptr;
62     }
63     return functions_[it->second].get();
64   }
65 
get_functionCompilationUnit66   Function& get_function(const c10::QualifiedName& name) const {
67     if (auto r = find_function(name)) {
68       return *r;
69     }
70     TORCH_CHECK(false, "attempted to get undefined function ", name.name());
71   }
72 
set_optimizedCompilationUnit73   void set_optimized(bool o) {
74     TORCH_WARN(
75         "CompilationUnit::set_optimized() is deprecated and has no effect. "
76         "Please use setGraphExecutorOptimize()");
77   }
78 
is_optimizedCompilationUnit79   bool is_optimized() const {
80     TORCH_WARN(
81         "CompilationUnit::is_optimized() is deprecated and always returns true. "
82         "Please use getGraphExecutorOptimize()");
83     return true;
84   }
85 
86   // for historic reasons, these are defined in ir_emitter.cpp
87   // Returns the list of Functions just defined.
88   std::vector<Function*> define(
89       const std::optional<c10::QualifiedName>& prefix,
90       const std::vector<Property>& properties,
91       const std::vector<ResolverPtr>& propResolvers,
92       const std::vector<Def>& definitions,
93       const std::vector<ResolverPtr>&
94           defResolvers, /* determines how we handle free
95                      variables in each definition*/
96       // if non-null, the first argument to each def, is bound to this value
97       const Self* self,
98       // see [name mangling]
99       bool shouldMangle = false,
100       std::optional<size_t> operator_set_version = std::nullopt);
101 
102   void define_hooks(
103       const std::optional<c10::QualifiedName>& prefix,
104       const std::vector<Def>& hookDefs,
105       const std::vector<ResolverPtr>& hookResolvers,
106       const std::vector<Def>& preHookDefs,
107       const std::vector<ResolverPtr>& preHookResolvers,
108       const Self* self,
109       bool shouldMangle = false);
110 
111   // same as above but parse the definitions from source
112   // Returns the list of Functions just defined.
113   std::vector<Function*> define(
114       // prefix namespace to put all the defined functions into
115       const std::optional<c10::QualifiedName>& prefix,
116       const std::string& source,
117       const ResolverPtr& resolver,
118       const Self* self);
119 
120   void define_interface(
121       const c10::QualifiedName& qualifiedName,
122       const ClassDef& classDef,
123       ResolverPtr rcb,
124       bool is_module = false);
125 
126   Function* create_function(
127       c10::QualifiedName name,
128       std::shared_ptr<Graph> graph,
129       bool shouldMangle = false) {
130     if (shouldMangle) {
131       name = mangle(name);
132     }
133     auto fn = std::make_unique<GraphFunction>(
134         std::move(name), std::move(graph), nullptr);
135     auto ret = fn.get();
136     register_function(std::move(fn));
137     return ret;
138   }
139 
get_functionsCompilationUnit140   std::vector<Function*> get_functions() const {
141     return fmap(functions_, [](const std::unique_ptr<Function>& fn) {
142       return fn.get();
143     });
144   }
145 
146   /// Run a method from this compilation.
147   ///
148   /// For example:
149   /// @code
150   ///   IValue output = module->run("relu_script", a, b);
151   /// @endcode
152   ///
153   /// To get a compile a module from a source string, see torch::jit::compile
154   ///
155   /// @param method_name The name of the method to run
156   /// @param args Arguments to be passed to the method
157   /// @return An IValue containing the return value (or values if it is a tuple)
158   /// from the method
159   template <typename... Types>
run_methodCompilationUnit160   IValue run_method(const c10::QualifiedName& method_name, Types&&... args) {
161     return get_function(method_name)({IValue(std::forward<Types>(args))...});
162   }
163 
drop_all_functionsCompilationUnit164   void drop_all_functions() {
165     dict_.clear();
166     functions_.clear();
167   }
168 
169   /**
170    * Register a class as being owned by this compilation unit.
171    */
register_typeCompilationUnit172   void register_type(c10::NamedTypePtr namedType) {
173     // TODO: class types cannot be redefined because we have no way right now
174     // of invalidating their methods. NamedTuples are fine though, since they
175     // don't have methods.
176     TORCH_CHECK(
177         0 == classDict_.count(*namedType->name()),
178         "class '",
179         namedType->name()->qualifiedName(),
180         "' already defined.");
181     classes_.push_back(std::move(namedType));
182     classDict_[*classes_.back()->name()] = classes_.size() - 1;
183   };
184 
get_classCompilationUnit185   c10::ClassTypePtr get_class(const c10::QualifiedName& name) const {
186     auto type = get_type(name);
187     if (!type) {
188       return nullptr;
189     }
190     return type->cast<c10::ClassType>();
191   }
192 
get_interfaceCompilationUnit193   c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const {
194     auto type = get_type(name);
195     if (!type) {
196       return nullptr;
197     }
198     return type->cast<c10::InterfaceType>();
199   }
200 
get_named_tupleCompilationUnit201   c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const {
202     for (const auto& cls : classes_) {
203       if (cls->name()->qualifiedName() == name.qualifiedName()) {
204         return cls->expect<TupleType>();
205       }
206     }
207     return nullptr;
208   }
209 
get_typeCompilationUnit210   c10::NamedTypePtr get_type(const c10::QualifiedName& name) const {
211     auto it = classDict_.find(name);
212     if (it == classDict_.end()) {
213       return nullptr;
214     }
215     return classes_[it->second];
216   }
217 
218   // For testing: clear all Python-defined classes to ensure that unit tests
219   // have isolation.
_clear_python_cuCompilationUnit220   void _clear_python_cu() {
221     // Delete all the associated class methods
222     for (const auto& type : classes_) {
223       if (auto cls = type->cast<ClassType>()) {
224         for (auto method : cls->methods()) {
225           // Tombstone the method in the compilation unit.
226           // Don't erase because the dict_
227           auto it = dict_.find(method->qualname());
228           if (it != dict_.end()) {
229             functions_[it->second] = nullptr;
230             // Erase in our big lookup table
231             dict_.erase(it);
232           }
233         }
234         // Classes can have multiple pointers to the same hook,
235         // need to make sure to not delete it twice
236         std::unordered_set<Function*> hooks_to_delete;
237         for (const auto& hook : cls->getForwardHooks()) {
238           hooks_to_delete.insert(hook);
239         }
240         for (const auto& pre_hook : cls->getForwardPreHooks()) {
241           hooks_to_delete.insert(pre_hook);
242         }
243         for (const auto& hook : hooks_to_delete) {
244           // Tombstone the hook in the compilation unit.
245           auto it = dict_.find(hook->qualname());
246           if (it != dict_.end()) {
247             functions_[it->second] = nullptr;
248             // Erase in our big lookup table
249             dict_.erase(it);
250           }
251         }
252       }
253     }
254     classes_.clear();
255     classDict_.clear();
256   }
257 
258   // [Internal Only] Remove method.
259   // Note Used for freezing.
unsafeRemoveMethodCompilationUnit260   void unsafeRemoveMethod(const c10::QualifiedName& method_name) {
261     auto it = dict_.find(method_name);
262     TORCH_CHECK(
263         it != dict_.end(),
264         "method '",
265         method_name.qualifiedName(),
266         "' does not exist.");
267     functions_[it->second] = nullptr;
268     dict_.erase(it);
269   }
270 
271   // [name mangling] All code objects must have a unique qualified name in a
272   // CompilationUnit. In Python, sometimes functions won't have unique qualified
273   // name (for example, nested functions). So we mangle Python functions to
274   // ensure that they are uniquely named.
275   //
276   // We also use mangling to distinguish different Module instances. Since each
277   // Module is a singleton class instance, different instances of the same
278   // Python Module will have different types but the same qualified name.
mangleCompilationUnit279   c10::QualifiedName mangle(const c10::QualifiedName& name) const {
280     auto mangled = name;
281     while (get_type(mangled) || find_function(mangled)) {
282       mangled = mangler_.mangle(mangled);
283     }
284     return mangled;
285   }
286 
287  private:
288   std::unique_ptr<Function> define(
289       const std::optional<c10::QualifiedName>& prefix,
290       const Def& def,
291       const ResolverPtr& resolver,
292       const Self* self,
293       const std::unordered_map<std::string, Function*>& function_table,
294       bool shouldMangle = false,
295       FunctionType type = FunctionType::Method,
296       std::optional<size_t> version = std::nullopt) const;
297 
298   // Define a property on \p self.
299   struct PropertyPair;
300   PropertyPair define_property(
301       const std::optional<c10::QualifiedName>& prefix,
302       const Property& prop,
303       const ResolverPtr& resolver,
304       const Self* self,
305       const std::unordered_map<std::string, Function*>& function_table,
306       bool shouldMangle = false) const;
307 
register_functionCompilationUnit308   Function& register_function(std::unique_ptr<Function> fn) {
309     TORCH_CHECK(
310         0 == dict_.count(fn->qualname().qualifiedName()),
311         "method '",
312         fn->qualname().qualifiedName(),
313         "' already defined.");
314     functions_.emplace_back(std::move(fn));
315     dict_[functions_.back()->qualname()] = functions_.size() - 1;
316     return *functions_.back();
317   }
318   std::vector<std::unique_ptr<Function>> functions_;
319   // for fast lookup
320   std::unordered_map<c10::QualifiedName, size_t> dict_;
321   std::unordered_map<c10::QualifiedName, size_t> classDict_;
322 
323   // [class ownership] Right now there are two relationships between classes
324   // and compilation units:
325   // 1. Classes have compilation units internally that hold their methods.
326   // 2. On load, the TypePtrs of any imported classes are owned by the main
327   // module's compilation unit.
328   std::vector<c10::NamedTypePtr> classes_;
329 
330   mutable NameMangler mangler_;
331 };
332 
333 // An owning pointer to a Function. Just a pair of a raw Function ptr and it's
334 // owning CU. We need this because pybind requires a ref-counted way to refer to
335 // Functions.
336 struct StrongFunctionPtr {
StrongFunctionPtrStrongFunctionPtr337   StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function)
338       : cu_(std::move(cu)), function_(function) {
339     TORCH_INTERNAL_ASSERT(cu_);
340     TORCH_INTERNAL_ASSERT(function_);
341   }
342   std::shared_ptr<CompilationUnit> cu_;
343   Function* function_;
344 };
345 
346 namespace script {
347 // We once had a `script::` namespace that was deleted. This is for backcompat
348 // of the public API; new code should not use this type alias.
349 using CompilationUnit = ::torch::jit::CompilationUnit;
350 } // namespace script
351 } // namespace torch::jit
352