1 #pragma once 2 #include <ATen/core/function.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/api/function_impl.h> 5 #include <torch/csrc/jit/frontend/name_mangler.h> 6 #include <torch/csrc/jit/frontend/source_range.h> 7 #include <torch/csrc/jit/ir/ir.h> 8 #include <torch/csrc/jit/runtime/graph_executor.h> 9 10 #include <torch/csrc/Export.h> 11 12 #include <ATen/core/function_schema.h> 13 #include <ATen/core/qualified_name.h> 14 #include <c10/util/ArrayRef.h> 15 #include <optional> 16 17 #include <functional> 18 #include <memory> 19 #include <mutex> 20 #include <ostream> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 namespace torch::jit { 26 27 struct Def; 28 struct Property; 29 struct ClassDef; 30 struct SugaredValue; 31 struct Resolver; 32 33 using ResolverPtr = std::shared_ptr<Resolver>; 34 struct Self { 35 virtual ~Self() = default; 36 virtual std::shared_ptr<SugaredValue> makeSugared(Value* v) const = 0; 37 virtual ClassTypePtr getClassType() const = 0; 38 }; 39 40 // A CompilationUnit is a list of named Functions 41 // with helper methods to iterate the list or invoke the function. 42 // Classes have a CompilationUnit holding the class methods, 43 // and Modules have a CompilationUnit holding the Functions that 44 // are used to implement their Methods 45 46 struct TORCH_API CompilationUnit { 47 enum class FunctionType { Method, Hook, PreHook }; 48 // constructor that takes a set of functions to compile using the native 49 // resolver 50 explicit CompilationUnit(const std::string& source); 51 CompilationUnit() = default; 52 53 CompilationUnit& operator=(CompilationUnit&&) = default; 54 CompilationUnit(CompilationUnit&&) = default; 55 CompilationUnit& operator=(const CompilationUnit&) = delete; 56 CompilationUnit(const CompilationUnit&) = delete; 57 find_functionCompilationUnit58 Function* find_function(const c10::QualifiedName& name) const { 59 auto it = dict_.find(name); 60 if (it == dict_.end()) { 61 return nullptr; 62 } 63 return functions_[it->second].get(); 64 } 65 get_functionCompilationUnit66 Function& get_function(const c10::QualifiedName& name) const { 67 if (auto r = find_function(name)) { 68 return *r; 69 } 70 TORCH_CHECK(false, "attempted to get undefined function ", name.name()); 71 } 72 set_optimizedCompilationUnit73 void set_optimized(bool o) { 74 TORCH_WARN( 75 "CompilationUnit::set_optimized() is deprecated and has no effect. " 76 "Please use setGraphExecutorOptimize()"); 77 } 78 is_optimizedCompilationUnit79 bool is_optimized() const { 80 TORCH_WARN( 81 "CompilationUnit::is_optimized() is deprecated and always returns true. " 82 "Please use getGraphExecutorOptimize()"); 83 return true; 84 } 85 86 // for historic reasons, these are defined in ir_emitter.cpp 87 // Returns the list of Functions just defined. 88 std::vector<Function*> define( 89 const std::optional<c10::QualifiedName>& prefix, 90 const std::vector<Property>& properties, 91 const std::vector<ResolverPtr>& propResolvers, 92 const std::vector<Def>& definitions, 93 const std::vector<ResolverPtr>& 94 defResolvers, /* determines how we handle free 95 variables in each definition*/ 96 // if non-null, the first argument to each def, is bound to this value 97 const Self* self, 98 // see [name mangling] 99 bool shouldMangle = false, 100 std::optional<size_t> operator_set_version = std::nullopt); 101 102 void define_hooks( 103 const std::optional<c10::QualifiedName>& prefix, 104 const std::vector<Def>& hookDefs, 105 const std::vector<ResolverPtr>& hookResolvers, 106 const std::vector<Def>& preHookDefs, 107 const std::vector<ResolverPtr>& preHookResolvers, 108 const Self* self, 109 bool shouldMangle = false); 110 111 // same as above but parse the definitions from source 112 // Returns the list of Functions just defined. 113 std::vector<Function*> define( 114 // prefix namespace to put all the defined functions into 115 const std::optional<c10::QualifiedName>& prefix, 116 const std::string& source, 117 const ResolverPtr& resolver, 118 const Self* self); 119 120 void define_interface( 121 const c10::QualifiedName& qualifiedName, 122 const ClassDef& classDef, 123 ResolverPtr rcb, 124 bool is_module = false); 125 126 Function* create_function( 127 c10::QualifiedName name, 128 std::shared_ptr<Graph> graph, 129 bool shouldMangle = false) { 130 if (shouldMangle) { 131 name = mangle(name); 132 } 133 auto fn = std::make_unique<GraphFunction>( 134 std::move(name), std::move(graph), nullptr); 135 auto ret = fn.get(); 136 register_function(std::move(fn)); 137 return ret; 138 } 139 get_functionsCompilationUnit140 std::vector<Function*> get_functions() const { 141 return fmap(functions_, [](const std::unique_ptr<Function>& fn) { 142 return fn.get(); 143 }); 144 } 145 146 /// Run a method from this compilation. 147 /// 148 /// For example: 149 /// @code 150 /// IValue output = module->run("relu_script", a, b); 151 /// @endcode 152 /// 153 /// To get a compile a module from a source string, see torch::jit::compile 154 /// 155 /// @param method_name The name of the method to run 156 /// @param args Arguments to be passed to the method 157 /// @return An IValue containing the return value (or values if it is a tuple) 158 /// from the method 159 template <typename... Types> run_methodCompilationUnit160 IValue run_method(const c10::QualifiedName& method_name, Types&&... args) { 161 return get_function(method_name)({IValue(std::forward<Types>(args))...}); 162 } 163 drop_all_functionsCompilationUnit164 void drop_all_functions() { 165 dict_.clear(); 166 functions_.clear(); 167 } 168 169 /** 170 * Register a class as being owned by this compilation unit. 171 */ register_typeCompilationUnit172 void register_type(c10::NamedTypePtr namedType) { 173 // TODO: class types cannot be redefined because we have no way right now 174 // of invalidating their methods. NamedTuples are fine though, since they 175 // don't have methods. 176 TORCH_CHECK( 177 0 == classDict_.count(*namedType->name()), 178 "class '", 179 namedType->name()->qualifiedName(), 180 "' already defined."); 181 classes_.push_back(std::move(namedType)); 182 classDict_[*classes_.back()->name()] = classes_.size() - 1; 183 }; 184 get_classCompilationUnit185 c10::ClassTypePtr get_class(const c10::QualifiedName& name) const { 186 auto type = get_type(name); 187 if (!type) { 188 return nullptr; 189 } 190 return type->cast<c10::ClassType>(); 191 } 192 get_interfaceCompilationUnit193 c10::InterfaceTypePtr get_interface(const c10::QualifiedName& name) const { 194 auto type = get_type(name); 195 if (!type) { 196 return nullptr; 197 } 198 return type->cast<c10::InterfaceType>(); 199 } 200 get_named_tupleCompilationUnit201 c10::TupleTypePtr get_named_tuple(const c10::QualifiedName& name) const { 202 for (const auto& cls : classes_) { 203 if (cls->name()->qualifiedName() == name.qualifiedName()) { 204 return cls->expect<TupleType>(); 205 } 206 } 207 return nullptr; 208 } 209 get_typeCompilationUnit210 c10::NamedTypePtr get_type(const c10::QualifiedName& name) const { 211 auto it = classDict_.find(name); 212 if (it == classDict_.end()) { 213 return nullptr; 214 } 215 return classes_[it->second]; 216 } 217 218 // For testing: clear all Python-defined classes to ensure that unit tests 219 // have isolation. _clear_python_cuCompilationUnit220 void _clear_python_cu() { 221 // Delete all the associated class methods 222 for (const auto& type : classes_) { 223 if (auto cls = type->cast<ClassType>()) { 224 for (auto method : cls->methods()) { 225 // Tombstone the method in the compilation unit. 226 // Don't erase because the dict_ 227 auto it = dict_.find(method->qualname()); 228 if (it != dict_.end()) { 229 functions_[it->second] = nullptr; 230 // Erase in our big lookup table 231 dict_.erase(it); 232 } 233 } 234 // Classes can have multiple pointers to the same hook, 235 // need to make sure to not delete it twice 236 std::unordered_set<Function*> hooks_to_delete; 237 for (const auto& hook : cls->getForwardHooks()) { 238 hooks_to_delete.insert(hook); 239 } 240 for (const auto& pre_hook : cls->getForwardPreHooks()) { 241 hooks_to_delete.insert(pre_hook); 242 } 243 for (const auto& hook : hooks_to_delete) { 244 // Tombstone the hook in the compilation unit. 245 auto it = dict_.find(hook->qualname()); 246 if (it != dict_.end()) { 247 functions_[it->second] = nullptr; 248 // Erase in our big lookup table 249 dict_.erase(it); 250 } 251 } 252 } 253 } 254 classes_.clear(); 255 classDict_.clear(); 256 } 257 258 // [Internal Only] Remove method. 259 // Note Used for freezing. unsafeRemoveMethodCompilationUnit260 void unsafeRemoveMethod(const c10::QualifiedName& method_name) { 261 auto it = dict_.find(method_name); 262 TORCH_CHECK( 263 it != dict_.end(), 264 "method '", 265 method_name.qualifiedName(), 266 "' does not exist."); 267 functions_[it->second] = nullptr; 268 dict_.erase(it); 269 } 270 271 // [name mangling] All code objects must have a unique qualified name in a 272 // CompilationUnit. In Python, sometimes functions won't have unique qualified 273 // name (for example, nested functions). So we mangle Python functions to 274 // ensure that they are uniquely named. 275 // 276 // We also use mangling to distinguish different Module instances. Since each 277 // Module is a singleton class instance, different instances of the same 278 // Python Module will have different types but the same qualified name. mangleCompilationUnit279 c10::QualifiedName mangle(const c10::QualifiedName& name) const { 280 auto mangled = name; 281 while (get_type(mangled) || find_function(mangled)) { 282 mangled = mangler_.mangle(mangled); 283 } 284 return mangled; 285 } 286 287 private: 288 std::unique_ptr<Function> define( 289 const std::optional<c10::QualifiedName>& prefix, 290 const Def& def, 291 const ResolverPtr& resolver, 292 const Self* self, 293 const std::unordered_map<std::string, Function*>& function_table, 294 bool shouldMangle = false, 295 FunctionType type = FunctionType::Method, 296 std::optional<size_t> version = std::nullopt) const; 297 298 // Define a property on \p self. 299 struct PropertyPair; 300 PropertyPair define_property( 301 const std::optional<c10::QualifiedName>& prefix, 302 const Property& prop, 303 const ResolverPtr& resolver, 304 const Self* self, 305 const std::unordered_map<std::string, Function*>& function_table, 306 bool shouldMangle = false) const; 307 register_functionCompilationUnit308 Function& register_function(std::unique_ptr<Function> fn) { 309 TORCH_CHECK( 310 0 == dict_.count(fn->qualname().qualifiedName()), 311 "method '", 312 fn->qualname().qualifiedName(), 313 "' already defined."); 314 functions_.emplace_back(std::move(fn)); 315 dict_[functions_.back()->qualname()] = functions_.size() - 1; 316 return *functions_.back(); 317 } 318 std::vector<std::unique_ptr<Function>> functions_; 319 // for fast lookup 320 std::unordered_map<c10::QualifiedName, size_t> dict_; 321 std::unordered_map<c10::QualifiedName, size_t> classDict_; 322 323 // [class ownership] Right now there are two relationships between classes 324 // and compilation units: 325 // 1. Classes have compilation units internally that hold their methods. 326 // 2. On load, the TypePtrs of any imported classes are owned by the main 327 // module's compilation unit. 328 std::vector<c10::NamedTypePtr> classes_; 329 330 mutable NameMangler mangler_; 331 }; 332 333 // An owning pointer to a Function. Just a pair of a raw Function ptr and it's 334 // owning CU. We need this because pybind requires a ref-counted way to refer to 335 // Functions. 336 struct StrongFunctionPtr { StrongFunctionPtrStrongFunctionPtr337 StrongFunctionPtr(std::shared_ptr<CompilationUnit> cu, Function* function) 338 : cu_(std::move(cu)), function_(function) { 339 TORCH_INTERNAL_ASSERT(cu_); 340 TORCH_INTERNAL_ASSERT(function_); 341 } 342 std::shared_ptr<CompilationUnit> cu_; 343 Function* function_; 344 }; 345 346 namespace script { 347 // We once had a `script::` namespace that was deleted. This is for backcompat 348 // of the public API; new code should not use this type alias. 349 using CompilationUnit = ::torch::jit::CompilationUnit; 350 } // namespace script 351 } // namespace torch::jit 352