1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17 
18 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
22 #include "tensorflow/core/framework/rng_alg.h"
23 
24 namespace mlir {
25 namespace TF {
26 
27 namespace {
28 
29 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
30 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)31 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
32   auto resource_type = resource.getType()
33                            .cast<TensorType>()
34                            .getElementType()
35                            .cast<ResourceType>();
36   if (resource_type.getSubtypes().size() == 1)
37     return resource_type.getSubtypes().front();
38 
39   return UnrankedTensorType::get(element_type);
40 }
41 
HasResourceSubtype(Value resource)42 static bool HasResourceSubtype(Value resource) {
43   return resource.getType()
44              .cast<TensorType>()
45              .getElementType()
46              .cast<ResourceType>()
47              .getSubtypes()
48              .size() == 1;
49 }
50 
GetResourceSubtype(Value resource)51 static Type GetResourceSubtype(Value resource) {
52   return resource.getType()
53       .cast<TensorType>()
54       .getElementType()
55       .cast<ResourceType>()
56       .getSubtypes()
57       .front();
58 }
59 
60 // Decompose tf.RngReadAndSkip.
61 //
62 // For Philox, the resource variable holds a tensor<3xi64> with the state:
63 //   [counter_lo, counter_hi, key]
64 //
65 //   RngReadAndSkip increments the 128 bit counter value by 256 * delta and
66 //   returns the original state value.
67 //
68 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
69 //   [counter, key]
70 //
71 //   RngReadAndSkip increments the 64 bit counter value by 256 * delta and
72 //   returns a tensor<3xi64> value [counter, key, 0].
73 class DecomposeRngReadAndSkipOp : public RewritePattern {
74  public:
DecomposeRngReadAndSkipOp(MLIRContext * context)75   explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
76       : RewritePattern(RngReadAndSkipOp::getOperationName(), 1, context,
77                        {
78                            AddV2Op::getOperationName(),
79                            AssignVariableOp::getOperationName(),
80                            CastOp::getOperationName(),
81                            ConstOp::getOperationName(),
82                            LessOp::getOperationName(),
83                            MulOp::getOperationName(),
84                            PadOp::getOperationName(),
85                            PackOp::getOperationName(),
86                            ReadVariableOp::getOperationName(),
87                            SelectV2Op::getOperationName(),
88                            UnpackOp::getOperationName(),
89                        }) {}
90 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const91   LogicalResult matchAndRewrite(Operation *op,
92                                 PatternRewriter &rewriter) const override {
93     auto rng_op = cast<RngReadAndSkipOp>(op);
94 
95     DenseIntElementsAttr alg_constant;
96     if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
97       return rewriter.notifyMatchFailure(
98           op, "unable to determine algorithm statically");
99     }
100 
101     if (alg_constant.getNumElements() != 1) {
102       return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
103     }
104 
105     uint64_t alg_value = ((*alg_constant.value_begin<APInt>()).getZExtValue());
106     tensorflow::Algorithm alg;
107     if (tensorflow::RNG_ALG_PHILOX == alg_value) {
108       alg = tensorflow::RNG_ALG_PHILOX;
109     } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
110       alg = tensorflow::RNG_ALG_THREEFRY;
111     } else {
112       return rewriter.notifyMatchFailure(op, "unsupported alg");
113     }
114 
115     Type state_element_type = rewriter.getI64Type();
116     RankedTensorType op_type = RankedTensorType::get(
117         {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
118         state_element_type);
119     if (op_type != rng_op.getType()) {
120       return rewriter.notifyMatchFailure(op, "unexpected op type");
121     }
122 
123     if (!HasResourceSubtype(rng_op.resource())) {
124       return rewriter.notifyMatchFailure(op, "missing resource subtype");
125     }
126 
127     int counter_size = tensorflow::GetCounterSize(alg);
128     int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
129     RankedTensorType res_type =
130         RankedTensorType::get({state_size}, state_element_type);
131     if (res_type != GetResourceSubtype(rng_op.resource())) {
132       return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
133     }
134 
135     Location loc = op->getLoc();
136 
137     // Read the state value from the resource.
138     Value state =
139         rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
140 
141     // Extract the key and counter from the state.
142     RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
143     auto unpacked = rewriter.create<UnpackOp>(
144         loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
145     Value key = unpacked.getResult(counter_size);
146 
147     SmallVector<Value, 4> counter;
148     for (int i = 0; i < counter_size; ++i) {
149       counter.push_back(unpacked.getResult(i));
150     }
151 
152     // Set the increment to 256 * delta.
153     Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
154     RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
155     Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
156     Value increment =
157         rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
158 
159     // Increment the counter.
160     SmallVector<Value, 4> pack_args;
161     RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
162     Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
163     Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
164     for (int i = 0; i < counter_size; ++i) {
165       Value word = counter[i];
166       Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
167       Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
168       Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
169       pack_args.push_back(new_word);
170 
171       Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
172       increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
173     }
174 
175     // Save the new state value to the resource.
176     pack_args.push_back(key);
177     Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
178     rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
179 
180     // Pad the original state as necessary to fill the output shape.
181     int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
182     Type i64 = rewriter.getI64Type();
183     RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
184     std::vector<int64_t> paddings_values = {0, pad};
185     Value paddings = rewriter.create<ConstOp>(
186         loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
187     Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
188 
189     rewriter.replaceOp(op, output);
190     return success();
191   }
192 };
193 
194 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
195 }  // namespace
196 
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,RewritePatternSet * patterns)197 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
198                                           RewritePatternSet *patterns) {
199   patterns->add<DecomposeRngReadAndSkipOp>(context);
200   populateWithGenerated(*patterns);
201 }
202 
203 }  // namespace TF
204 }  // namespace mlir
205