xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/function_schema.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/core/qualified_name.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/FunctionRef.h>
8 
9 namespace c10 {
10 struct FunctionSchema;
11 };
12 
13 namespace at {
14 TORCH_API void launch(std::function<void()> func);
15 }
16 
17 namespace torch::jit {
18 
19 struct Graph;
20 struct Code;
21 
22 namespace mobile {
23 struct Code;
24 }
25 
26 using Stack = std::vector<at::IValue>;
27 using Kwargs = std::unordered_map<std::string, at::IValue>;
28 struct RecursiveMethodCallError : public std::exception {};
29 using TaskLauncher = std::function<void(std::function<void()>)>;
30 
31 TORCH_API void preoptimizeGraph(
32     std::shared_ptr<Graph>& graph,
33     bool disable_autocast = false);
34 
35 // A Function is a pure Graph with no implicit `self` object bound.
36 // It contains schema information and the executor that manages the
37 // execution of the function. Method is a wrapper around an
38 // underlying Function that also provides a `self` object.
39 struct TORCH_API Function {
40   Function() = default;
41   Function(const Function&) = default;
42   Function& operator=(const Function&) = default;
43   Function(Function&&) noexcept = default;
44   Function& operator=(Function&&) noexcept = default;
doc_stringFunction45   virtual c10::string_view doc_string() const {
46     static constexpr c10::string_view no_doc_string = "";
47     return no_doc_string;
48   }
49 
isGraphFunctionFunction50   virtual bool isGraphFunction() const {
51     return false;
52   }
53 
54   virtual void run(Stack& stack) = 0;
55 
56   virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
57       Stack& /*stack*/,
58       // NOLINTNEXTLINE(performance-unnecessary-value-param)
59       C10_UNUSED TaskLauncher taskLauncher = at::launch) {
60     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
61     return {};
62   }
63 
operatorFunction64   at::IValue operator()(Stack stack, const Kwargs& kwargs = Kwargs()) {
65     getSchema().checkAndNormalizeInputs(stack, kwargs);
66     run(stack);
67     return stack.front();
68   }
69 
70   virtual const c10::QualifiedName& qualname() const = 0;
71 
nameFunction72   const std::string& name() const {
73     return qualname().name();
74   }
75 
76   // if this isn't yet defined, run its method_creator function
77   virtual void ensure_defined() = 0;
78 
79   virtual const c10::FunctionSchema& getSchema() const = 0;
80 
81   virtual size_t num_inputs() const = 0;
82 
83   virtual Function& setSchema(c10::FunctionSchema schema) = 0;
84 
85   // call() defines how different interpreter implementations interacts with
86   // Function objects. Basically interpreters need to provide a callback to
87   // communicate to Functions what to do if provided a Code object.
88   // Alternatively we could design the signature to return an optional Code
89   // object, but that requires special handling the null case in interpreter
90   // and the fallback behavior is not well defined by interpreter but rather
91   // Function themselves, so a callback approach is more reasonable than
92   // returning values.
93   // If call() returns true, then callback completes successfully, otherwise
94   // call() returns false.
95 
96   // Overload for server interpreter, a bailout size is needed for graph
97   // executor.
callFunction98   virtual bool call(
99       Stack&,
100       std::optional<size_t>,
101       c10::function_ref<void(const Code&)>) {
102     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
103     return false;
104   }
105 
106   // Overload for mobile interpreter.
callFunction107   virtual bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) {
108     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
109     return false;
110   }
111 
112   virtual ~Function() = default;
113 };
114 } // namespace torch::jit
115