xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/transforms/runtime/specialization.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_RUNTIME_SPECIALIZATION_H_
17 #define XLA_MLIR_RUNTIME_SPECIALIZATION_H_
18 
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/IR/Types.h"  // from @llvm-project
22 #include "tensorflow/compiler/xla/runtime/arguments.h"
23 #include "tensorflow/compiler/xla/runtime/constraints.h"
24 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
25 
26 namespace xla {
27 namespace runtime {
28 
29 // TODO(ezhulenev): A lot of specialization code is written with an assumption
30 // that we can only specialize Tensor arguments. Make this extendable
31 // to support user-defined types and user-defined specializations.
32 
33 // Symbolic shape attached to the argument as a dense i64 array attribute.
34 // TODO(ezhulenev): Change symbolic shape attribute type to match the comment.
35 constexpr const char* kSymbolicShapeAttrName = "rt.symbolic_shape";
36 
37 // Listener class to control notifications during specialization.
38 struct SpecializationListener {
~SpecializationListenerSpecializationListener39   virtual ~SpecializationListener() {}
40 
41   // Called at the end of module specialization.
42   // - 'operands' is a reference to the specialized operands' types.
43   // - `attrs` is a list of attributes attached to operands.
notifyModuleSpecializedSpecializationListener44   virtual void notifyModuleSpecialized(
45       llvm::ArrayRef<mlir::Type> operands,
46       llvm::ArrayRef<mlir::DictionaryAttr> attrs) const {}
47 
48   // Called once for every value-specialized argument.
notifyValueSpecializedSpecializationListener49   virtual void notifyValueSpecialized(unsigned index, mlir::Type type,
50                                       mlir::Attribute value) const {}
51 };
52 
53 // Specializes function to the runtime arguments:
54 //
55 // - updates all unknown dimensions according to the resolved symbolic shapes
56 // - attaches symbolic shape attribute to the operands
57 // - for value-specialized operands sinks small constants into the function body
58 //
59 // Returns error if arguments are not compatible with the function signature.
60 llvm::Error SpecializeFunction(
61     mlir::func::FuncOp func, ArgumentsRef arguments,
62     llvm::ArrayRef<SymbolicShapesResolver::SymbolicShape> symbolic_shapes,
63     llvm::ArrayRef<ArgumentConstraint> constraints,
64     const SpecializationListener* listener = nullptr);
65 
66 }  // namespace runtime
67 }  // namespace xla
68 
69 #endif  // XLA_MLIR_RUNTIME_SPECIALIZATION_H_
70