1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17
18 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
22 #include "tensorflow/core/framework/rng_alg.h"
23
24 namespace mlir {
25 namespace TF {
26
27 namespace {
28
29 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
30 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)31 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
32 auto resource_type = resource.getType()
33 .cast<TensorType>()
34 .getElementType()
35 .cast<ResourceType>();
36 if (resource_type.getSubtypes().size() == 1)
37 return resource_type.getSubtypes().front();
38
39 return UnrankedTensorType::get(element_type);
40 }
41
HasResourceSubtype(Value resource)42 static bool HasResourceSubtype(Value resource) {
43 return resource.getType()
44 .cast<TensorType>()
45 .getElementType()
46 .cast<ResourceType>()
47 .getSubtypes()
48 .size() == 1;
49 }
50
GetResourceSubtype(Value resource)51 static Type GetResourceSubtype(Value resource) {
52 return resource.getType()
53 .cast<TensorType>()
54 .getElementType()
55 .cast<ResourceType>()
56 .getSubtypes()
57 .front();
58 }
59
60 // Decompose tf.RngReadAndSkip.
61 //
62 // For Philox, the resource variable holds a tensor<3xi64> with the state:
63 // [counter_lo, counter_hi, key]
64 //
65 // RngReadAndSkip increments the 128 bit counter value by 256 * delta and
66 // returns the original state value.
67 //
68 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
69 // [counter, key]
70 //
71 // RngReadAndSkip increments the 64 bit counter value by 256 * delta and
72 // returns a tensor<3xi64> value [counter, key, 0].
73 class DecomposeRngReadAndSkipOp : public RewritePattern {
74 public:
DecomposeRngReadAndSkipOp(MLIRContext * context)75 explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
76 : RewritePattern(RngReadAndSkipOp::getOperationName(), 1, context,
77 {
78 AddV2Op::getOperationName(),
79 AssignVariableOp::getOperationName(),
80 CastOp::getOperationName(),
81 ConstOp::getOperationName(),
82 LessOp::getOperationName(),
83 MulOp::getOperationName(),
84 PadOp::getOperationName(),
85 PackOp::getOperationName(),
86 ReadVariableOp::getOperationName(),
87 SelectV2Op::getOperationName(),
88 UnpackOp::getOperationName(),
89 }) {}
90
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const91 LogicalResult matchAndRewrite(Operation *op,
92 PatternRewriter &rewriter) const override {
93 auto rng_op = cast<RngReadAndSkipOp>(op);
94
95 DenseIntElementsAttr alg_constant;
96 if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
97 return rewriter.notifyMatchFailure(
98 op, "unable to determine algorithm statically");
99 }
100
101 if (alg_constant.getNumElements() != 1) {
102 return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
103 }
104
105 uint64_t alg_value = ((*alg_constant.value_begin<APInt>()).getZExtValue());
106 tensorflow::Algorithm alg;
107 if (tensorflow::RNG_ALG_PHILOX == alg_value) {
108 alg = tensorflow::RNG_ALG_PHILOX;
109 } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
110 alg = tensorflow::RNG_ALG_THREEFRY;
111 } else {
112 return rewriter.notifyMatchFailure(op, "unsupported alg");
113 }
114
115 Type state_element_type = rewriter.getI64Type();
116 RankedTensorType op_type = RankedTensorType::get(
117 {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
118 state_element_type);
119 if (op_type != rng_op.getType()) {
120 return rewriter.notifyMatchFailure(op, "unexpected op type");
121 }
122
123 if (!HasResourceSubtype(rng_op.resource())) {
124 return rewriter.notifyMatchFailure(op, "missing resource subtype");
125 }
126
127 int counter_size = tensorflow::GetCounterSize(alg);
128 int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
129 RankedTensorType res_type =
130 RankedTensorType::get({state_size}, state_element_type);
131 if (res_type != GetResourceSubtype(rng_op.resource())) {
132 return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
133 }
134
135 Location loc = op->getLoc();
136
137 // Read the state value from the resource.
138 Value state =
139 rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
140
141 // Extract the key and counter from the state.
142 RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
143 auto unpacked = rewriter.create<UnpackOp>(
144 loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
145 Value key = unpacked.getResult(counter_size);
146
147 SmallVector<Value, 4> counter;
148 for (int i = 0; i < counter_size; ++i) {
149 counter.push_back(unpacked.getResult(i));
150 }
151
152 // Set the increment to 256 * delta.
153 Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
154 RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
155 Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
156 Value increment =
157 rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
158
159 // Increment the counter.
160 SmallVector<Value, 4> pack_args;
161 RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
162 Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
163 Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
164 for (int i = 0; i < counter_size; ++i) {
165 Value word = counter[i];
166 Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
167 Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
168 Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
169 pack_args.push_back(new_word);
170
171 Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
172 increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
173 }
174
175 // Save the new state value to the resource.
176 pack_args.push_back(key);
177 Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
178 rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
179
180 // Pad the original state as necessary to fill the output shape.
181 int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
182 Type i64 = rewriter.getI64Type();
183 RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
184 std::vector<int64_t> paddings_values = {0, pad};
185 Value paddings = rewriter.create<ConstOp>(
186 loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
187 Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
188
189 rewriter.replaceOp(op, output);
190 return success();
191 }
192 };
193
194 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
195 } // namespace
196
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,RewritePatternSet * patterns)197 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
198 RewritePatternSet *patterns) {
199 patterns->add<DecomposeRngReadAndSkipOp>(context);
200 populateWithGenerated(*patterns);
201 }
202
203 } // namespace TF
204 } // namespace mlir
205