xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_op_kernel.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
20 #include "tensorflow/compiler/tf2xla/xla_context.h"
21 #include "tensorflow/compiler/tf2xla/xla_expression.h"
22 #include "tensorflow/compiler/tf2xla/xla_resource.h"
23 #include "tensorflow/compiler/xla/client/value_inference.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 
32 class XlaOpKernelContext;
33 
34 // Implementations of operators that generate XLA code should usually subclass
35 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
36 // an XlaOpKernel produces and consumes symbolic values during compilation.
37 //
38 // See the comments in xla_context.h for more details.
39 class XlaOpKernel : public OpKernel {
40  public:
41   explicit XlaOpKernel(OpKernelConstruction* construction);
42 
43   // Subclasses should implement Compile(), much as standard OpKernels implement
44   // Compute().
45   virtual void Compile(XlaOpKernelContext* context) = 0;
46 
47  private:
48   void Compute(OpKernelContext* context) final;
49 };
50 
51 // The context passed to the Compile() method of XlaOpKernel. An
52 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
53 // implementing operators that perform symbolic execution as part of the XLA
54 // compiler. The key difference is that XlaOpKernelContext produces and consumes
55 // data as XLA computations, rather than as standard Tensors.
56 //
57 // Under the hood, symbolic execution communicates using special Tensors that
58 // wrap XlaExpression objects, however this is an implementation detail that
59 // this class hides. The *only* correct way to allocate a Tensor during
60 // compilation is using the XlaOpKernelContext methods, since they ensure there
61 // is a valid XlaExpression backing the tensor. No Op should ever call
62 // allocate_output or allocate_temp directly on the underlying OpKernelContext.
63 class XlaOpKernelContext {
64  public:
65   explicit XlaOpKernelContext(OpKernelContext* context);
66 
67   XlaContext* xla_context() const;
68 
69   // Returns the XLA XlaBuilder containing the output of compilation.
70   xla::XlaBuilder* builder() const;
71 
72   xla::ValueInference& value_inference();
73 
74   // Inputs
75 
76   // Returns the number of inputs to the operator.
num_inputs()77   int num_inputs() const { return context_->num_inputs(); }
78 
79   // Returns the type of input `index`.
80   DataType input_type(int index) const;
81 
82   // Returns the type of input `name`.
83   DataType InputType(absl::string_view name);
84 
85   // Returns the type of input `index` as an xla::PrimitiveType. If the type
86   // is not representable as an XLA type, sets an error status and returns
87   // xla::PRIMITIVE_TYPE_INVALID.
88   xla::PrimitiveType input_xla_type(int index);
89 
90   // Returns the type of input `name` as an xla::PrimitiveType. If the type
91   // is not representable as an XLA type, sets an error status and returns
92   // xla::PRIMITIVE_TYPE_INVALID.
93   xla::PrimitiveType InputXlaType(absl::string_view name);
94 
95   // Returns the shape of input at `index` or input the given `name`. Note that
96   // in case the shape of the input is not static, then the returned shape has
97   // bounds as the dimension size instead of having unknown dimensions. Use
98   // InputXlaShape instead that provides shapes with dynamism information.
99   //
100   ABSL_DEPRECATED(
101       "Prefer InputXlaShape which handles dynamic shapes accurately.")
102   TensorShape InputShape(int index);
103   ABSL_DEPRECATED(
104       "Prefer InputXlaShape which handles dynamic shapes accurately.")
105   TensorShape InputShape(absl::string_view name);
106 
107   // Returns input `index` as a XlaOp. Unlike
108   // OpKernelContext::Input returns a symbolic value rather than a concrete
109   // Tensor.
110   xla::XlaOp Input(int index);
111   // Returns input `name` as a XlaOp.
112   xla::XlaOp Input(absl::string_view name);
113 
114   // Returns the xla input shape for a given index.
115   StatusOr<xla::Shape> InputXlaShape(int index);
116   StatusOr<xla::Shape> InputXlaShape(absl::string_view name);
117 
118   // Returns true if all inputs are the same shape, otherwise sets the
119   // status to a non-OK value and returns false.
120   // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
121   bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
122 
123   // Returns the named list-valued immutable input in "list", as
124   // defined in the OpDef.  If the named output is not list-valued,
125   // returns a one-element list.
126   Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
127                    std::vector<TensorShape>* shapes);
128   // Evaluates input and returns their dynamism vector in a vector of
129   // predicates.
130   Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out);
131   Status ResolveInputDynamismIntoPred(int index, bool* out);
132   Status ResolveInputDynamismIntoPredVector(absl::string_view name,
133                                             std::vector<bool>* out);
134   Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out);
135 
136   Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal);
137   Status ResolveInputDynamism(absl::string_view name,
138                               xla::Literal* dynamism_literal);
139 
140   Status ResolveInputDynamismReshaped(int index,
141                                       absl::Span<const int64_t> new_dims,
142                                       xla::Literal* dynamism_literal);
143   // Helper methods for constant inputs.
144 
145   // Evaluates input `index` and stores it in `*constant_literal`. If the
146   // expression cannot be evaluated, e.g., because it depends on unbound
147   // parameters, returns a non-OK status. This function can also be used to
148   // infer constant input upper or lower bounds, by changing the `mode`
149   // parameter.
150   Status ConstantInput(
151       int index, xla::Literal* constant_literal,
152       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
153   Status ConstantInput(
154       absl::string_view name, xla::Literal* constant_literal,
155       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
156 
157   // Converts a constant scalar int32 or int64 tensor into an int64.
158   Status ConstantInputAsIntScalar(
159       int index, int64_t* out,
160       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
161   Status ConstantInputAsIntScalar(
162       absl::string_view name, int64_t* out,
163       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
164 
165   // Converts a constant scalar float32 or float64 tensor into a float64.
166   Status ConstantInputAsFloatScalar(
167       int index, double* out,
168       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
169 
170   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
171   Status ConstantInputAsIntVector(
172       int index, std::vector<int64_t>* out,
173       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
174   Status ConstantInputAsIntVector(
175       absl::string_view name, std::vector<int64_t>* out,
176       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
177 
178   // Reshapes and converts a constant int32 or int64 tensor into a vector of
179   // int64s.
180   Status ConstantInputReshapedToIntVector(
181       int index, std::vector<int64_t>* out,
182       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
183   Status ConstantInputReshapedToIntVector(
184       absl::string_view name, std::vector<int64_t>* out,
185       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
186 
187   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
188   Status ConstantInputAsInt64Literal(
189       int index, xla::Literal* out,
190       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
191   Status ConstantInputAsInt64Literal(
192       absl::string_view name, xla::Literal* out,
193       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
194 
195   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
196   Status ConstantInputAsShape(
197       int index, TensorShape* shape,
198       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
199 
200   // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1
201   // into a PartialTensorShape.
202   Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape);
203 
204   // Returns the named list-valued immutable input in "list", as
205   // defined in the OpDef.  If the named output is not list-valued,
206   // returns a one-element list.
207   Status ConstantInputList(
208       absl::string_view name, std::vector<xla::Literal>* outputs,
209       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
210 
211   // Returns the Tensor representation of the constant input.
212   StatusOr<Tensor> ConstantInputTensor(
213       int index,
214       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
215 
216   // Returns an XlaExpression describing the value of 'index'.
217   const XlaExpression& InputExpression(int index);
218   const XlaExpression& InputExpression(absl::string_view name);
219 
220   // Outputs
221 
num_outputs()222   int num_outputs() const { return context_->num_outputs(); }
expected_output_dtype(int index)223   DataType expected_output_dtype(int index) const {
224     return context_->expected_output_dtype(index);
225   }
226 
227   // Returns the type of output `index` as an xla::PrimitiveType. If the type
228   // is not representable as an XLA type, sets an error status and returns
229   // xla::PRIMITIVE_TYPE_INVALID.
230   xla::PrimitiveType output_xla_type(int index);
231 
232   // Sets output `index` to the XlaOp `handle`.
233   // All outputs should be set using SetOutput and SetConstantOutput, not
234   // via the underlying OpKernelContext.
235   void SetOutput(int index, const xla::XlaOp& handle);
236 
237   // Sets output `index` to compile-time constant `host_tensor`, where
238   // `host_tensor` is a tensor in host memory. It is preferable to use
239   // SetConstantOutput where possible.
240   void SetConstantOutput(int index, const Tensor& host_tensor);
241 
242   // Returns an XlaExpression describing the value of 'index'.
243   void SetOutputExpression(int index, const XlaExpression& expression);
244 
245   // Sets output `index` to the Tensor List `handle`.
246   void SetTensorListOutput(int index, const xla::XlaOp& handle);
247 
248   // Status handling.
SetStatus(const Status & status)249   void SetStatus(const Status& status) { context_->SetStatus(status); }
status()250   Status status() { return context_->status(); }
251 
252   // Variables
253 
254   // Sets `*resource` to the resource associated with input `index`.
255   Status GetResourceInput(int index, XlaResource** resource);
256 
257   // Sets output `index` to be a reference to resource `resource`.
258   void SetResourceOutput(int index, XlaResource* resource);
259 
260   // Sets `*type` and `*shape` to the current type and shape of a variable's
261   // value.
262   Status GetVariableTypeAndShape(int index, DataType* type,
263                                  TensorShape* shape) const;
264 
265   // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
266   // returns "-1", this is useful when the underlying ops expect explicit
267   // dynamic index like reshape.
set_dynamic_dimension_is_minus_one(bool value)268   void set_dynamic_dimension_is_minus_one(bool value) {
269     dynamic_dimension_is_minus_one_ = value;
270   }
271 
dynamic_dimension_is_minus_one()272   bool dynamic_dimension_is_minus_one() const {
273     return dynamic_dimension_is_minus_one_;
274   }
275 
is_dynamic_dimension(int64_t dim_size)276   bool is_dynamic_dimension(int64_t dim_size) { return dim_size == -1; }
277 
278   // Reads the current value of the resource variable referred to by input
279   // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
280   // variable. Returns an error if the variable has not been initialized, or if
281   // its type does not match `type`.
282   Status ReadVariableInput(int index, DataType type, TensorShape* shape,
283                            xla::XlaOp* value);
284   // Reads the current value of the resource variable referred to by input
285   // `name`.
286   Status ReadVariableInput(absl::string_view name, DataType type,
287                            TensorShape* shape, xla::XlaOp* value);
288 
289   // Assigns the value `handle` to the variable referenced by input
290   // `input_index`. The variable must be of `type`. Returns an error if the
291   // variable has been initialized with a different type or with a
292   // different shape.
293   Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
294   // Assigns the value `handle` to the variable referenced by input `name`.
295   Status AssignVariable(absl::string_view name, DataType type,
296                         xla::XlaOp handle);
297 
298   // Helper routines for the OP_REQUIRES macros
299   void CtxFailure(const Status& s);
300   void CtxFailureWithWarning(const Status& s);
301   void CtxFailure(const char* file, int line, const Status& s);
302   void CtxFailureWithWarning(const char* file, int line, const Status& s);
303 
304   // If this kernel invocation is within a function execution,
305   // call_frame() returns the call frame for the function call.
call_frame()306   CallFrameInterface* call_frame() const { return context_->call_frame(); }
307 
function_library()308   FunctionLibraryRuntime* function_library() const {
309     return context_->function_library();
310   }
311 
op_kernel()312   const OpKernel& op_kernel() const { return context_->op_kernel(); }
313 
314   // Returns the underlying OpKernelContext. Use rarely.
op_kernel_context()315   OpKernelContext* op_kernel_context() const { return context_; }
316 
317   // Returns the XlaCompiler that is performing the compilation. Used for, e.g.,
318   // While to compile nested computations.
319   XlaCompiler* compiler() const;
320 
321   // TODO(phawkins): find a better home for these helpers.
322 
323   // Gets an XLA lambda to compute Max. This is cached in the
324   // XlaContext since it may be used by multiple Ops. There is a
325   // separate specialization of the computation for each DataType.
326   const xla::XlaComputation* GetOrCreateMax(const DataType type);
327 
328   // Gets an XLA lambda to compute Min. This is cached in the
329   // XlaContext since it may be used by multiple Ops. There is a
330   // separate specialization of the computation for each DataType.
331   const xla::XlaComputation* GetOrCreateMin(const DataType type);
332 
333   // Gets an XLA lambda to compute Add. This is cached in the
334   // XlaContext since it may be used by multiple Ops. There is a
335   // separate specialization of the computation for each DataType.
336   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
337 
338   // Gets an XLA lambda to compute Mul. This is cached in the
339   // XlaContext since it may be used by multiple Ops. There is a
340   // separate specialization of the computation for each DataType.
341   const xla::XlaComputation* GetOrCreateMul(const DataType type);
342 
343   // Returns stack trace encoded as a string at a given module, or an empty
344   // string if none found.
345   std::string StackTrace() const;
346 
347  private:
348   // Returns the tensor of input `name`.
349   const Tensor& GetInputTensorByName(absl::string_view name);
350   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
351   // InputShape(index), and stores it in `*constant_literal`. If the input
352   // cannot be evaluated, e.g., because it depends on unbound parameters,
353   // returns a non-Ok status. If InputShape(index).num_elements() !=
354   // new_shape.num_elements(), returns an error status.
355   Status ConstantInputReshaped(
356       int index, absl::Span<const int64_t> new_dims,
357       xla::Literal* constant_literal,
358       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue);
359 
360   OpKernelContext* const context_;
361   bool dynamic_dimension_is_minus_one_;
362   xla::ValueInference value_inference_;
363 };
364 
365 }  // namespace tensorflow
366 
367 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
368