1 #include <ATen/TracerMode.h>
2 #include <ATen/core/op_registration/op_registration.h>
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/tracer.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/library.h>
8 #include <optional>
9
10 using namespace at;
11
12 namespace torch::TraceType {
13
14 namespace {
15
copy_(Tensor & self,const Tensor & src,bool non_blocking)16 Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
17 jit::Value* output = nullptr;
18 if (torch::jit::tracer::isTracing()) {
19 const jit::tracer::TracingState& state = *jit::tracer::getTracingState();
20 auto& graph = state.graph;
21 if (state.force_outplace && self.storage().use_count() <= 1) {
22 // if you have no views of self, then an in place copy is equivalent to
23 // making sure we expand src to the same size as self
24 jit::Node* node = graph->create(jit::aten::expand_as, /*num_outputs=*/1);
25 jit::tracer::addInputs(node, "src", src);
26 jit::tracer::addInputs(node, "self", self);
27 graph->insertNode(node);
28 output = node->output();
29 } else {
30 output = graph->insert(
31 jit::aten::copy_,
32 {jit::tracer::getValueTrace(self), jit::tracer::getValueTrace(src)});
33 jit::tracer::recordSourceLocation(output->node());
34 }
35 jit::tracer::ensureUniqueIfOutOfPlaced(
36 "copy_ (possibly due to an assignment)", self);
37 }
38
39 {
40 at::tracer::impl::NoTracerDispatchMode tracer_guard;
41 self.copy_(src, non_blocking);
42 }
43
44 if (torch::jit::tracer::isTracing()) {
45 jit::tracer::setOutput(output, self);
46 }
47 return self;
48 }
49
resize_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)50 const Tensor& resize_(
51 const Tensor& self,
52 IntArrayRef size,
53 std::optional<MemoryFormat> optional_memory_format) {
54 if (torch::jit::tracer::isTracing()) {
55 if (jit::tracer::ArgumentStash::hasIntArrayRef("size")) {
56 jit::tracer::ArgumentStash::popIntArrayRef("size");
57 }
58 jit::tracer::warn("resize_", jit::tracer::WARN_RESIZE);
59 jit::tracer::delValueTrace(self);
60 }
61
62 {
63 at::tracer::impl::NoTracerDispatchMode tracer_guard;
64 self.resize_(size, optional_memory_format);
65 }
66 return self;
67 }
68
resize_as_(const Tensor & self,const Tensor & the_template,std::optional<MemoryFormat> optional_memory_format)69 const Tensor& resize_as_(
70 const Tensor& self,
71 const Tensor& the_template,
72 std::optional<MemoryFormat> optional_memory_format) {
73 if (torch::jit::tracer::isTracing()) {
74 jit::tracer::warn("resize_as_", jit::tracer::WARN_RESIZE);
75 jit::tracer::delValueTrace(self);
76 }
77
78 {
79 at::tracer::impl::NoTracerDispatchMode tracer_guard;
80 self.resize_as_(the_template, optional_memory_format);
81 }
82 return self;
83 }
84
detach(const Tensor & self)85 Tensor detach(const Tensor& self) {
86 torch::jit::Node* node = nullptr;
87 if (jit::tracer::isTracing()) {
88 auto& graph = jit::tracer::getTracingState()->graph;
89 node = graph->create(jit::aten::detach, /*num_outputs=*/0);
90 jit::tracer::recordSourceLocation(node);
91 jit::tracer::addInputs(node, "self", self);
92 graph->insertNode(node);
93 }
94
95 auto result = [&]() {
96 at::tracer::impl::NoTracerDispatchMode tracer_guard;
97 return self.detach();
98 }();
99
100 if (jit::tracer::isTracing()) {
101 jit::tracer::addOutput(node, result);
102 }
103 return result;
104 }
105
detach_(Tensor & self)106 Tensor& detach_(Tensor& self) {
107 torch::jit::Node* node = nullptr;
108 if (jit::tracer::isTracing()) {
109 auto& graph = jit::tracer::getTracingState()->graph;
110 node = graph->create(jit::aten::detach, /*num_outputs=*/0);
111 jit::tracer::recordSourceLocation(node);
112 jit::tracer::addInputs(node, "self", self);
113 graph->insertNode(node);
114 jit::tracer::ensureUniqueIfOutOfPlaced("detach_", self);
115 }
116
117 {
118 at::tracer::impl::NoTracerDispatchMode tracer_guard;
119 self.detach_();
120 }
121
122 if (jit::tracer::isTracing() && node) {
123 jit::tracer::addOutput(node, self);
124 }
125 return self;
126 }
127
128 // Invariant:
129 // - Ops registered to DispatchKey::Tracer below must be included in
130 // `MANUAL_TRACER` in tools/autograd/gen_variable_type.py
TORCH_LIBRARY_IMPL(aten,Tracer,m)131 TORCH_LIBRARY_IMPL(aten, Tracer, m) {
132 m.impl("resize_", resize_);
133 m.impl("resize_as_", resize_as_);
134 m.impl("detach", TORCH_FN(detach));
135 m.impl("detach_", detach_);
136 m.impl("copy_", copy_);
137
138 // Skip tracing for the following ops by registering fallthrough kernel
139 // explicitly.
140 m.impl("_backward", CppFunction::makeFallthrough());
141 m.impl("set_data", CppFunction::makeFallthrough());
142 m.impl("data", CppFunction::makeFallthrough());
143 m.impl("is_leaf", CppFunction::makeFallthrough());
144 m.impl("output_nr", CppFunction::makeFallthrough());
145 m.impl("_version", CppFunction::makeFallthrough());
146 m.impl("requires_grad_", CppFunction::makeFallthrough());
147 m.impl("retain_grad", CppFunction::makeFallthrough());
148 m.impl("_fw_primal", CppFunction::makeFallthrough());
149 m.impl("_make_dual", CppFunction::makeFallthrough());
150 }
151
152 } // namespace
153
154 } // namespace torch::TraceType
155
156 namespace torch::jit {
general_trace_function(const c10::OperatorHandle & op,Stack * stack)157 static void general_trace_function(
158 const c10::OperatorHandle& op,
159 Stack* stack) {
160 const auto input_size = op.schema().arguments().size();
161 const auto output_size = op.schema().returns().size();
162
163 Node* node = nullptr;
164 std::shared_ptr<tracer::TracingState> tracer_state;
165
166 // trace the input before unwrapping, otherwise we may lose
167 // the input information
168 if (tracer::isTracing()) {
169 tracer_state = tracer::getTracingState();
170 auto symbol = Symbol::fromQualString(op.schema().name());
171 const auto& graph = tracer::getTracingState()->graph;
172 node = graph->create(symbol, 0);
173 tracer::recordSourceLocation(node);
174 const auto& args = op.schema().arguments();
175 int i = 0;
176 for (auto iter = stack->end() - static_cast<std::ptrdiff_t>(input_size);
177 iter != stack->end();
178 ++iter, ++i) {
179 // TODO we need to refactor graph APIs (e.g., addInputs)
180 // appropriately; after that, we can get rid of the giant if-else
181 // block we will clean this tech debt together in the following PRs
182 auto type = args[i].type();
183 if (type->kind() == TypeKind::OptionalType) {
184 if (iter->isNone()) {
185 Value* none = graph->insertNode(graph->createNone())->output();
186 node->addInput(none);
187 continue;
188 } else {
189 type = type->expectRef<OptionalType>().getElementType();
190 }
191 }
192 if (type->isSubtypeOf(*TensorType::get())) {
193 AT_ASSERT(iter->isTensor());
194 tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
195 } else if (type->kind() == TypeKind::FloatType) {
196 AT_ASSERT(iter->isDouble());
197 tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
198 } else if (type->kind() == TypeKind::IntType) {
199 AT_ASSERT(iter->isInt());
200 tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
201 } else if (type->kind() == TypeKind::BoolType) {
202 AT_ASSERT(iter->isBool());
203 tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
204 } else if (type->kind() == TypeKind::StringType) {
205 AT_ASSERT(iter->isString());
206 tracer::addInputs(node, args[i].name().c_str(), iter->toStringView());
207 } else if (type->kind() == TypeKind::NumberType) {
208 tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
209 } else if (type->kind() == TypeKind::ListType) {
210 const auto& elem_type = type->expectRef<ListType>().getElementType();
211 if (elem_type->isSubtypeOf(*TensorType::get())) {
212 AT_ASSERT(iter->isTensorList());
213 auto list = iter->toTensorVector();
214 tracer::addInputs(node, args[i].name().c_str(), list);
215 } else if (auto class_type = elem_type->cast<ClassType>()) {
216 AT_ASSERT(iter->isList());
217 auto list = iter->toList();
218 std::vector<c10::intrusive_ptr<c10::ivalue::Object>> objects;
219 for (IValue iv : list) {
220 objects.emplace_back(std::move(iv).toObject());
221 }
222 tracer::addInputs(node, args[i].name().c_str(), objects, class_type);
223 } else if (elem_type->kind() == TypeKind::FloatType) {
224 AT_ASSERT(iter->isDoubleList());
225 // NB: now, tracer doesn't support tracing double list. We add
226 // special handling here, since in our case, we assume that all the
227 // doubles in the list are constants
228 auto value = iter->toDoubleVector();
229 std::vector<Value*> info(value.size());
230 for (const auto value_index : c10::irange(value.size())) {
231 info[value_index] = graph->insertConstant(value[value_index]);
232 tracer::recordSourceLocation(info[value_index]->node());
233 }
234 node->addInput(
235 graph->insertNode(graph->createList(FloatType::get(), info))
236 ->output());
237 } else if (elem_type->kind() == TypeKind::IntType) {
238 AT_ASSERT(iter->isIntList());
239 tracer::addInputs(
240 node,
241 args[i].name().c_str(),
242 c10::IntArrayRef(iter->toIntVector()));
243 } else if (elem_type->kind() == TypeKind::BoolType) {
244 AT_ASSERT(iter->isBoolList());
245 tracer::addInputs(
246 node, args[i].name().c_str(), iter->toBoolList().vec());
247 } else {
248 throw std::runtime_error(
249 "unsupported input list type: " + elem_type->str());
250 }
251 } else if (iter->isObject()) {
252 tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
253 } else {
254 throw std::runtime_error("unsupported input type: " + type->str());
255 }
256 }
257 graph->insertNode(node);
258
259 tracer::setTracingState(nullptr);
260 }
261
262 op.callBoxed(stack);
263
264 if (tracer_state) {
265 tracer::setTracingState(std::move(tracer_state));
266 int i = 0;
267 for (auto iter = stack->end() - static_cast<std::ptrdiff_t>(output_size);
268 iter != stack->end();
269 ++iter, ++i) {
270 const auto& type = op.schema().returns()[i].type();
271 if (type->isSubtypeOf(*TensorType::get())) {
272 AT_ASSERT(iter->isTensor());
273 tracer::addOutput(node, iter->toTensor());
274 } else if (type->kind() == TypeKind::ListType) {
275 const auto& elem_type = type->expectRef<ListType>().getElementType();
276 if (elem_type->isSubtypeOf(*TensorType::get())) {
277 AT_ASSERT(iter->isTensorList());
278 tracer::addOutput(node, iter->toTensorList());
279 } else {
280 throw std::runtime_error(
281 "unsupported ouptut list type: " + elem_type->str());
282 }
283 } else if (type->kind() == TypeKind::ClassType) {
284 AT_ASSERT(iter->isObject());
285 tracer::addOutput(node, iter->toObject());
286 } else {
287 throw std::runtime_error(
288 "unsupported output type: " + type->str() +
289 ", from operator: " + toString(op.operator_name()));
290 }
291 }
292 }
293 }
TORCH_LIBRARY_IMPL(_,Tracer,m)294 TORCH_LIBRARY_IMPL(_, Tracer, m) {
295 m.fallback(CppFunction::makeFromBoxedFunction<&general_trace_function>());
296 }
297
298 } // namespace torch::jit
299