1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_RUNTIME_SPECIALIZATION_H_ 17 #define XLA_MLIR_RUNTIME_SPECIALIZATION_H_ 18 19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 21 #include "mlir/IR/Types.h" // from @llvm-project 22 #include "tensorflow/compiler/xla/runtime/arguments.h" 23 #include "tensorflow/compiler/xla/runtime/constraints.h" 24 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h" 25 26 namespace xla { 27 namespace runtime { 28 29 // TODO(ezhulenev): A lot of specialization code is written with an assumption 30 // that we can only specialize Tensor arguments. Make this extendable 31 // to support user-defined types and user-defined specializations. 32 33 // Symbolic shape attached to the argument as a dense i64 array attribute. 34 // TODO(ezhulenev): Change symbolic shape attribute type to match the comment. 35 constexpr const char* kSymbolicShapeAttrName = "rt.symbolic_shape"; 36 37 // Listener class to control notifications during specialization. 38 struct SpecializationListener { ~SpecializationListenerSpecializationListener39 virtual ~SpecializationListener() {} 40 41 // Called at the end of module specialization. 42 // - 'operands' is a reference to the specialized operands' types. 43 // - `attrs` is a list of attributes attached to operands. notifyModuleSpecializedSpecializationListener44 virtual void notifyModuleSpecialized( 45 llvm::ArrayRef<mlir::Type> operands, 46 llvm::ArrayRef<mlir::DictionaryAttr> attrs) const {} 47 48 // Called once for every value-specialized argument. notifyValueSpecializedSpecializationListener49 virtual void notifyValueSpecialized(unsigned index, mlir::Type type, 50 mlir::Attribute value) const {} 51 }; 52 53 // Specializes function to the runtime arguments: 54 // 55 // - updates all unknown dimensions according to the resolved symbolic shapes 56 // - attaches symbolic shape attribute to the operands 57 // - for value-specialized operands sinks small constants into the function body 58 // 59 // Returns error if arguments are not compatible with the function signature. 60 llvm::Error SpecializeFunction( 61 mlir::func::FuncOp func, ArgumentsRef arguments, 62 llvm::ArrayRef<SymbolicShapesResolver::SymbolicShape> symbolic_shapes, 63 llvm::ArrayRef<ArgumentConstraint> constraints, 64 const SpecializationListener* listener = nullptr); 65 66 } // namespace runtime 67 } // namespace xla 68 69 #endif // XLA_MLIR_RUNTIME_SPECIALIZATION_H_ 70