1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <memory> 6 7 namespace torch::jit { 8 9 // Run TensorExpressions-based fuser. 10 // If add_composed_op is true, creates a single operation that 11 // performs both the runtime check that types align 12 // and then the dispatch to the kernel/unoptimized graph 13 TORCH_API void FuseTensorExprs( 14 std::shared_ptr<Graph>& graph, 15 size_t min_group_size = 2, 16 bool add_composed_op = false, 17 bool fuse_to_dynamic_shapes = false); 18 19 TORCH_API void setTensorExprFuserEnabled(bool val); 20 TORCH_API bool tensorExprFuserEnabled(); 21 TORCH_API void setTensorExprDynamicShapeFusionEnabled(bool val); 22 TORCH_API bool tensorExprDynamicShapeFusionEnabled(); 23 TORCH_API bool setTexprReductionsEnabled(bool value); 24 TORCH_API bool texprReductionsEnabled(); 25 26 TORCH_API void RemoveProfileNodesAndSpecializeTypes( 27 std::shared_ptr<Graph>& graph); 28 TORCH_API bool hasTensorTypeSpecialization(Value* v); 29 TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph); 30 TORCH_API void removeTensorTypeSpecializations(Block* block); 31 32 using tensor_type_converter_t = 33 c10::function_ref<TensorTypePtr(const TensorTypePtr& t)>; 34 35 // inserts a TypeCheck pattern 36 // 37 // around the guarded node that has a Subgraph attribute, this inserts a pattern 38 // 39 // if TypeCheck(...): 40 // guarded_node 41 // else: 42 // FallbackGraph(...) 43 // 44 // The TypeCheck includes the types of all Tensor inputs to the guarded_node, 45 // as processed by the type_converter, a lambda 46 // TensorTypePtr(const TensorTypePtr& t). This allows to erase irrelevant 47 // aspects of the type. 48 // 49 // The Fallback graph will have the same subgraph as the guarded node (with the 50 // expectation that the guarded_node's subgraph will then be optimized. 51 TORCH_API void insertTypeGuard( 52 Node* guarded_node, 53 tensor_type_converter_t type_converter, 54 c10::Symbol kind); 55 56 TORCH_API bool usedOnlyInSize(Value* v); 57 TORCH_API Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db); 58 59 namespace tensorexpr { 60 TORCH_API bool isSupported(Node* node); 61 62 /// Get the modifiable custom operator set object. 63 /// 64 /// For static shapes, if a custom operator has been added to the custom 65 /// operator set, it will be pulled into the NNC fusion group. But it doesn't 66 /// work with dynamic shapes unless explicitly register the shape function via 67 /// `torch::jit::RegisterShapeComputeGraphForSchema` for the custom operator. 68 /// 69 /// @return Reference of the custome operator set 70 /// 71 TORCH_API OperatorSet& getCustomOperatorSet(); 72 } // namespace tensorexpr 73 } // namespace torch::jit 74