xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/TraceTypeManual.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/TracerMode.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/tracer.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/library.h>
8 #include <optional>
9 
10 using namespace at;
11 
12 namespace torch::TraceType {
13 
14 namespace {
15 
copy_(Tensor & self,const Tensor & src,bool non_blocking)16 Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
17   jit::Value* output = nullptr;
18   if (torch::jit::tracer::isTracing()) {
19     const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
20     auto& graph = state.graph;
21     if (state.force_outplace && self.storage().use_count() <= 1) {
22       // if you have no views of self, then an in place copy is equivalent to
23       // making sure we expand src to the same size as self
24       jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
25       jit::tracer::addInputs(node, "src", src);
26       jit::tracer::addInputs(node, "self", self);
27       graph->insertNode(node);
28       output = node->output();
29     } else {
30       output = graph->insert(
31           jit::aten::copy_,
32           {jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
33       jit::tracer::recordSourceLocation(output->node());
34     }
35     jit::tracer::ensureUniqueIfOutOfPlaced(
36         "copy_ (possibly due to an assignment)", self);
37   }
38 
39   {
40     at::tracer::impl::NoTracerDispatchMode tracer_guard;
41     self.copy_(src, non_blocking);
42   }
43 
44   if (torch::jit::tracer::isTracing()) {
45     jit::tracer::setOutput(output, self);
46   }
47   return self;
48 }
49 
resize_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)50 const Tensor& resize_(
51     const Tensor& self,
52     IntArrayRef size,
53     std::optional<MemoryFormat> optional_memory_format) {
54   if (torch::jit::tracer::isTracing()) {
55     if (jit::tracer::ArgumentStash::hasIntArrayRef("size")) {
56       jit::tracer::ArgumentStash::popIntArrayRef("size");
57     }
58     jit::tracer::warn("resize_", jit::tracer::WARN_RESIZE);
59     jit::tracer::delValueTrace(self);
60   }
61 
62   {
63     at::tracer::impl::NoTracerDispatchMode tracer_guard;
64     self.resize_(size, optional_memory_format);
65   }
66   return self;
67 }
68 
resize_as_(const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)69 const Tensor& resize_as_(
70     const Tensor& self,
71     const Tensor& the_template,
72     std::optional<MemoryFormat> optional_memory_format) {
73   if (torch::jit::tracer::isTracing()) {
74     jit::tracer::warn("resize_as_", jit::tracer::WARN_RESIZE);
75     jit::tracer::delValueTrace(self);
76   }
77 
78   {
79     at::tracer::impl::NoTracerDispatchMode tracer_guard;
80     self.resize_as_(the_template, optional_memory_format);
81   }
82   return self;
83 }
84 
detach(const Tensor & self)85 Tensor detach(const Tensor& self) {
86   torch::jit::Node* node = nullptr;
87   if (jit::tracer::isTracing()) {
88     auto& graph = jit::tracer::getTracingState()->graph;
89     node = graph->create(jit::aten::detach, /*num_outputs=*/0);
90     jit::tracer::recordSourceLocation(node);
91     jit::tracer::addInputs(node, "self", self);
92     graph->insertNode(node);
93   }
94 
95   auto result = [&]() {
96     at::tracer::impl::NoTracerDispatchMode tracer_guard;
97     return self.detach();
98   }();
99 
100   if (jit::tracer::isTracing()) {
101     jit::tracer::addOutput(node, result);
102   }
103   return result;
104 }
105 
detach_(Tensor & self)106 Tensor& detach_(Tensor& self) {
107   torch::jit::Node* node = nullptr;
108   if (jit::tracer::isTracing()) {
109     auto& graph = jit::tracer::getTracingState()->graph;
110     node = graph->create(jit::aten::detach, /*num_outputs=*/0);
111     jit::tracer::recordSourceLocation(node);
112     jit::tracer::addInputs(node, "self", self);
113     graph->insertNode(node);
114     jit::tracer::ensureUniqueIfOutOfPlaced("detach_", self);
115   }
116 
117   {
118     at::tracer::impl::NoTracerDispatchMode tracer_guard;
119     self.detach_();
120   }
121 
122   if (jit::tracer::isTracing() && node) {
123     jit::tracer::addOutput(node, self);
124   }
125   return self;
126 }
127 
128 // Invariant:
129 // - Ops registered to DispatchKey::Tracer below must be included in
130 // `MANUAL_TRACER` in tools/autograd/gen_variable_type.py
TORCH_LIBRARY_IMPL(aten,Tracer,m)131 TORCH_LIBRARY_IMPL(aten, Tracer, m) {
132   m.impl("resize_", resize_);
133   m.impl("resize_as_", resize_as_);
134   m.impl("detach", TORCH_FN(detach));
135   m.impl("detach_", detach_);
136   m.impl("copy_", copy_);
137 
138   // Skip tracing for the following ops by registering fallthrough kernel
139   // explicitly.
140   m.impl("_backward", CppFunction::makeFallthrough());
141   m.impl("set_data", CppFunction::makeFallthrough());
142   m.impl("data", CppFunction::makeFallthrough());
143   m.impl("is_leaf", CppFunction::makeFallthrough());
144   m.impl("output_nr", CppFunction::makeFallthrough());
145   m.impl("_version", CppFunction::makeFallthrough());
146   m.impl("requires_grad_", CppFunction::makeFallthrough());
147   m.impl("retain_grad", CppFunction::makeFallthrough());
148   m.impl("_fw_primal", CppFunction::makeFallthrough());
149   m.impl("_make_dual", CppFunction::makeFallthrough());
150 }
151 
152 } // namespace
153 
154 } // namespace torch::TraceType
155 
156 namespace torch::jit {
general_trace_function(const c10::OperatorHandle & op,Stack * stack)157 static void general_trace_function(
158     const c10::OperatorHandle& op,
159     Stack* stack) {
160   const auto input_size = op.schema().arguments().size();
161   const auto output_size = op.schema().returns().size();
162 
163   Node* node = nullptr;
164   std::shared_ptr<tracer::TracingState> tracer_state;
165 
166   // trace the input before unwrapping, otherwise we may lose
167   // the input information
168   if (tracer::isTracing()) {
169     tracer_state = tracer::getTracingState();
170     auto symbol = Symbol::fromQualString(op.schema().name());
171     const auto& graph = tracer::getTracingState()->graph;
172     node = graph->create(symbol, 0);
173     tracer::recordSourceLocation(node);
174     const auto& args = op.schema().arguments();
175     int i = 0;
176     for (auto iter = stack->end() - static_cast<std::ptrdiff_t>(input_size);
177          iter != stack->end();
178          ++iter, ++i) {
179       // TODO we need to refactor graph APIs (e.g., addInputs)
180       // appropriately; after that, we can get rid of the giant if-else
181       // block we will clean this tech debt together in the following PRs
182       auto type = args[i].type();
183       if (type->kind() == TypeKind::OptionalType) {
184         if (iter->isNone()) {
185           Value* none = graph->insertNode(graph->createNone())->output();
186           node->addInput(none);
187           continue;
188         } else {
189           type = type->expectRef<OptionalType>().getElementType();
190         }
191       }
192       if (type->isSubtypeOf(*TensorType::get())) {
193         AT_ASSERT(iter->isTensor());
194         tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
195       } else if (type->kind() == TypeKind::FloatType) {
196         AT_ASSERT(iter->isDouble());
197         tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
198       } else if (type->kind() == TypeKind::IntType) {
199         AT_ASSERT(iter->isInt());
200         tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
201       } else if (type->kind() == TypeKind::BoolType) {
202         AT_ASSERT(iter->isBool());
203         tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
204       } else if (type->kind() == TypeKind::StringType) {
205         AT_ASSERT(iter->isString());
206         tracer::addInputs(node, args[i].name().c_str(), iter->toStringView());
207       } else if (type->kind() == TypeKind::NumberType) {
208         tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
209       } else if (type->kind() == TypeKind::ListType) {
210         const auto& elem_type = type->expectRef<ListType>().getElementType();
211         if (elem_type->isSubtypeOf(*TensorType::get())) {
212           AT_ASSERT(iter->isTensorList());
213           auto list = iter->toTensorVector();
214           tracer::addInputs(node, args[i].name().c_str(), list);
215         } else if (auto class_type = elem_type->cast<ClassType>()) {
216           AT_ASSERT(iter->isList());
217           auto list = iter->toList();
218           std::vector<c10::intrusive_ptr<c10::ivalue::Object>> objects;
219           for (IValue iv : list) {
220             objects.emplace_back(std::move(iv).toObject());
221           }
222           tracer::addInputs(node, args[i].name().c_str(), objects, class_type);
223         } else if (elem_type->kind() == TypeKind::FloatType) {
224           AT_ASSERT(iter->isDoubleList());
225           // NB: now, tracer doesn't support tracing double list. We add
226           // special handling here, since in our case, we assume that all the
227           // doubles in the list are constants
228           auto value = iter->toDoubleVector();
229           std::vector<Value*> info(value.size());
230           for (const auto value_index : c10::irange(value.size())) {
231             info[value_index] = graph->insertConstant(value[value_index]);
232             tracer::recordSourceLocation(info[value_index]->node());
233           }
234           node->addInput(
235               graph->insertNode(graph->createList(FloatType::get(), info))
236                   ->output());
237         } else if (elem_type->kind() == TypeKind::IntType) {
238           AT_ASSERT(iter->isIntList());
239           tracer::addInputs(
240               node,
241               args[i].name().c_str(),
242               c10::IntArrayRef(iter->toIntVector()));
243         } else if (elem_type->kind() == TypeKind::BoolType) {
244           AT_ASSERT(iter->isBoolList());
245           tracer::addInputs(
246               node, args[i].name().c_str(), iter->toBoolList().vec());
247         } else {
248           throw std::runtime_error(
249               "unsupported input list type: " + elem_type->str());
250         }
251       } else if (iter->isObject()) {
252         tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
253       } else {
254         throw std::runtime_error("unsupported input type: " + type->str());
255       }
256     }
257     graph->insertNode(node);
258 
259     tracer::setTracingState(nullptr);
260   }
261 
262   op.callBoxed(stack);
263 
264   if (tracer_state) {
265     tracer::setTracingState(std::move(tracer_state));
266     int i = 0;
267     for (auto iter = stack->end() - static_cast<std::ptrdiff_t>(output_size);
268          iter != stack->end();
269          ++iter, ++i) {
270       const auto& type = op.schema().returns()[i].type();
271       if (type->isSubtypeOf(*TensorType::get())) {
272         AT_ASSERT(iter->isTensor());
273         tracer::addOutput(node, iter->toTensor());
274       } else if (type->kind() == TypeKind::ListType) {
275         const auto& elem_type = type->expectRef<ListType>().getElementType();
276         if (elem_type->isSubtypeOf(*TensorType::get())) {
277           AT_ASSERT(iter->isTensorList());
278           tracer::addOutput(node, iter->toTensorList());
279         } else {
280           throw std::runtime_error(
281               "unsupported ouptut list type: " + elem_type->str());
282         }
283       } else if (type->kind() == TypeKind::ClassType) {
284         AT_ASSERT(iter->isObject());
285         tracer::addOutput(node, iter->toObject());
286       } else {
287         throw std::runtime_error(
288             "unsupported output type: " + type->str() +
289             ", from operator: " + toString(op.operator_name()));
290       }
291     }
292   }
293 }
TORCH_LIBRARY_IMPL(_,Tracer,m)294 TORCH_LIBRARY_IMPL(_, Tracer, m) {
295   m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>());
296 }
297 
298 } // namespace torch::jit
299