xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/fallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/fallback.h>
2 
3 #include <ATen/core/functional.h> //fmap
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/runtime/custom_operator.h>
8 #include <torch/csrc/jit/runtime/interpreter.h>
9 
10 #include <stdexcept>
11 
12 namespace torch::jit::fuser {
13 
14 namespace {
aliasAnalysisIsSpecialCase()15 c10::AliasAnalysisKind aliasAnalysisIsSpecialCase() {
16   return AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
17 }
18 } // namespace
19 
20 // Registers fused operators so that fused graphs can properly generate fallback
21 // code.
22 RegisterOperators reg_fused_operators({Operator(
23     prim::FusedConcat,
__anon9268f9170202(const Node* node) 24     [](const Node* node) -> Operation {
25       int64_t dim = node->i(attr::dim);
26       int64_t num_inputs = node->inputs().size();
27       return [dim, num_inputs](Stack& stack) {
28         auto result = at::cat(
29             fmap(
30                 last(stack, num_inputs),
31                 [](const IValue& i) { return i.toTensor(); }),
32             dim);
33         drop(stack, num_inputs);
34         pack(stack, std::move(result));
35       };
36     },
37     aliasAnalysisIsSpecialCase())});
38 
runFallback(int64_t key,Stack & stack)39 void runFallback(int64_t key, Stack& stack) {
40   auto maybe_spec = retrieve(key);
41   if (!maybe_spec)
42     throw std::runtime_error("Failed to find fusion spec to run fallback.");
43 
44   InterpreterState{(*maybe_spec)->code()}.run(stack);
45 }
46 
47 } // namespace torch::jit::fuser
48