1 #pragma once 2 3 #include <ATen/core/function.h> 4 #include <ATen/core/ivalue.h> 5 #include <ATen/core/stack.h> 6 #include <torch/csrc/api/include/torch/imethod.h> 7 #include <torch/csrc/jit/api/function_impl.h> 8 9 namespace torch::jit { 10 11 using ObjectPtr = c10::intrusive_ptr<c10::ivalue::Object>; 12 13 // A method in a module, e.g. f in: 14 // 15 // class M(ScriptModule): 16 // @script_method 17 // def f(self, x): 18 // ... 19 // Note: because Method/Module are exposed to python these 20 // classes use python method naming conventions 21 struct TORCH_API Method : public torch::IMethod { 22 Method(ObjectPtr owner, Function* function); 23 24 // the module that contains this method. 25 Module owner() const; 26 // the raw objectptr that owns this method, for when the method is owned by a 27 // torchbind object. 28 ObjectPtr raw_owner() const; 29 void run(Stack& stack); runMethod30 void run(Stack&& stack) { 31 run(stack); 32 } 33 34 c10::IValue operator()( 35 std::vector<c10::IValue> stack, 36 const Kwargs& kwargs = Kwargs()) const override; 37 38 // Run method async. Invocation on this function would invokes a JIT 39 // interpreter that executes ops inline, one by one, on caller's thread. A 40 // model can utilize async op, i.e. `fork`, to launch an asynchronous task 41 // which will be launched on provided `taskLauncher`. 42 c10::intrusive_ptr<c10::ivalue::Future> run_async( 43 std::vector<c10::IValue> stack, 44 const Kwargs& kwargs = Kwargs(), 45 TaskLauncher taskLauncher = at::launch); 46 graphMethod47 std::shared_ptr<Graph> graph() const { 48 return toGraphFunction(*function_).graph(); 49 } 50 nameMethod51 const std::string& name() const override { 52 return function_->name(); 53 } 54 num_inputsMethod55 size_t num_inputs() const { 56 return function_->num_inputs(); 57 } 58 get_executorMethod59 GraphExecutor& get_executor() { 60 return toGraphFunction(*function_).get_executor(); 61 } 62 functionMethod63 Function& function() const { 64 return *function_; 65 } 66 67 private: 68 void setArgumentNames(std::vector<std::string>&) const override; 69 70 // Methods are uniqued onwed by a single module. This raw pointer allows 71 // looking up the module. 72 ObjectPtr owner_; 73 74 // Underlying unbound function 75 Function* function_; 76 }; 77 78 namespace script { 79 // We once had a `script::` namespace that was deleted. This is for backcompat 80 // of the public API; new code should not use this type alias. 81 using Method = ::torch::jit::Method; 82 } // namespace script 83 84 } // namespace torch::jit 85