1 #pragma once 2 3 #include <ATen/core/function_schema.h> 4 #include <ATen/core/ivalue.h> 5 #include <ATen/core/qualified_name.h> 6 #include <c10/util/Exception.h> 7 #include <c10/util/FunctionRef.h> 8 9 namespace c10 { 10 struct FunctionSchema; 11 }; 12 13 namespace at { 14 TORCH_API void launch(std::function<void()> func); 15 } 16 17 namespace torch::jit { 18 19 struct Graph; 20 struct Code; 21 22 namespace mobile { 23 struct Code; 24 } 25 26 using Stack = std::vector<at::IValue>; 27 using Kwargs = std::unordered_map<std::string, at::IValue>; 28 struct RecursiveMethodCallError : public std::exception {}; 29 using TaskLauncher = std::function<void(std::function<void()>)>; 30 31 TORCH_API void preoptimizeGraph( 32 std::shared_ptr<Graph>& graph, 33 bool disable_autocast = false); 34 35 // A Function is a pure Graph with no implicit `self` object bound. 36 // It contains schema information and the executor that manages the 37 // execution of the function. Method is a wrapper around an 38 // underlying Function that also provides a `self` object. 39 struct TORCH_API Function { 40 Function() = default; 41 Function(const Function&) = default; 42 Function& operator=(const Function&) = default; 43 Function(Function&&) noexcept = default; 44 Function& operator=(Function&&) noexcept = default; doc_stringFunction45 virtual c10::string_view doc_string() const { 46 static constexpr c10::string_view no_doc_string = ""; 47 return no_doc_string; 48 } 49 isGraphFunctionFunction50 virtual bool isGraphFunction() const { 51 return false; 52 } 53 54 virtual void run(Stack& stack) = 0; 55 56 virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync( 57 Stack& /*stack*/, 58 // NOLINTNEXTLINE(performance-unnecessary-value-param) 59 C10_UNUSED TaskLauncher taskLauncher = at::launch) { 60 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); 61 return {}; 62 } 63 operatorFunction64 at::IValue operator()(Stack stack, const Kwargs& kwargs = Kwargs()) { 65 getSchema().checkAndNormalizeInputs(stack, kwargs); 66 run(stack); 67 return stack.front(); 68 } 69 70 virtual const c10::QualifiedName& qualname() const = 0; 71 nameFunction72 const std::string& name() const { 73 return qualname().name(); 74 } 75 76 // if this isn't yet defined, run its method_creator function 77 virtual void ensure_defined() = 0; 78 79 virtual const c10::FunctionSchema& getSchema() const = 0; 80 81 virtual size_t num_inputs() const = 0; 82 83 virtual Function& setSchema(c10::FunctionSchema schema) = 0; 84 85 // call() defines how different interpreter implementations interacts with 86 // Function objects. Basically interpreters need to provide a callback to 87 // communicate to Functions what to do if provided a Code object. 88 // Alternatively we could design the signature to return an optional Code 89 // object, but that requires special handling the null case in interpreter 90 // and the fallback behavior is not well defined by interpreter but rather 91 // Function themselves, so a callback approach is more reasonable than 92 // returning values. 93 // If call() returns true, then callback completes successfully, otherwise 94 // call() returns false. 95 96 // Overload for server interpreter, a bailout size is needed for graph 97 // executor. callFunction98 virtual bool call( 99 Stack&, 100 std::optional<size_t>, 101 c10::function_ref<void(const Code&)>) { 102 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); 103 return false; 104 } 105 106 // Overload for mobile interpreter. callFunction107 virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) { 108 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); 109 return false; 110 } 111 112 virtual ~Function() = default; 113 }; 114 } // namespace torch::jit 115