1 #include <torch/csrc/jit/codegen/fuser/fallback.h>
2
3 #include <ATen/core/functional.h> //fmap
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/runtime/custom_operator.h>
8 #include <torch/csrc/jit/runtime/interpreter.h>
9
10 #include <stdexcept>
11
12 namespace torch::jit::fuser {
13
14 namespace {
aliasAnalysisIsSpecialCase()15 c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
16 return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
17 }
18 } // namespace
19
20 // Registers fused operators so that fused graphs can properly generate fallback
21 // code.
22 RegisterOperators reg_fused_operators({Operator(
23 prim::FusedConcat,
__anon9268f9170202(const Node* node) 24 [](const Node* node) -> Operation {
25 int64_t dim = node->i(attr::dim);
26 int64_t num_inputs = node->inputs().size();
27 return [dim, num_inputs](Stack& stack) {
28 auto result = at::cat(
29 fmap(
30 last(stack, num_inputs),
31 [](const IValue& i) { return i.toTensor(); }),
32 dim);
33 drop(stack, num_inputs);
34 pack(stack, std::move(result));
35 };
36 },
37 aliasAnalysisIsSpecialCase())});
38
runFallback(int64_t key,Stack & stack)39 void runFallback(int64_t key, Stack& stack) {
40 auto maybe_spec = retrieve(key);
41 if (!maybe_spec)
42 throw std::runtime_error("Failed to find fusion spec to run fallback.");
43
44 InterpreterState{(*maybe_spec)->code()}.run(stack);
45 }
46
47 } // namespace torch::jit::fuser
48