1 #pragma once 2 3 #include <unordered_map> 4 #include <vector> 5 6 #include <torch/csrc/Export.h> 7 #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h> 8 9 namespace torch { 10 namespace jit { 11 namespace tensorexpr { 12 13 class Expr; 14 class Buf; 15 class Stmt; 16 17 enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate }; 18 19 struct TORCH_API TensorAccessBoundsInfo { 20 TensorAccessKind kind; 21 std::vector<ExprPtr> start; 22 std::vector<ExprPtr> stop; 23 }; 24 25 using BoundsInfo = 26 std::unordered_map<BufPtr, std::vector<TensorAccessBoundsInfo>>; 27 28 TORCH_API BoundsInfo 29 inferBounds(const StmtPtr& s, bool distinctAccessKinds = true); 30 31 // Bounds inference caching the analysis. The MemDependencyChecker must already 32 // have been run. 33 TORCH_API BoundsInfo getInferredBounds( 34 analysis::MemDependencyChecker& analyzer, 35 const StmtPtr& s, 36 bool distinctAccessKinds = true); 37 TORCH_API BoundsInfo getInferredBounds( 38 analysis::MemDependencyChecker& analyzer, 39 const ExprPtr& e, 40 bool distinctAccessKinds = true); 41 42 TORCH_API void printBoundsInfo(const BoundsInfo& v); 43 44 TORCH_API std::vector<ExprPtr> getBoundExtents( 45 const std::vector<TensorAccessBoundsInfo>& infos); 46 47 // The kind of dependency found, in increasing order of exclusivity. 48 enum class HazardKind { 49 ReadAfterWrite, 50 WriteAfterRead, 51 WriteAfterWrite, 52 NoDependency, 53 }; 54 TORCH_API HazardKind getPotentialHazards( 55 analysis::MemDependencyChecker& analyzer, 56 const StmtPtr& A, 57 const StmtPtr& B); 58 59 // Returns true if there is a conflicting overlap between accesses in 60 // statements A and B. A conflicting overlap is an overlap in buffer accesses 61 // where at least one of the accesses is a Store. 62 TORCH_API bool hasConflictingOverlap( 63 analysis::MemDependencyChecker& analyzer, 64 const StmtPtr& A, 65 const StmtPtr& B); 66 // Same as above, between accesses in stores S1 and S2. 67 TORCH_API bool isOverlapping( 68 analysis::MemDependencyChecker& analyzer, 69 const StorePtr& S1, 70 const StorePtr& S2); 71 // Same as above, between accesses in store S and load L. 72 TORCH_API bool isOverlapping( 73 analysis::MemDependencyChecker& analyzer, 74 const StorePtr& S, 75 const LoadPtr& L); 76 77 } // namespace tensorexpr 78 } // namespace jit 79 } // namespace torch 80