1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TF_TYPE_ATTRIBUTES 17#define TF_TYPE_ATTRIBUTES 18 19include "mlir/IR/AttrTypeBase.td" 20include "mlir/IR/SubElementInterfaces.td" 21include "tensorflow/core/ir/types/dialect.td" 22include "mlir/IR/BuiltinAttributeInterfaces.td" 23 24// Base class for TFType dialect attributes. 25class TFType_Attr<string name, list<Trait> traits = []> 26 : AttrDef<TFTypeDialect, name, traits>; 27 28// LINT.IfChange 29def TFType_FullTypeId : I32EnumAttr<"FullTypeId", "", [ 30 // The default represents an uninitialized values. 31 I32EnumAttrCase<"TFT_UNSET", 0, "unset">, 32 33 // Type symbols. Used to construct more complex type expressions like 34 // algebraic data types. 35 36 // Type variables may serve as placeholder for any other type ID in type 37 // templates. 38 // 39 // Examples: 40 // TFT_DATASET[TFT_VAR["T"]] is a Dataset returning a type indicated by "T". 41 // TFT_TENSOR[TFT_VAR["T"]] is a Tensor of n element type indicated by "T". 42 // TFT_TENSOR[TFT_VAR["T"]], TFT_TENSOR[TFT_VAR["T"]] are two tensors of 43 // identical element types. 44 // TFT_TENSOR[TFT_VAR["P"]], TFT_TENSOR[TFT_VAR["Q"]] are two tensors of 45 // independent element types. 46 // 47 I32EnumAttrCase<"TFT_VAR", 1, "var">, 48 49 // Wildcard type. Describes a parameter of unknown type. In TensorFlow, that 50 // can mean either a "Top" type (accepts any type), or a dynamically typed 51 // object whose type is unknown in context. 52 // Important: "unknown" does not necessarily mean undeterminable! 53 I32EnumAttrCase<"TFT_ANY", 2, "any">, 54 55 // The algebraic product type. This is an algebraic type that may be used just 56 // for logical grouping. Not to confused with TFT_TUPLE which describes a 57 // concrete object of several elements. 58 // 59 // Example: 60 // TFT_DATASET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]]] 61 // is a Dataset producing two tensors, an integer one and a float one. 62 // 63 I32EnumAttrCase<"TFT_PRODUCT", 3, "product">, 64 65 // Represents a named field, with the name stored in the attribute. 66 // 67 // Parametrization: 68 // TFT_NAMED[<type>]{<name>} 69 // * <type> is the type of the field 70 // * <name> is the field name, as string (thpugh can theoretically be an int 71 // as well) 72 // 73 // Example: 74 // TFT_RECORD[ 75 // TFT_NAMED[TFT_TENSOR[TFT_INT32]]{'foo'}, 76 // TFT_NAMED[TFT_TENSOR[TFT_FLOAT32]]{'bar'}, 77 // ] 78 // is a structure with two fields, an int tensor "foo" and a float tensor 79 // "bar". 80 I32EnumAttrCase<"TFT_NAMED", 4, "named">, 81 82 // Template definition. Expands the variables by repeating a template as 83 // arguments of container. 84 // 85 // Parametrization: 86 // TFT_FOR_EACH[<container_type>, <template>, <expansions>] 87 // * <container_type> is the type of the container that the template will be 88 // expanded into 89 // * <template> is any type definition that potentially contains type 90 // variables 91 // * <expansions> is a TFT_VAR and may include more types in the future 92 // 93 // Example: 94 // TFT_FOR_EACH[ 95 // TFT_PRODUCT, 96 // TFT_TENSOR[TFT_VAR["t"]], 97 // TFT_VAR["t"] 98 // ] 99 // will substitute a T = TFT_INT32 to TFT_PRODUCT[TFT_TENSOR[TFT_INT32]] 100 // and a T = (TFT_INT32, TFT_INT64) to 101 // TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_INT64]]. 102 I32EnumAttrCase<"TFT_FOR_EACH", 20, "for_each">, 103 104 // Callable types describe functions and ops. 105 // 106 // Parametrization: 107 // TFT_CALLABLE[<arg type>, <return type>] 108 // * <arg type> is the type of the arguments; TFT_PRODUCT represents 109 // multiple 110 // arguments. 111 // * <return type> is the return type; TFT_PRODUCT represents multiple 112 // return values (that means that callables returning multiple things 113 // don't necessarily return a single tuple). 114 // 115 // Example: 116 // TFT_CALLABLE[ 117 // TFT_ANY, 118 // TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]], 119 // ] 120 // is a callable with unspecified (for now) input arguments, and 121 // two return values of type tensor. 122 // 123 I32EnumAttrCase<"TFT_CALLABLE", 100, "callable">, 124 125 // Concrete type IDs, representing "proper" data types that can describe 126 // runtime TensorFlow objects. 127 128 // The usual Tensor. This is a parametric type. 129 // 130 // Parametrization: 131 // TFT_TENSOR[<element type>, <shape type>] 132 // * <element type> is currently limited to one of the element types 133 // defined below. 134 // * <shape type> is not yet defined, and may only be TFT_UNKNOWN for now. 135 // 136 // A TFT_SHAPE type will be defined in the future. 137 // 138 // Example: 139 // TFT_TENSOR[TFT_INT32, TFT_UNKNOWN] 140 // is a Tensor of int32 element type and unknown shape. 141 // 142 // TODO(mdan): Define TFT_SHAPE and add more examples. 143 I32EnumAttrCase<"TFT_TENSOR", 1000, "tensor">, 144 145 // Array (or tensorflow::TensorList in the variant type registry). 146 // Note: this is not to be confused with the deprecated `TensorArray*` ops 147 // which are not supported by FullType. 148 // This type represents a random-access list whose elements can be 149 // described by a single type. Although immutable, Array is expected to 150 // support efficient mutation semantics (i.e. element update) in the 151 // user-facing API. 152 // The element type may be generic or even TFT_ANY for a heterogenous list. 153 // 154 // Parametrization: 155 // TFT_ARRAY[<element type>] 156 // * <element type> may be any concrete type. 157 // 158 // Examples: 159 // TFT_ARRAY[TFT_TENSOR[TFT_INT32]] is a TensorArray holding int32 Tensors 160 // of any shape. 161 // TFT_ARRAY[TFT_TENSOR[TFT_UNKNOWN]] is a TensorArray holding Tensors of 162 // mixed element types. 163 // TFT_ARRAY[TFT_UNKNOWN] is a TensorArray holding any element type. 164 // TFT_ARRAY[] is equivalent to TFT_ARRAY[TFT_UNKNOWN]. 165 // TFT_ARRAY[TFT_ARRAY[]] is an array or arrays (of unknown types). 166 I32EnumAttrCase<"TFT_ARRAY", 1001, "array">, 167 168 // Optional (or tensorflow::OptionalVariant in the variant type registry). 169 // This type represents a value that may either hold an element of a single 170 // specified type, or nothing at all. 171 // 172 // Parametrization: 173 // TFT_OPTIONAL[<element type>] 174 // * <element type> may be any concrete type. 175 // 176 // Examples: 177 // TFT_OPTIONAL[TFT_TENSOR[TFT_INT32]] is an Optional holding an int32 178 // Tensor of any shape. 179 I32EnumAttrCase<"TFT_OPTIONAL", 1002, "optional">, 180 181 // Literal types describe compile-time constant values. 182 // Literal types may also participate in dependent types. 183 // 184 // Parametrization: 185 // TFT_LITERAL[<value type>]{<value>} 186 // * <value type> may be any concrete type compatible that can hold <value> 187 // * <value> is the type's attribute, and holds the actual literal value 188 // 189 // Examples: 190 // TFT_LITERAL[TFT_INT32]{1} is the compile-time constant 1. 191 I32EnumAttrCase<"TFT_LITERAL", 1003, "literal">, 192 193 // Encoding types describe a value of a certain type, encoded as a different 194 // type. 195 // 196 // Parametrization: 197 // TFT_ENCODED[<encoded type>, <encoding type>] 198 // * <encoded type> may be any type 199 // * <encoding type> may be any type 200 // 201 // Examples: 202 // TFT_ENCODING[TFT_INT32, TFT_STRING] is an integer encoded as string. 203 I32EnumAttrCase<"TFT_ENCODED", 1004, "encoded">, 204 205 // Type attributes. These always appear in the parametrization of a type, 206 // never alone. For example, there is no such thing as a "bool" TensorFlow 207 // object (for now). 208 209 // The bool element type. 210 // TODO(mdan): Quantized types, legacy representations (e.g. ref) 211 I32EnumAttrCase<"TFT_BOOL", 200, "bool">, 212 // Integer element types. 213 I32EnumAttrCase<"TFT_UINT8", 201, "uint8">, 214 I32EnumAttrCase<"TFT_UINT16", 202, "uint16">, 215 I32EnumAttrCase<"TFT_UINT32", 203, "uint32">, 216 I32EnumAttrCase<"TFT_UINT64", 204, "uint64">, 217 I32EnumAttrCase<"TFT_INT8", 205, "int8">, 218 I32EnumAttrCase<"TFT_INT16", 206, "int16">, 219 I32EnumAttrCase<"TFT_INT32", 207, "int32">, 220 I32EnumAttrCase<"TFT_INT64", 208, "int64">, 221 // Floating-point element types. 222 I32EnumAttrCase<"TFT_HALF", 209, "half">, 223 I32EnumAttrCase<"TFT_FLOAT", 210, "float">, 224 I32EnumAttrCase<"TFT_DOUBLE", 211, "double">, 225 I32EnumAttrCase<"TFT_BFLOAT16", 215, "bfloat16">, 226 // TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead? 227 I32EnumAttrCase<"TFT_COMPLEX64", 212, "complex64">, 228 I32EnumAttrCase<"TFT_COMPLEX128", 213, "complex128">, 229 // The string element type. 230 I32EnumAttrCase<"TFT_STRING", 214, "string">, 231 232 // Other types that we don't know yet whether they will become part of the 233 // core type system or be consisdered third-party (and consequently moved to 234 // user-defined type mechanisms). Presently, they are effectively in the core 235 // type system, because key compilation passes like Placer account for their 236 // existence. 237 238 // Datasets created by tf.data ops and APIs. Datasets have generator/iterable 239 // semantics, that is, one can construct an iterator from them. Like 240 // Array, they are considered to return elements that can be described 241 // by a single type. Unlike Array, they do not support random access or 242 // mutation, and can potentially produce an infinite number of elements. 243 // A datasets can produce logical structures (e.g. multiple elements). This 244 // is expressed using TFT_PRODUCT. 245 // 246 // 247 // Parametrization: TFT_ARRAY[<element type>]. 248 // * <element type> may be a concrete type or a type symbol. It represents 249 // the data type of the elements produced by the dataset. 250 // 251 // Examples: 252 // TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32 253 // Tensors of unknown shape. 254 // TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is 255 // a Dataset producing pairs of Tensors, one integer and one float. 256 // Note: The high ID number is to prepare for the eventuality that Datasets 257 // will be supported by user types in the future. 258 I32EnumAttrCase<"TFT_DATASET", 10102, "dataset">, 259 260 // A ragged tensor created by tf.ragged ops and APIs. 261 // 262 // Parametrization: TFT_RAGGED[<element_type>]. 263 I32EnumAttrCase<"TFT_RAGGED", 10103, "ragged">, 264 265 // Iterators created by tf.data ops and APIs. Very similar to Datasets, except 266 // they are mutable. 267 // 268 // 269 // Parametrization: TFT_ITERATOR[<element type>]. 270 // * <element type> may be a concrete type or a type symbol. It represents 271 // the data type of the elements produced by the dataset. 272 I32EnumAttrCase<"TFT_ITERATOR", 10104, "iterator">, 273 274 // A mutex lock tensor, produced by tf.raw_ops.MutexLock. 275 // Unlike strict execution models, where ownership of a lock is denoted by 276 // "running after the lock has been acquired", in non-strict mode, lock 277 // ownership is in the true sense: "the op argument representing the lock is 278 // available". 279 // Mutex locks are the dynamic counterpart of control dependencies. 280 // TODO(mdan): Properly document this thing. 281 // 282 // Parametrization: TFT_MUTEX_LOCK[]. 283 I32EnumAttrCase<"TFT_MUTEX_LOCK", 10202, "mutex_lock">, 284 285 // The equivalent of a Tensor with DT_VARIANT dtype, kept here to simplify 286 // translation. This type should not normally appear after type inference. 287 // Note that LEGACY_VARIANT != ANY: TENSOR[INT32] is a subtype of ANY, but is 288 // not a subtype of LEGACY_VARIANT. 289 I32EnumAttrCase<"TFT_LEGACY_VARIANT", 10203, "legacy_variant"> 290]> { 291 let cppNamespace = "::mlir::tf_type"; 292 string cppType = "int32_t"; 293 let genSpecializedAttr = 0; 294} 295 296def TFType_FullTypeArgsAttr : ArrayRefParameter<"::mlir::tf_type::FullTypeAttr", "args">; 297 298def TFType_FullTypeAttrAttr : Attr<Or<[StrAttr.predicate, SI64Attr.predicate]>, 299 "FullType literal attr"> { 300 let storageType = "Attribute"; 301 let returnType = "Attribute"; 302 let convertFromStorage = "$_self"; 303 let constBuilderCall = "$0"; 304 string cppType = "Attribute"; 305 let isOptional = 1; 306} 307 308def TFType_FullTypeAttr : AttrDef<TFTypeDialect, "FullType"> { 309 let summary = "FullType"; 310 let parameters = (ins 311 TFType_FullTypeId:$type_id, 312 TFType_FullTypeArgsAttr:$args, 313 TFType_FullTypeAttrAttr:$attr 314 ); 315 let mnemonic = "full_type"; 316 let hasCustomAssemblyFormat = 1; 317 // Format is effectively: type_id ('<' $args^ '>')? ($attr?) 318} 319// LINT.ThenChange(../../framework/full_type.proto) 320 321//===----------------------------------------------------------------------===// 322// FuncAttr 323//===----------------------------------------------------------------------===// 324 325def TFType_FuncAttr : TFType_Attr<"Func", [ 326 DeclareAttrInterfaceMethods<SubElementAttrInterface, 327 ["replaceImmediateSubElements"]> 328 ]> { 329 let mnemonic = "func"; 330 let summary = "Models the `AttrValue.value.func` proto attribute value as a " 331 "pair of SymbolRef and DictionaryAttr"; 332 let description = [{ 333 This attributes matches the protobuf `AttrValue.value.func` with a 334 `SymbolRefAttr`, for the `NameAttrList.name` `string` and a `DictionaryAttr` 335 for the `NameAttrList.attr` `map<string, AttrValue>`. It is currently 336 printed and parsed for the following format: 337 338 #tf_type.func<@symbol, {attr = "value"}> 339 340 where the first element is the `SymbolRefAttr` and the second element is the 341 `DictionaryAttr`. 342 343 So that the symbol reference and any symbol references nested in the 344 `DictionaryAttr` are visible to symbol tables, this attribute implements the 345 `SubElementAttrInterface`. 346 }]; 347 348 let parameters = (ins 349 "SymbolRefAttr":$name, 350 "DictionaryAttr":$attrs 351 ); 352 let builders = [ 353 AttrBuilder<(ins "StringRef":$name, "DictionaryAttr":$attr), [{ 354 return $_get($_ctxt, SymbolRefAttr::get($_ctxt, name), attr); 355 }]> 356 ]; 357 let hasCustomAssemblyFormat = 1; 358} 359 360//===----------------------------------------------------------------------===// 361// Placeholder 362//===----------------------------------------------------------------------===// 363 364def TFType_PlaceholderAttr : TFType_Attr<"Placeholder"> { 365 let mnemonic = "placeholder"; 366 let summary = "Placeholder attributes are string referring to a function " 367 "attribute to be substituted on instantiation"; 368 let description = [{ 369 This is matching the `placeholder` Attribute type in protobuf storage. This 370 is just a string, but we need a dedicated type for roundtrip purpose. 371 }]; 372 let parameters = (ins 373 StringRefParameter<"value">:$value 374 ); 375 let hasCustomAssemblyFormat = 1; 376} 377 378def TFGraph_TypeOrPlaceholder 379 : Attr<Or<[TypeAttr.predicate, TFType_PlaceholderAttr.predicate]>, 380 "a type or placeholder attribute"> { 381 let returnType = "::mlir::Attribute"; 382 let convertFromStorage = "$_self"; 383} 384 385//===----------------------------------------------------------------------===// 386// ShapeAttr 387//===----------------------------------------------------------------------===// 388 389def TFType_ShapeAttrDef : TFType_Attr<"Shape"> { 390 let mnemonic = "shape"; 391 let summary = "A shape either unranked or is modelled an array of int64"; 392 let description = [{ 393 This attributes matches the `ShapedType` MLIR Type content into an attribute 394 value. It contains a flag to indicate if it unranked, and if ranked it 395 exposes an array of integer modeling the individual dimensions. A value of 396 `ShapedType::kDynamicDim` indicates a dynamic dimension. 397 }]; 398 399 let parameters = (ins 400 ArrayRefParameter<"int64_t">:$shape, 401 "bool":$unranked 402 ); 403 let builders = [ 404 // Returns a shape attribute for the provided `dimension` array. If the 405 // `dimensions` aren't provided, then the shape attribute is unranked. 406 // For ranked shapes, the value of the each individual dimension size must 407 // be >= 0 or `ShapedType::kDynamicDim`. The value of 408 // `ShapedType::kDynamicDim` means the dimension is dynamic. Otherwise, the 409 // dimension is static. 410 AttrBuilder<(ins "llvm::Optional<ArrayRef<int64_t>>":$dimensions)>, 411 // Returns a Shape attribute from a TensorFlow ShapedType type. 412 AttrBuilder<(ins "ShapedType":$shaped_type)> 413 ]; 414 let extraClassDeclaration = [{ 415 // Returns true if this shape is ranked and has only known dimensions size. 416 bool hasStaticShape() const; 417 418 // Returns true if this shape attribute has a statically known rank. 419 bool hasRank() const; 420 421 // Returns the rank. Aborts if unranked. 422 int64_t getRank() const; 423 424 // Returns the shape array if ranked, or None if unranked. 425 llvm::Optional<ArrayRef<int64_t>> getValue() const; 426 }]; 427 let hasCustomAssemblyFormat = 1; 428} 429 430// An array of TF shapes. 431def TFGraph_ShapesAttr 432 : TypedArrayAttrBase<TFType_ShapeAttrDef, "An array of shapes.">; 433 434//===----------------------------------------------------------------------===// 435// VersionAttr 436//===----------------------------------------------------------------------===// 437 438def TFType_VersionAttr : TFType_Attr<"Version"> { 439 let mnemonic = "version"; 440 let summary = "An Attribute describing the version for a TensorFlow Graph"; 441 let parameters = (ins 442 "int32_t":$producer, 443 "int32_t":$minConsumer, 444 ArrayRefParameter<"int32_t">:$badConsumers 445 ); 446 let hasCustomAssemblyFormat = 1; 447} 448 449//===----------------------------------------------------------------------===// 450// Tensorflow devices metadata 451//===----------------------------------------------------------------------===// 452 453// Tensorflow GPU device metadata. 454def TFType_GpuDeviceMetadata : TFType_Attr<"GpuDeviceMetadata"> { 455 let mnemonic = "gpu_device_metadata"; 456 let summary = "Attribute that specifies a GPU's compute capability"; 457 let parameters = (ins "int32_t":$cc_major, "int32_t":$cc_minor); 458 let assemblyFormat = "`<` struct(params) `>`"; 459} 460 461//===----------------------------------------------------------------------===// 462// TensorProtoAttr 463//===----------------------------------------------------------------------===// 464 465def TF_TensorProtoAttr : TFType_Attr<"TensorProto", [ElementsAttrInterface, TypedAttrInterface]> { 466 let mnemonic = "tensor_proto"; 467 468 let summary = "Attribute that stores TensorFlow TensorProto debug string"; 469 470 let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, 471 StringRefParameter<"">:$value); 472 let builders = [ 473 AttrBuilderWithInferredContext<(ins "ShapedType":$type, 474 "StringRef":$value), [{ 475 return $_get(type.getContext(), type, value); 476 }]>, 477 ]; 478 let extraClassDeclaration = [{ 479 using ValueType = StringRef; 480 }]; 481 482 let hasCustomAssemblyFormat = 1; 483} 484 485 486#endif 487