xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/tensorexpr_fuser.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <memory>
6 
7 namespace torch::jit {
8 
9 // Run TensorExpressions-based fuser.
10 // If add_composed_op is true, creates a single operation that
11 // performs both the runtime check that types align
12 // and then the dispatch to the kernel/unoptimized graph
13 TORCH_API void FuseTensorExprs(
14     std::shared_ptr<Graph>& graph,
15     size_t min_group_size = 2,
16     bool add_composed_op = false,
17     bool fuse_to_dynamic_shapes = false);
18 
19 TORCH_API void setTensorExprFuserEnabled(bool val);
20 TORCH_API bool tensorExprFuserEnabled();
21 TORCH_API void setTensorExprDynamicShapeFusionEnabled(bool val);
22 TORCH_API bool tensorExprDynamicShapeFusionEnabled();
23 TORCH_API bool setTexprReductionsEnabled(bool value);
24 TORCH_API bool texprReductionsEnabled();
25 
26 TORCH_API void RemoveProfileNodesAndSpecializeTypes(
27     std::shared_ptr<Graph>& graph);
28 TORCH_API bool hasTensorTypeSpecialization(Value* v);
29 TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph);
30 TORCH_API void removeTensorTypeSpecializations(Block* block);
31 
32 using tensor_type_converter_t =
33     c10::function_ref<TensorTypePtr(const TensorTypePtr& t)>;
34 
35 // inserts a TypeCheck pattern
36 //
37 // around the guarded node that has a Subgraph attribute, this inserts a pattern
38 //
39 //   if TypeCheck(...):
40 //     guarded_node
41 //   else:
42 //     FallbackGraph(...)
43 //
44 // The TypeCheck includes the types of all Tensor inputs to the guarded_node,
45 // as processed by the type_converter, a lambda
46 // TensorTypePtr(const TensorTypePtr& t). This allows to erase irrelevant
47 // aspects of the type.
48 //
49 // The Fallback graph will have the same subgraph as the guarded node (with the
50 // expectation that the guarded_node's subgraph will then be optimized.
51 TORCH_API void insertTypeGuard(
52     Node* guarded_node,
53     tensor_type_converter_t type_converter,
54     c10::Symbol kind);
55 
56 TORCH_API bool usedOnlyInSize(Value* v);
57 TORCH_API Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db);
58 
59 namespace tensorexpr {
60 TORCH_API bool isSupported(Node* node);
61 
62 /// Get the modifiable custom operator set object.
63 ///
64 /// For static shapes, if a custom operator has been added to the custom
65 /// operator set, it will be pulled into the NNC fusion group. But it doesn't
66 /// work with dynamic shapes unless explicitly register the shape function via
67 /// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator.
68 ///
69 /// @return Reference of the custome operator set
70 ///
71 TORCH_API OperatorSet& getCustomOperatorSet();
72 } // namespace tensorexpr
73 } // namespace torch::jit
74