1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 20 #include "tensorflow/compiler/tf2xla/xla_context.h" 21 #include "tensorflow/compiler/tf2xla/xla_expression.h" 22 #include "tensorflow/compiler/tf2xla/xla_resource.h" 23 #include "tensorflow/compiler/xla/client/value_inference.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/client/xla_computation.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 32 class XlaOpKernelContext; 33 34 // Implementations of operators that generate XLA code should usually subclass 35 // XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, 36 // an XlaOpKernel produces and consumes symbolic values during compilation. 37 // 38 // See the comments in xla_context.h for more details. 39 class XlaOpKernel : public OpKernel { 40 public: 41 explicit XlaOpKernel(OpKernelConstruction* construction); 42 43 // Subclasses should implement Compile(), much as standard OpKernels implement 44 // Compute(). 45 virtual void Compile(XlaOpKernelContext* context) = 0; 46 47 private: 48 void Compute(OpKernelContext* context) final; 49 }; 50 51 // The context passed to the Compile() method of XlaOpKernel. An 52 // XlaOpKernelContext is a variant of the standard OpKernel class, tailored for 53 // implementing operators that perform symbolic execution as part of the XLA 54 // compiler. The key difference is that XlaOpKernelContext produces and consumes 55 // data as XLA computations, rather than as standard Tensors. 56 // 57 // Under the hood, symbolic execution communicates using special Tensors that 58 // wrap XlaExpression objects, however this is an implementation detail that 59 // this class hides. The *only* correct way to allocate a Tensor during 60 // compilation is using the XlaOpKernelContext methods, since they ensure there 61 // is a valid XlaExpression backing the tensor. No Op should ever call 62 // allocate_output or allocate_temp directly on the underlying OpKernelContext. 63 class XlaOpKernelContext { 64 public: 65 explicit XlaOpKernelContext(OpKernelContext* context); 66 67 XlaContext* xla_context() const; 68 69 // Returns the XLA XlaBuilder containing the output of compilation. 70 xla::XlaBuilder* builder() const; 71 72 xla::ValueInference& value_inference(); 73 74 // Inputs 75 76 // Returns the number of inputs to the operator. num_inputs()77 int num_inputs() const { return context_->num_inputs(); } 78 79 // Returns the type of input `index`. 80 DataType input_type(int index) const; 81 82 // Returns the type of input `name`. 83 DataType InputType(absl::string_view name); 84 85 // Returns the type of input `index` as an xla::PrimitiveType. If the type 86 // is not representable as an XLA type, sets an error status and returns 87 // xla::PRIMITIVE_TYPE_INVALID. 88 xla::PrimitiveType input_xla_type(int index); 89 90 // Returns the type of input `name` as an xla::PrimitiveType. If the type 91 // is not representable as an XLA type, sets an error status and returns 92 // xla::PRIMITIVE_TYPE_INVALID. 93 xla::PrimitiveType InputXlaType(absl::string_view name); 94 95 // Returns the shape of input at `index` or input the given `name`. Note that 96 // in case the shape of the input is not static, then the returned shape has 97 // bounds as the dimension size instead of having unknown dimensions. Use 98 // InputXlaShape instead that provides shapes with dynamism information. 99 // 100 ABSL_DEPRECATED( 101 "Prefer InputXlaShape which handles dynamic shapes accurately.") 102 TensorShape InputShape(int index); 103 ABSL_DEPRECATED( 104 "Prefer InputXlaShape which handles dynamic shapes accurately.") 105 TensorShape InputShape(absl::string_view name); 106 107 // Returns input `index` as a XlaOp. Unlike 108 // OpKernelContext::Input returns a symbolic value rather than a concrete 109 // Tensor. 110 xla::XlaOp Input(int index); 111 // Returns input `name` as a XlaOp. 112 xla::XlaOp Input(absl::string_view name); 113 114 // Returns the xla input shape for a given index. 115 StatusOr<xla::Shape> InputXlaShape(int index); 116 StatusOr<xla::Shape> InputXlaShape(absl::string_view name); 117 118 // Returns true if all inputs are the same shape, otherwise sets the 119 // status to a non-OK value and returns false. 120 // Usage: if (!context->ValidateInputsAreSameShape(this)) return; 121 bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; 122 123 // Returns the named list-valued immutable input in "list", as 124 // defined in the OpDef. If the named output is not list-valued, 125 // returns a one-element list. 126 Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles, 127 std::vector<TensorShape>* shapes); 128 // Evaluates input and returns their dynamism vector in a vector of 129 // predicates. 130 Status ResolveInputDynamismIntoPredVector(int index, std::vector<bool>* out); 131 Status ResolveInputDynamismIntoPred(int index, bool* out); 132 Status ResolveInputDynamismIntoPredVector(absl::string_view name, 133 std::vector<bool>* out); 134 Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out); 135 136 Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal); 137 Status ResolveInputDynamism(absl::string_view name, 138 xla::Literal* dynamism_literal); 139 140 Status ResolveInputDynamismReshaped(int index, 141 absl::Span<const int64_t> new_dims, 142 xla::Literal* dynamism_literal); 143 // Helper methods for constant inputs. 144 145 // Evaluates input `index` and stores it in `*constant_literal`. If the 146 // expression cannot be evaluated, e.g., because it depends on unbound 147 // parameters, returns a non-OK status. This function can also be used to 148 // infer constant input upper or lower bounds, by changing the `mode` 149 // parameter. 150 Status ConstantInput( 151 int index, xla::Literal* constant_literal, 152 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 153 Status ConstantInput( 154 absl::string_view name, xla::Literal* constant_literal, 155 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 156 157 // Converts a constant scalar int32 or int64 tensor into an int64. 158 Status ConstantInputAsIntScalar( 159 int index, int64_t* out, 160 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 161 Status ConstantInputAsIntScalar( 162 absl::string_view name, int64_t* out, 163 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 164 165 // Converts a constant scalar float32 or float64 tensor into a float64. 166 Status ConstantInputAsFloatScalar( 167 int index, double* out, 168 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 169 170 // Converts a constant 1D int32 or int64 tensor into a vector of int64s. 171 Status ConstantInputAsIntVector( 172 int index, std::vector<int64_t>* out, 173 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 174 Status ConstantInputAsIntVector( 175 absl::string_view name, std::vector<int64_t>* out, 176 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 177 178 // Reshapes and converts a constant int32 or int64 tensor into a vector of 179 // int64s. 180 Status ConstantInputReshapedToIntVector( 181 int index, std::vector<int64_t>* out, 182 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 183 Status ConstantInputReshapedToIntVector( 184 absl::string_view name, std::vector<int64_t>* out, 185 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 186 187 // Converts a constant int32 or int64 Tensor into an xla int64 Literal. 188 Status ConstantInputAsInt64Literal( 189 int index, xla::Literal* out, 190 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 191 Status ConstantInputAsInt64Literal( 192 absl::string_view name, xla::Literal* out, 193 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 194 195 // Converts a constant 1D int32 or int64 tensor into a TensorShape. 196 Status ConstantInputAsShape( 197 int index, TensorShape* shape, 198 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 199 200 // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 201 // into a PartialTensorShape. 202 Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); 203 204 // Returns the named list-valued immutable input in "list", as 205 // defined in the OpDef. If the named output is not list-valued, 206 // returns a one-element list. 207 Status ConstantInputList( 208 absl::string_view name, std::vector<xla::Literal>* outputs, 209 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 210 211 // Returns the Tensor representation of the constant input. 212 StatusOr<Tensor> ConstantInputTensor( 213 int index, 214 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 215 216 // Returns an XlaExpression describing the value of 'index'. 217 const XlaExpression& InputExpression(int index); 218 const XlaExpression& InputExpression(absl::string_view name); 219 220 // Outputs 221 num_outputs()222 int num_outputs() const { return context_->num_outputs(); } expected_output_dtype(int index)223 DataType expected_output_dtype(int index) const { 224 return context_->expected_output_dtype(index); 225 } 226 227 // Returns the type of output `index` as an xla::PrimitiveType. If the type 228 // is not representable as an XLA type, sets an error status and returns 229 // xla::PRIMITIVE_TYPE_INVALID. 230 xla::PrimitiveType output_xla_type(int index); 231 232 // Sets output `index` to the XlaOp `handle`. 233 // All outputs should be set using SetOutput and SetConstantOutput, not 234 // via the underlying OpKernelContext. 235 void SetOutput(int index, const xla::XlaOp& handle); 236 237 // Sets output `index` to compile-time constant `host_tensor`, where 238 // `host_tensor` is a tensor in host memory. It is preferable to use 239 // SetConstantOutput where possible. 240 void SetConstantOutput(int index, const Tensor& host_tensor); 241 242 // Returns an XlaExpression describing the value of 'index'. 243 void SetOutputExpression(int index, const XlaExpression& expression); 244 245 // Sets output `index` to the Tensor List `handle`. 246 void SetTensorListOutput(int index, const xla::XlaOp& handle); 247 248 // Status handling. SetStatus(const Status & status)249 void SetStatus(const Status& status) { context_->SetStatus(status); } status()250 Status status() { return context_->status(); } 251 252 // Variables 253 254 // Sets `*resource` to the resource associated with input `index`. 255 Status GetResourceInput(int index, XlaResource** resource); 256 257 // Sets output `index` to be a reference to resource `resource`. 258 void SetResourceOutput(int index, XlaResource* resource); 259 260 // Sets `*type` and `*shape` to the current type and shape of a variable's 261 // value. 262 Status GetVariableTypeAndShape(int index, DataType* type, 263 TensorShape* shape) const; 264 265 // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension 266 // returns "-1", this is useful when the underlying ops expect explicit 267 // dynamic index like reshape. set_dynamic_dimension_is_minus_one(bool value)268 void set_dynamic_dimension_is_minus_one(bool value) { 269 dynamic_dimension_is_minus_one_ = value; 270 } 271 dynamic_dimension_is_minus_one()272 bool dynamic_dimension_is_minus_one() const { 273 return dynamic_dimension_is_minus_one_; 274 } 275 is_dynamic_dimension(int64_t dim_size)276 bool is_dynamic_dimension(int64_t dim_size) { return dim_size == -1; } 277 278 // Reads the current value of the resource variable referred to by input 279 // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the 280 // variable. Returns an error if the variable has not been initialized, or if 281 // its type does not match `type`. 282 Status ReadVariableInput(int index, DataType type, TensorShape* shape, 283 xla::XlaOp* value); 284 // Reads the current value of the resource variable referred to by input 285 // `name`. 286 Status ReadVariableInput(absl::string_view name, DataType type, 287 TensorShape* shape, xla::XlaOp* value); 288 289 // Assigns the value `handle` to the variable referenced by input 290 // `input_index`. The variable must be of `type`. Returns an error if the 291 // variable has been initialized with a different type or with a 292 // different shape. 293 Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); 294 // Assigns the value `handle` to the variable referenced by input `name`. 295 Status AssignVariable(absl::string_view name, DataType type, 296 xla::XlaOp handle); 297 298 // Helper routines for the OP_REQUIRES macros 299 void CtxFailure(const Status& s); 300 void CtxFailureWithWarning(const Status& s); 301 void CtxFailure(const char* file, int line, const Status& s); 302 void CtxFailureWithWarning(const char* file, int line, const Status& s); 303 304 // If this kernel invocation is within a function execution, 305 // call_frame() returns the call frame for the function call. call_frame()306 CallFrameInterface* call_frame() const { return context_->call_frame(); } 307 function_library()308 FunctionLibraryRuntime* function_library() const { 309 return context_->function_library(); 310 } 311 op_kernel()312 const OpKernel& op_kernel() const { return context_->op_kernel(); } 313 314 // Returns the underlying OpKernelContext. Use rarely. op_kernel_context()315 OpKernelContext* op_kernel_context() const { return context_; } 316 317 // Returns the XlaCompiler that is performing the compilation. Used for, e.g., 318 // While to compile nested computations. 319 XlaCompiler* compiler() const; 320 321 // TODO(phawkins): find a better home for these helpers. 322 323 // Gets an XLA lambda to compute Max. This is cached in the 324 // XlaContext since it may be used by multiple Ops. There is a 325 // separate specialization of the computation for each DataType. 326 const xla::XlaComputation* GetOrCreateMax(const DataType type); 327 328 // Gets an XLA lambda to compute Min. This is cached in the 329 // XlaContext since it may be used by multiple Ops. There is a 330 // separate specialization of the computation for each DataType. 331 const xla::XlaComputation* GetOrCreateMin(const DataType type); 332 333 // Gets an XLA lambda to compute Add. This is cached in the 334 // XlaContext since it may be used by multiple Ops. There is a 335 // separate specialization of the computation for each DataType. 336 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 337 338 // Gets an XLA lambda to compute Mul. This is cached in the 339 // XlaContext since it may be used by multiple Ops. There is a 340 // separate specialization of the computation for each DataType. 341 const xla::XlaComputation* GetOrCreateMul(const DataType type); 342 343 // Returns stack trace encoded as a string at a given module, or an empty 344 // string if none found. 345 std::string StackTrace() const; 346 347 private: 348 // Returns the tensor of input `name`. 349 const Tensor& GetInputTensorByName(absl::string_view name); 350 // Evaluates input `index`, reshapes it to `new_shape` if new_shape != 351 // InputShape(index), and stores it in `*constant_literal`. If the input 352 // cannot be evaluated, e.g., because it depends on unbound parameters, 353 // returns a non-Ok status. If InputShape(index).num_elements() != 354 // new_shape.num_elements(), returns an error status. 355 Status ConstantInputReshaped( 356 int index, absl::Span<const int64_t> new_dims, 357 xla::Literal* constant_literal, 358 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); 359 360 OpKernelContext* const context_; 361 bool dynamic_dimension_is_minus_one_; 362 xla::ValueInference value_inference_; 363 }; 364 365 } // namespace tensorflow 366 367 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ 368