1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.h"
17
18 #include <stddef.h>
19
20 #include <array>
21 #include <memory>
22 #include <string>
23 #include <vector>
24
25 #include "absl/cleanup/cleanup.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
32 #include "tensorflow/core/tpu/ops/tpu_embedding_shape_util.h"
33 #include "tensorflow/core/tpu/tpu_api.h"
34 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
35 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
36 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
37 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
38 #include "tensorflow/stream_executor/tpu/status_helper.h"
39
40
41 using tensorflow::tpu::TPUEmbeddingConfiguration;
42 using tensorflow::tpu::TpuEmbeddingShapeUtil;
43
44
45 namespace tensorflow {
46
47 // Computes (and VLOGs) the expected shapes of the embedding table shards.
ComputeExpectedTableShardShapes(const TPUEmbeddingConfiguration & config,int shard_id,int num_shards,const string & op_name,std::vector<TensorShape> * table_shapes)48 Status ComputeExpectedTableShardShapes(const TPUEmbeddingConfiguration& config,
49 int shard_id, int num_shards,
50 const string& op_name,
51 std::vector<TensorShape>* table_shapes) {
52 std::vector<TensorShapeProto> shape_protos;
53 const int num_tables = config.table_descriptor_size();
54 TF_RETURN_IF_ERROR(TpuEmbeddingShapeUtil::ComputeTableShapes(
55 config, shard_id, num_shards, &shape_protos));
56 if (num_tables != shape_protos.size()) {
57 return errors::InvalidArgument(
58 op_name, ": The size of the shape_protos vector ", shape_protos.size(),
59 " must be the same as the number of tables ", num_tables);
60 }
61 for (int table_id = 0; table_id < num_tables; ++table_id) {
62 const TensorShape& shape = TensorShape(shape_protos[table_id]);
63 table_shapes->push_back(shape);
64
65 const auto& table_descriptor = config.table_descriptor(table_id);
66 VLOG(1) << "Table " << table_id << " (name " << table_descriptor.name()
67 << ") has shape: " << shape.DebugString()
68 << " on shard: " << shard_id << " (of " << num_shards << ").";
69 }
70
71 return OkStatus();
72 }
73
74 // Logs min/max/avg for the specified state_variable array.
LogRangeStatistics(int32 table_id,int32 state_variable_index,absl::Span<const float> state_variable)75 void LogRangeStatistics(int32 table_id, int32 state_variable_index,
76 absl::Span<const float> state_variable) {
77 if (VLOG_IS_ON(5)) {
78 float min = std::numeric_limits<float>::infinity();
79 float max = -std::numeric_limits<float>::infinity();
80 double avg = 0.0;
81 for (int elt = 0; elt < state_variable.size(); ++elt) {
82 if (state_variable[elt] < min) min = state_variable[elt];
83 if (state_variable[elt] > max) max = state_variable[elt];
84 avg += state_variable[elt];
85 }
86 LOG(INFO) << "Table " << table_id << " state_variable "
87 << state_variable_index << " min " << min << " max " << max
88 << " avg " << avg / state_variable.size() << " total elts "
89 << state_variable.size();
90 }
91 }
92
93
LoadAllTPUEmbeddingParametersOp(OpKernelConstruction * ctx)94 LoadAllTPUEmbeddingParametersOp::LoadAllTPUEmbeddingParametersOp(
95 OpKernelConstruction* ctx)
96 : OpKernel(ctx) {
97 string config_string;
98 OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_string));
99
100 OP_REQUIRES(
101 ctx, config_.ParseFromString(config_string),
102 errors::InvalidArgument("LoadAllTPUEmbeddingParametersOp: Failed to "
103 "parse TPUEmbeddingConfiguration "
104 "proto from config attr"));
105 // Auto populate the feature descriptor
106 // TODO (b/201806244): remove this logic after the change to the
107 // initialization to the config proto.
108 OP_REQUIRES_OK(ctx, PopulateMissingFieldsInTPUEmbeddingConfig(&config_));
109
110 int num_shards;
111 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_shards", &num_shards));
112 int shard_id;
113 OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_id", &shard_id));
114
115 OP_REQUIRES_OK(ctx, ComputeExpectedTableShardShapes(
116 config_, shard_id, num_shards,
117 "LoadAllTPUEmbeddingParametersOp", &table_shapes_));
118 }
119
GetStateVariables(OpKernelContext * ctx,std::array<std::vector<absl::Span<const float>>,tpu::kMaxAuxiliaryParameterCount+1> & state_variable_vector)120 void LoadAllTPUEmbeddingParametersOp::GetStateVariables(
121 OpKernelContext* ctx,
122 std::array<std::vector<absl::Span<const float>>,
123 tpu::kMaxAuxiliaryParameterCount + 1> &state_variable_vector) {
124 std::array<OpInputList, tpu::kMaxAuxiliaryParameterCount + 1>
125 state_variable;
126 OP_REQUIRES_OK(ctx, ctx->input_list("parameters", &state_variable[0]));
127 for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
128 OP_REQUIRES_OK(ctx, ctx->input_list(absl::StrCat("auxiliary", i),
129 &state_variable[i]));
130 }
131 const int num_tables = state_variable[0].size();
132 // This should be enforced by Tensorflow's type system.
133 for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
134 CHECK_EQ(num_tables, state_variable[i].size());
135 }
136
137 OP_REQUIRES(ctx, num_tables == table_shapes_.size(),
138 errors::InvalidArgument(
139 "LoadAllTPUEmbeddingParametersOp has ", num_tables,
140 " inputs in lists but config specifies ",
141 table_shapes_.size(), " embedding tables."));
142
143 CHECK_EQ(num_tables, config_.table_descriptor_size());
144 for (int table_id = 0; table_id < num_tables; ++table_id) {
145 const auto& table_descriptor = config_.table_descriptor(table_id);
146 std::vector<tpu::StateVariableSpecification> state_variable_specs;
147 Status status = tpu::GetOptimizationAlgorithmStateVariables(
148 table_descriptor.optimization_parameters(), &state_variable_specs);
149 OP_REQUIRES(ctx, status.ok(),
150 errors::InvalidArgument(
151 "LoadAllTPUEmbeddingParametersOp: No optimization "
152 "algorithm specified for table ",
153 table_id, " (named ", table_descriptor.name(), ")"));
154 OP_REQUIRES(
155 ctx, state_variable[0][table_id].shape() == table_shapes_[table_id],
156 errors::InvalidArgument(
157 "LoadAllTPUEmbeddingParametersOp: Embeddings for table ",
158 table_id, " (named ", table_descriptor.name(), ") has shape ",
159 state_variable[0][table_id].shape().DebugString(),
160 " but config specifies table shape ",
161 table_shapes_[table_id].DebugString()));
162 for (int i = 1; i < state_variable_specs.size(); ++i) {
163 OP_REQUIRES(
164 ctx, state_variable[i][table_id].shape() == table_shapes_[table_id],
165 errors::InvalidArgument(
166 "LoadAllTPUEmbeddingParametersOp: Auxiliary ", i - 1,
167 " for table ", table_id, " has shape ",
168 state_variable[i][table_id].shape().DebugString(),
169 " but config specifies table shape ",
170 table_shapes_[table_id].DebugString()));
171 }
172 const int64 num_elements = state_variable[0][table_id].NumElements();
173 VLOG(1) << "Table " << table_id << " (name " << table_descriptor.name()
174 << ") has shape: " << table_shapes_[table_id].DebugString()
175 << ", number of elements: " << num_elements;
176 for (int i = 0; i < state_variable_specs.size(); ++i) {
177 OP_REQUIRES(
178 ctx, state_variable[i][table_id].NumElements() == num_elements,
179 errors::InvalidArgument(
180 "LoadAllTPUEmbeddingParametersOp: Embeddings/auxiliary ", i,
181 " for table ", table_id, " has element count ",
182 state_variable[i][table_id].NumElements(),
183 " but config requires count ", num_elements));
184 const float* state_variable_i_ptr =
185 state_variable[i][table_id].flat<float>().data();
186 state_variable_vector[i].push_back(absl::MakeConstSpan(
187 state_variable_i_ptr, static_cast<size_t>(num_elements)));
188 LogRangeStatistics(
189 table_id, i,
190 absl::MakeConstSpan(state_variable_i_ptr, num_elements));
191 }
192 for (int i = state_variable_specs.size();
193 i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
194 OP_REQUIRES(ctx, state_variable[i][table_id].NumElements() == 0,
195 errors::InvalidArgument(
196 "LoadAllTPUEmbeddingParametersOp: Auxiliary ", i,
197 " for table ", table_id, " has element count ",
198 state_variable[i][table_id].NumElements(),
199 " but config requires empty tensor"));
200 state_variable_vector[i].push_back(absl::Span<const float>());
201 }
202 }
203 }
204
Compute(OpKernelContext * ctx)205 void LoadAllTPUEmbeddingParametersOp::Compute(OpKernelContext* ctx) {
206 VLOG(1) << "LoadAllTPUEmbeddingParameters::Compute";
207
208 std::array<std::vector<absl::Span<const float>>,
209 tpu::kMaxAuxiliaryParameterCount + 1> state_variable_vector;
210
211 GetStateVariables(ctx, state_variable_vector);
212 const int num_tables = state_variable_vector[0].size();
213
214 std::unique_ptr<ApiConverter::TpuEmbeddingEngineParametersData> params =
215 ApiConverter::Create(num_tables);
216 std::array<std::vector<FloatListRef>,
217 tpu::kMaxAuxiliaryParameterCount + 1> params_data;
218 for (size_t i = 0; i < tpu::kMaxAuxiliaryParameterCount + 1; i++) {
219 params_data[i] = std::vector<FloatListRef>(num_tables);
220 for (size_t table_id = 0; table_id < num_tables; table_id++) {
221 params->c_params.parameters[i][table_id] = &(params_data[i][table_id]);
222 params->c_params.parameters[i][table_id]->size =
223 state_variable_vector[i][table_id].size();
224 params->c_params.parameters[i][table_id]->ptr =
225 const_cast<float*>(state_variable_vector[i][table_id].data());
226 }
227 }
228 StatusHelper status;
229 tpu::OpsApiFn()->TpuEmbeddingEngine_WriteParametersFn(&(params->c_params),
230 status.c_status);
231 OP_REQUIRES_OK(ctx, status.status());
232
233 VLOG(1) << "LoadAllTPUEmbeddingParameters::Compute done";
234 }
235
236
RetrieveAllTPUEmbeddingParametersOp(OpKernelConstruction * ctx)237 RetrieveAllTPUEmbeddingParametersOp::RetrieveAllTPUEmbeddingParametersOp(
238 OpKernelConstruction* ctx)
239 : OpKernel(ctx) {
240 string config_string;
241 OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_string));
242
243 OP_REQUIRES(
244 ctx, config_.ParseFromString(config_string),
245 errors::InvalidArgument("Failed to parse TPUEmbeddingConfiguration "
246 "proto from config attr"));
247
248 // Auto populate the feature descriptor
249 // TODO (b/201806244): remove this logic after the change to the
250 // initialization to the config proto.
251 OP_REQUIRES_OK(ctx, PopulateMissingFieldsInTPUEmbeddingConfig(&config_));
252
253 int num_shards;
254 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_shards", &num_shards));
255 int shard_id;
256 OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_id", &shard_id));
257
258 OP_REQUIRES_OK(ctx,
259 ComputeExpectedTableShardShapes(
260 config_, shard_id, num_shards,
261 "RetrieveAllTPUEmbeddingParametersOp", &table_shapes_));
262 }
263
GetStateVariables(OpKernelContext * ctx,std::array<std::vector<absl::Span<float>>,tpu::kMaxAuxiliaryParameterCount+1> & state_variable_vector,std::vector<int> & num_state_variables)264 void RetrieveAllTPUEmbeddingParametersOp::GetStateVariables(
265 OpKernelContext* ctx,
266 std::array<std::vector<absl::Span<float>>,
267 tpu::kMaxAuxiliaryParameterCount + 1> &state_variable_vector,
268 std::vector<int> & num_state_variables) {
269 std::array<OpOutputList, tpu::kMaxAuxiliaryParameterCount + 1>
270 state_variable;
271 OP_REQUIRES_OK(ctx, ctx->output_list("parameters", &state_variable[0]));
272 for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
273 OP_REQUIRES_OK(ctx, ctx->output_list(absl::StrCat("auxiliary", i),
274 &state_variable[i]));
275 }
276 const int num_tables = state_variable[0].size();
277 // This should be enforced by Tensorflow's type system.
278 for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
279 CHECK_EQ(num_tables, state_variable[i].size());
280 }
281
282 OP_REQUIRES(ctx, num_tables == table_shapes_.size(),
283 errors::InvalidArgument(
284 "RetrieveAllTPUEmbeddingParametersOp has ", num_tables,
285 " outputs in lists but config specifies ",
286 table_shapes_.size(), " embedding tables."));
287
288 for (auto& v : state_variable_vector) {
289 v.resize(num_tables);
290 }
291 num_state_variables.resize(num_tables);
292
293 // Get locations to write returned data
294 for (int table_id = 0; table_id < num_tables; ++table_id) {
295 const auto& table_descriptor = config_.table_descriptor(table_id);
296
297 std::vector<tpu::StateVariableSpecification> state_variable_specs;
298 Status status = tpu::GetOptimizationAlgorithmStateVariables(
299 table_descriptor.optimization_parameters(), &state_variable_specs);
300 OP_REQUIRES(
301 ctx, status.ok(),
302 errors::InvalidArgument("RetrieveAllTPUEmbeddingParametersOp: No "
303 "optimization algorithm specified for table ",
304 table_id));
305 num_state_variables[table_id] = state_variable_specs.size();
306 const int64 num_elements = table_shapes_[table_id].num_elements();
307 for (int i = 0; i < state_variable_specs.size(); ++i) {
308 Tensor* state_variable_tensor;
309 OP_REQUIRES_OK(
310 ctx, state_variable[i].allocate(table_id, table_shapes_[table_id],
311 &state_variable_tensor));
312 float* state_variable_ptr = state_variable_tensor->flat<float>().data();
313 state_variable_vector[i][table_id] =
314 absl::MakeSpan(state_variable_ptr, num_elements);
315 }
316 // Fill in auxiliary values after the number actually used for table_id
317 // with empty 2-D tensors.
318 for (int i = state_variable_specs.size();
319 i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
320 Tensor* auxiliary_tensor;
321 TensorShape shape;
322 std::array<int32, 2> dims = {{0, 0}};
323 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(dims, &shape));
324 OP_REQUIRES_OK(ctx, state_variable[i].allocate(table_id, shape,
325 &auxiliary_tensor));
326 state_variable_vector[i][table_id] = absl::Span<float>();
327 }
328 }
329 }
330
Compute(OpKernelContext * ctx)331 void RetrieveAllTPUEmbeddingParametersOp::Compute(OpKernelContext* ctx) {
332 VLOG(1) << "RetrieveAllTPUEmbeddingParameters::Compute";
333
334 std::array<std::vector<absl::Span<float>>,
335 tpu::kMaxAuxiliaryParameterCount + 1> state_variable_vector;
336 std::vector<int> num_state_variables;
337
338 GetStateVariables(ctx, state_variable_vector, num_state_variables);
339 const int num_tables = state_variable_vector[0].size();
340
341
342 std::unique_ptr<ApiConverter::TpuEmbeddingEngineParametersData> params =
343 ApiConverter::Create(num_tables);
344 std::array<std::vector<FloatListRef>,
345 tpu::kMaxAuxiliaryParameterCount + 1> params_data;
346 for (size_t i = 0; i < tpu::kMaxAuxiliaryParameterCount + 1; i++) {
347 params_data[i] = std::vector<FloatListRef>(num_tables);
348 for (size_t table_id = 0; table_id < num_tables; table_id++) {
349 params->c_params.parameters[i][table_id] =
350 &(params_data[i][table_id]);
351 params->c_params.parameters[i][table_id]->size =
352 state_variable_vector[i][table_id].size(),
353 params->c_params.parameters[i][table_id]->ptr =
354 state_variable_vector[i][table_id].data();
355 }
356 }
357 StatusHelper status;
358 tpu::OpsApiFn()->TpuEmbeddingEngine_ReadParametersFn(&(params->c_params),
359 status.c_status);
360 OP_REQUIRES_OK(ctx, status.status());
361
362 if (VLOG_IS_ON(5)) {
363 for (int table_id = 0; table_id < num_tables; ++table_id) {
364 for (int i = 0; i < num_state_variables[table_id]; ++i) {
365 LogRangeStatistics(table_id, i, state_variable_vector[i][table_id]);
366 }
367 }
368 }
369 }
370
371 #ifdef LIBTPU_ON_GCE
372
373 REGISTER_KERNEL_BUILDER(
374 Name("LoadAllTPUEmbeddingParameters").Device(DEVICE_CPU),
375 LoadAllTPUEmbeddingParametersOp);
376 REGISTER_KERNEL_BUILDER(
377 Name("RetrieveAllTPUEmbeddingParameters").Device(DEVICE_CPU),
378 RetrieveAllTPUEmbeddingParametersOp);
379
380 #endif // LIBTPU_ON_GCE
381 } // namespace tensorflow
382