xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/types/attributes.td (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TF_TYPE_ATTRIBUTES
17#define TF_TYPE_ATTRIBUTES
18
19include "mlir/IR/AttrTypeBase.td"
20include "mlir/IR/SubElementInterfaces.td"
21include "tensorflow/core/ir/types/dialect.td"
22include "mlir/IR/BuiltinAttributeInterfaces.td"
23
24// Base class for TFType dialect attributes.
25class TFType_Attr<string name, list<Trait> traits = []>
26    : AttrDef<TFTypeDialect, name, traits>;
27
28// LINT.IfChange
29def TFType_FullTypeId : I32EnumAttr<"FullTypeId", "", [
30  // The default represents an uninitialized values.
31  I32EnumAttrCase<"TFT_UNSET", 0, "unset">,
32
33  // Type symbols. Used to construct more complex type expressions like
34  // algebraic data types.
35
36  // Type variables may serve as placeholder for any other type ID in type
37  // templates.
38  //
39  // Examples:
40  //   TFT_DATASET[TFT_VAR["T"]] is a Dataset returning a type indicated by "T".
41  //   TFT_TENSOR[TFT_VAR["T"]] is a Tensor of n element type indicated by "T".
42  //   TFT_TENSOR[TFT_VAR["T"]], TFT_TENSOR[TFT_VAR["T"]] are two tensors of
43  //     identical element types.
44  //   TFT_TENSOR[TFT_VAR["P"]], TFT_TENSOR[TFT_VAR["Q"]] are two tensors of
45  //     independent element types.
46  //
47  I32EnumAttrCase<"TFT_VAR", 1, "var">,
48
49  // Wildcard type. Describes a parameter of unknown type. In TensorFlow, that
50  // can mean either a "Top" type (accepts any type), or a dynamically typed
51  // object whose type is unknown in context.
52  // Important: "unknown" does not necessarily mean undeterminable!
53  I32EnumAttrCase<"TFT_ANY", 2, "any">,
54
55  // The algebraic product type. This is an algebraic type that may be used just
56  // for logical grouping. Not to confused with TFT_TUPLE which describes a
57  // concrete object of several elements.
58  //
59  // Example:
60  //   TFT_DATASET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]]]
61  //     is a Dataset producing two tensors, an integer one and a float one.
62  //
63  I32EnumAttrCase<"TFT_PRODUCT", 3, "product">,
64
65  // Represents a named field, with the name stored in the attribute.
66  //
67  // Parametrization:
68  //   TFT_NAMED[<type>]{<name>}
69  //   * <type> is the type of the field
70  //   * <name> is the field name, as string (thpugh can theoretically be an int
71  //     as well)
72  //
73  // Example:
74  //   TFT_RECORD[
75  //     TFT_NAMED[TFT_TENSOR[TFT_INT32]]{'foo'},
76  //     TFT_NAMED[TFT_TENSOR[TFT_FLOAT32]]{'bar'},
77  //   ]
78  //     is a structure with two fields, an int tensor "foo" and a float tensor
79  //     "bar".
80  I32EnumAttrCase<"TFT_NAMED", 4, "named">,
81
82  // Template definition. Expands the variables by repeating a template as
83  // arguments of container.
84  //
85  // Parametrization:
86  //   TFT_FOR_EACH[<container_type>, <template>, <expansions>]
87  //   * <container_type> is the type of the container that the template will be
88  //     expanded into
89  //   * <template> is any type definition that potentially contains type
90  //     variables
91  //   * <expansions> is a TFT_VAR and may include more types in the future
92  //
93  // Example:
94  //   TFT_FOR_EACH[
95  //         TFT_PRODUCT,
96  //         TFT_TENSOR[TFT_VAR["t"]],
97  //         TFT_VAR["t"]
98  //     ]
99  //     will substitute a T = TFT_INT32 to TFT_PRODUCT[TFT_TENSOR[TFT_INT32]]
100  //     and a T = (TFT_INT32, TFT_INT64) to
101  //     TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_INT64]].
102  I32EnumAttrCase<"TFT_FOR_EACH", 20, "for_each">,
103
104  // Callable types describe functions and ops.
105  //
106  // Parametrization:
107  //   TFT_CALLABLE[<arg type>, <return type>]
108  //   * <arg type> is the type of the arguments; TFT_PRODUCT represents
109  //   multiple
110  //     arguments.
111  //   * <return type> is the return type; TFT_PRODUCT represents multiple
112  //     return values (that means that callables returning multiple things
113  //     don't necessarily return a single tuple).
114  //
115  // Example:
116  //   TFT_CALLABLE[
117  //     TFT_ANY,
118  //     TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]],
119  //   ]
120  //     is a callable with unspecified (for now) input arguments, and
121  //     two return values of type tensor.
122  //
123  I32EnumAttrCase<"TFT_CALLABLE", 100, "callable">,
124
125  // Concrete type IDs, representing "proper" data types that can describe
126  // runtime TensorFlow objects.
127
128  // The usual Tensor. This is a parametric type.
129  //
130  // Parametrization:
131  //   TFT_TENSOR[<element type>, <shape type>]
132  //   * <element type> is currently limited to one of the element types
133  //     defined below.
134  //   * <shape type> is not yet defined, and may only be TFT_UNKNOWN for now.
135  //
136  // A TFT_SHAPE type will be defined in the future.
137  //
138  // Example:
139  //   TFT_TENSOR[TFT_INT32, TFT_UNKNOWN]
140  //     is a Tensor of int32 element type and unknown shape.
141  //
142  // TODO(mdan): Define TFT_SHAPE and add more examples.
143  I32EnumAttrCase<"TFT_TENSOR", 1000, "tensor">,
144
145  // Array (or tensorflow::TensorList in the variant type registry).
146  // Note: this is not to be confused with the deprecated `TensorArray*` ops
147  // which are not supported by FullType.
148  // This type represents a random-access list whose elements can be
149  // described by a single type. Although immutable, Array is expected to
150  // support efficient mutation semantics (i.e. element update) in the
151  // user-facing API.
152  // The element type may be generic or even TFT_ANY for a heterogenous list.
153  //
154  // Parametrization:
155  //   TFT_ARRAY[<element type>]
156  //   * <element type> may be any concrete type.
157  //
158  // Examples:
159  //   TFT_ARRAY[TFT_TENSOR[TFT_INT32]] is a TensorArray holding int32 Tensors
160  //     of any shape.
161  //   TFT_ARRAY[TFT_TENSOR[TFT_UNKNOWN]] is a TensorArray holding Tensors of
162  //     mixed element types.
163  //   TFT_ARRAY[TFT_UNKNOWN] is a TensorArray holding any element type.
164  //   TFT_ARRAY[] is equivalent to TFT_ARRAY[TFT_UNKNOWN].
165  //   TFT_ARRAY[TFT_ARRAY[]] is an array or arrays (of unknown types).
166  I32EnumAttrCase<"TFT_ARRAY", 1001, "array">,
167
168  // Optional (or tensorflow::OptionalVariant in the variant type registry).
169  // This type represents a value that may either hold an element of a single
170  // specified type, or nothing at all.
171  //
172  // Parametrization:
173  //   TFT_OPTIONAL[<element type>]
174  //   * <element type> may be any concrete type.
175  //
176  // Examples:
177  //   TFT_OPTIONAL[TFT_TENSOR[TFT_INT32]] is an Optional holding an int32
178  //     Tensor of any shape.
179  I32EnumAttrCase<"TFT_OPTIONAL", 1002, "optional">,
180
181  // Literal types describe compile-time constant values.
182  // Literal types may also participate in dependent types.
183  //
184  // Parametrization:
185  //   TFT_LITERAL[<value type>]{<value>}
186  //   * <value type> may be any concrete type compatible that can hold <value>
187  //   * <value> is the type's attribute, and holds the actual literal value
188  //
189  // Examples:
190  //   TFT_LITERAL[TFT_INT32]{1} is the compile-time constant 1.
191  I32EnumAttrCase<"TFT_LITERAL", 1003, "literal">,
192
193  // Encoding types describe a value of a certain type, encoded as a different
194  // type.
195  //
196  // Parametrization:
197  //   TFT_ENCODED[<encoded type>, <encoding type>]
198  //   * <encoded type> may be any type
199  //   * <encoding type> may be any type
200  //
201  // Examples:
202  //   TFT_ENCODING[TFT_INT32, TFT_STRING] is an integer encoded as string.
203  I32EnumAttrCase<"TFT_ENCODED", 1004, "encoded">,
204
205  // Type attributes. These always appear in the parametrization of a type,
206  // never alone. For example, there is no such thing as a "bool" TensorFlow
207  // object (for now).
208
209  // The bool element type.
210  // TODO(mdan): Quantized types, legacy representations (e.g. ref)
211  I32EnumAttrCase<"TFT_BOOL", 200, "bool">,
212  // Integer element types.
213  I32EnumAttrCase<"TFT_UINT8", 201, "uint8">,
214  I32EnumAttrCase<"TFT_UINT16", 202, "uint16">,
215  I32EnumAttrCase<"TFT_UINT32", 203, "uint32">,
216  I32EnumAttrCase<"TFT_UINT64", 204, "uint64">,
217  I32EnumAttrCase<"TFT_INT8", 205, "int8">,
218  I32EnumAttrCase<"TFT_INT16", 206, "int16">,
219  I32EnumAttrCase<"TFT_INT32", 207, "int32">,
220  I32EnumAttrCase<"TFT_INT64", 208, "int64">,
221  // Floating-point element types.
222  I32EnumAttrCase<"TFT_HALF", 209, "half">,
223  I32EnumAttrCase<"TFT_FLOAT", 210, "float">,
224  I32EnumAttrCase<"TFT_DOUBLE", 211, "double">,
225  I32EnumAttrCase<"TFT_BFLOAT16", 215, "bfloat16">,
226  // TODO(mdan): Represent as TFT_COMPLEX[TFT_DOUBLE] instead?
227  I32EnumAttrCase<"TFT_COMPLEX64", 212, "complex64">,
228  I32EnumAttrCase<"TFT_COMPLEX128", 213, "complex128">,
229  // The string element type.
230  I32EnumAttrCase<"TFT_STRING", 214, "string">,
231
232  // Other types that we don't know yet whether they will become part of the
233  // core type system or be consisdered third-party (and consequently moved to
234  // user-defined type mechanisms). Presently, they are effectively in the core
235  // type system, because key compilation passes like Placer account for their
236  // existence.
237
238  // Datasets created by tf.data ops and APIs. Datasets have generator/iterable
239  // semantics, that is, one can construct an iterator from them. Like
240  // Array, they are considered to return elements that can be described
241  // by a single type. Unlike Array, they do not support random access or
242  // mutation, and can potentially produce an infinite number of elements.
243  // A datasets can produce logical structures (e.g. multiple elements). This
244  // is expressed using TFT_PRODUCT.
245  //
246  //
247  // Parametrization: TFT_ARRAY[<element type>].
248  //   * <element type> may be a concrete type or a type symbol. It represents
249  //     the data type of the elements produced by the dataset.
250  //
251  // Examples:
252  //   TFT_DATSET[TFT_TENSOR[TFT_INT32]] is a Dataset producing single int32
253  //     Tensors of unknown shape.
254  //   TFT_DATSET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT32]] is
255  //     a Dataset producing pairs of Tensors, one integer and one float.
256  // Note: The high ID number is to prepare for the eventuality that Datasets
257  // will be supported by user types in the future.
258  I32EnumAttrCase<"TFT_DATASET", 10102, "dataset">,
259
260  // A ragged tensor created by tf.ragged ops and APIs.
261  //
262  // Parametrization: TFT_RAGGED[<element_type>].
263  I32EnumAttrCase<"TFT_RAGGED", 10103, "ragged">,
264
265  // Iterators created by tf.data ops and APIs. Very similar to Datasets, except
266  // they are mutable.
267  //
268  //
269  // Parametrization: TFT_ITERATOR[<element type>].
270  //   * <element type> may be a concrete type or a type symbol. It represents
271  //     the data type of the elements produced by the dataset.
272  I32EnumAttrCase<"TFT_ITERATOR", 10104, "iterator">,
273
274  // A mutex lock tensor, produced by tf.raw_ops.MutexLock.
275  // Unlike strict execution models, where ownership of a lock is denoted by
276  // "running after the lock has been acquired", in non-strict mode, lock
277  // ownership is in the true sense: "the op argument representing the lock is
278  // available".
279  // Mutex locks are the dynamic counterpart of control dependencies.
280  // TODO(mdan): Properly document this thing.
281  //
282  // Parametrization: TFT_MUTEX_LOCK[].
283  I32EnumAttrCase<"TFT_MUTEX_LOCK", 10202, "mutex_lock">,
284
285  // The equivalent of a Tensor with DT_VARIANT dtype, kept here to simplify
286  // translation. This type should not normally appear after type inference.
287  // Note that LEGACY_VARIANT != ANY: TENSOR[INT32] is a subtype of ANY, but is
288  // not a subtype of LEGACY_VARIANT.
289  I32EnumAttrCase<"TFT_LEGACY_VARIANT", 10203, "legacy_variant">
290]> {
291  let cppNamespace = "::mlir::tf_type";
292  string cppType = "int32_t";
293  let genSpecializedAttr = 0;
294}
295
296def TFType_FullTypeArgsAttr : ArrayRefParameter<"::mlir::tf_type::FullTypeAttr", "args">;
297
298def TFType_FullTypeAttrAttr : Attr<Or<[StrAttr.predicate, SI64Attr.predicate]>,
299  "FullType literal attr"> {
300  let storageType = "Attribute";
301  let returnType = "Attribute";
302  let convertFromStorage = "$_self";
303  let constBuilderCall = "$0";
304  string cppType = "Attribute";
305  let isOptional = 1;
306}
307
308def TFType_FullTypeAttr : AttrDef<TFTypeDialect, "FullType"> {
309  let summary = "FullType";
310  let parameters = (ins
311      TFType_FullTypeId:$type_id,
312      TFType_FullTypeArgsAttr:$args,
313      TFType_FullTypeAttrAttr:$attr
314  );
315  let mnemonic = "full_type";
316  let hasCustomAssemblyFormat = 1;
317  // Format is effectively: type_id ('<' $args^ '>')? ($attr?)
318}
319// LINT.ThenChange(../../framework/full_type.proto)
320
321//===----------------------------------------------------------------------===//
322// FuncAttr
323//===----------------------------------------------------------------------===//
324
325def TFType_FuncAttr : TFType_Attr<"Func", [
326    DeclareAttrInterfaceMethods<SubElementAttrInterface,
327        ["replaceImmediateSubElements"]>
328  ]> {
329  let mnemonic = "func";
330  let summary = "Models the `AttrValue.value.func` proto attribute value as a "
331    "pair of SymbolRef and DictionaryAttr";
332  let description = [{
333    This attributes matches the protobuf `AttrValue.value.func` with a
334    `SymbolRefAttr`, for the `NameAttrList.name` `string` and a `DictionaryAttr`
335    for the `NameAttrList.attr` `map<string, AttrValue>`. It is currently
336    printed and parsed for the following format:
337
338      #tf_type.func<@symbol, {attr = "value"}>
339
340    where the first element is the `SymbolRefAttr` and the second element is the
341    `DictionaryAttr`.
342
343    So that the symbol reference and any symbol references nested in the
344    `DictionaryAttr` are visible to symbol tables, this attribute implements the
345    `SubElementAttrInterface`.
346  }];
347
348  let parameters = (ins
349    "SymbolRefAttr":$name,
350    "DictionaryAttr":$attrs
351  );
352  let builders = [
353    AttrBuilder<(ins "StringRef":$name, "DictionaryAttr":$attr), [{
354      return $_get($_ctxt, SymbolRefAttr::get($_ctxt, name), attr);
355    }]>
356  ];
357  let hasCustomAssemblyFormat = 1;
358}
359
360//===----------------------------------------------------------------------===//
361// Placeholder
362//===----------------------------------------------------------------------===//
363
364def TFType_PlaceholderAttr : TFType_Attr<"Placeholder"> {
365  let mnemonic = "placeholder";
366  let summary = "Placeholder attributes are string referring to a function "
367    "attribute to be substituted on instantiation";
368  let description = [{
369    This is matching the `placeholder` Attribute type in protobuf storage. This
370    is just a string, but we need a dedicated type for roundtrip purpose.
371  }];
372  let parameters = (ins
373    StringRefParameter<"value">:$value
374  );
375  let hasCustomAssemblyFormat = 1;
376}
377
378def TFGraph_TypeOrPlaceholder
379    : Attr<Or<[TypeAttr.predicate, TFType_PlaceholderAttr.predicate]>,
380          "a type or placeholder attribute"> {
381  let returnType = "::mlir::Attribute";
382  let convertFromStorage = "$_self";
383}
384
385//===----------------------------------------------------------------------===//
386// ShapeAttr
387//===----------------------------------------------------------------------===//
388
389def TFType_ShapeAttrDef : TFType_Attr<"Shape"> {
390  let mnemonic = "shape";
391  let summary = "A shape either unranked or is modelled an array of int64";
392  let description = [{
393    This attributes matches the `ShapedType` MLIR Type content into an attribute
394    value. It contains a flag to indicate if it unranked, and if ranked it
395    exposes an array of integer modeling the individual dimensions. A value of
396    `ShapedType::kDynamicDim` indicates a dynamic dimension.
397  }];
398
399  let parameters = (ins
400    ArrayRefParameter<"int64_t">:$shape,
401    "bool":$unranked
402  );
403  let builders = [
404    // Returns a shape attribute for the provided `dimension` array. If the
405    // `dimensions` aren't provided, then the shape attribute is unranked.
406    // For ranked shapes, the value of the each individual dimension size must
407    // be >= 0 or `ShapedType::kDynamicDim`. The value of
408    // `ShapedType::kDynamicDim` means the dimension is dynamic. Otherwise, the
409    // dimension is static.
410    AttrBuilder<(ins "llvm::Optional<ArrayRef<int64_t>>":$dimensions)>,
411    // Returns a Shape attribute from a TensorFlow ShapedType type.
412    AttrBuilder<(ins "ShapedType":$shaped_type)>
413  ];
414  let extraClassDeclaration = [{
415    // Returns true if this shape is ranked and has only known dimensions size.
416    bool hasStaticShape() const;
417
418    // Returns true if this shape attribute has a statically known rank.
419    bool hasRank() const;
420
421    // Returns the rank. Aborts if unranked.
422    int64_t getRank() const;
423
424    // Returns the shape array if ranked, or None if unranked.
425    llvm::Optional<ArrayRef<int64_t>> getValue() const;
426  }];
427  let hasCustomAssemblyFormat = 1;
428}
429
430// An array of TF shapes.
431def TFGraph_ShapesAttr
432    : TypedArrayAttrBase<TFType_ShapeAttrDef, "An array of shapes.">;
433
434//===----------------------------------------------------------------------===//
435// VersionAttr
436//===----------------------------------------------------------------------===//
437
438def TFType_VersionAttr : TFType_Attr<"Version"> {
439  let mnemonic = "version";
440  let summary = "An Attribute describing the version for a TensorFlow Graph";
441  let parameters = (ins
442    "int32_t":$producer,
443    "int32_t":$minConsumer,
444    ArrayRefParameter<"int32_t">:$badConsumers
445  );
446  let hasCustomAssemblyFormat = 1;
447}
448
449//===----------------------------------------------------------------------===//
450// Tensorflow devices metadata
451//===----------------------------------------------------------------------===//
452
453// Tensorflow GPU device metadata.
454def TFType_GpuDeviceMetadata : TFType_Attr<"GpuDeviceMetadata"> {
455  let mnemonic = "gpu_device_metadata";
456  let summary = "Attribute that specifies a GPU's compute capability";
457  let parameters = (ins "int32_t":$cc_major, "int32_t":$cc_minor);
458  let assemblyFormat = "`<` struct(params) `>`";
459}
460
461//===----------------------------------------------------------------------===//
462// TensorProtoAttr
463//===----------------------------------------------------------------------===//
464
465def TF_TensorProtoAttr : TFType_Attr<"TensorProto", [ElementsAttrInterface, TypedAttrInterface]> {
466  let mnemonic = "tensor_proto";
467
468  let summary = "Attribute that stores TensorFlow TensorProto debug string";
469
470  let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
471                        StringRefParameter<"">:$value);
472  let builders = [
473    AttrBuilderWithInferredContext<(ins "ShapedType":$type,
474                                        "StringRef":$value), [{
475      return $_get(type.getContext(), type, value);
476    }]>,
477  ];
478  let extraClassDeclaration = [{
479    using ValueType = StringRef;
480  }];
481
482  let hasCustomAssemblyFormat = 1;
483}
484
485
486#endif
487