xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/sparse_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/platform/errors.h"
21 
22 namespace tensorflow {
23 
24 using shape_inference::DimensionHandle;
25 using shape_inference::InferenceContext;
26 using shape_inference::ShapeHandle;
27 
28 namespace {
29 
SparseSparseMinOrMaxShapeFn(InferenceContext * c)30 Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
31   ShapeHandle unused;
32   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
33   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));  // a_shape
35   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused));  // b_indices
36   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));  // b_values
37   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused));  // b_shape
38   c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
39                              InferenceContext::kUnknownDim));
40   c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
41   return OkStatus();
42 }
43 
44 }  // namespace
45 
46 REGISTER_OP("SparseAddGrad")
47     .Input("backprop_val_grad: T")
48     .Input("a_indices: int64")
49     .Input("b_indices: int64")
50     .Input("sum_indices: int64")
51     .Output("a_val_grad: T")
52     .Output("b_val_grad: T")
53     .Attr("T: numbertype")
__anone91b66370202(InferenceContext* c) 54     .SetShapeFn([](InferenceContext* c) {
55       ShapeHandle a_indices;
56       ShapeHandle b_indices;
57       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
58       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
59       c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
60       c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
61       return OkStatus();
62     });
63 
64 REGISTER_OP("SparseAdd")
65     .Input("a_indices: int64")
66     .Input("a_values: T")
67     .Input("a_shape: int64")
68     .Input("b_indices: int64")
69     .Input("b_values: T")
70     .Input("b_shape: int64")
71     .Input("thresh: Treal")
72     .Output("sum_indices: int64")
73     .Output("sum_values: T")
74     .Output("sum_shape: int64")
75     .Attr("T: numbertype")
76     .Attr("Treal: realnumbertype")
__anone91b66370302(InferenceContext* c) 77     .SetShapeFn([](InferenceContext* c) {
78       ShapeHandle a_shape;
79       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
80       c->set_output(
81           0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
82       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
83       c->set_output(2, a_shape);
84       return OkStatus();
85     });
86 
87 REGISTER_OP("SparseTensorDenseMatMul")
88     .Input("a_indices: Tindices")
89     .Input("a_values: T")
90     .Input("a_shape: int64")
91     .Input("b: T")
92     .Output("product: T")
93     .Attr("T: type")
94     .Attr("Tindices: {int32,int64} = DT_INT64")
95     .Attr("adjoint_a: bool = false")
96     .Attr("adjoint_b: bool = false")
__anone91b66370402(InferenceContext* c) 97     .SetShapeFn([](InferenceContext* c) {
98       DimensionHandle unused_dim;
99       ShapeHandle unused;
100       ShapeHandle b;
101       ShapeHandle a_shape;
102       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // a_indices
103       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));  // a_values
104       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
105       TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
106       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
107 
108       bool adjoint_a;
109       bool adjoint_b;
110       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
111       TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
112 
113       DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
114       DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
115       DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
116       DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
117       TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
118       c->set_output(0, c->Matrix(output_left, output_right));
119       return OkStatus();
120     });
121 
122 REGISTER_OP("SerializeSparse")
123     .Input("sparse_indices: int64")
124     .Input("sparse_values: T")
125     .Input("sparse_shape: int64")
126     .Attr("T: type")
127     .Output("serialized_sparse: out_type")
128     .Attr("out_type: {string, variant} = DT_STRING")
__anone91b66370502(InferenceContext* c) 129     .SetShapeFn([](InferenceContext* c) {
130       ShapeHandle unused;
131       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
132       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
133       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
134       c->set_output(0, c->Vector(3));
135       return OkStatus();
136     });
137 
138 REGISTER_OP("SerializeManySparse")
139     .Input("sparse_indices: int64")
140     .Input("sparse_values: T")
141     .Input("sparse_shape: int64")
142     .Attr("T: type")
143     .Output("serialized_sparse: out_type")
144     .Attr("out_type: {string, variant} = DT_STRING")
__anone91b66370602(InferenceContext* c) 145     .SetShapeFn([](InferenceContext* c) {
146       ShapeHandle unused;
147       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
148       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
149       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
150       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
151       return OkStatus();
152     });
153 
154 REGISTER_OP("DeserializeSparse")
155     .Input("serialized_sparse: Tserialized")
156     .Output("sparse_indices: int64")
157     .Output("sparse_values: dtype")
158     .Output("sparse_shape: int64")
159     .Attr("dtype: type")
160     .Attr("Tserialized: {string, variant} = DT_STRING")
__anone91b66370702(InferenceContext* c) 161     .SetShapeFn([](InferenceContext* c) {
162       // serialized sparse is [?, ..., ?, 3] vector.
163       ShapeHandle unused_shape;
164       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape));
165       DimensionHandle unused;
166       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
167       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
168                                  InferenceContext::kUnknownDim));
169       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
170       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
171       return OkStatus();
172     });
173 
174 REGISTER_OP("DeserializeManySparse")
175     .Input("serialized_sparse: string")
176     .Output("sparse_indices: int64")
177     .Output("sparse_values: dtype")
178     .Output("sparse_shape: int64")
179     .Attr("dtype: type")
__anone91b66370802(InferenceContext* c) 180     .SetShapeFn([](InferenceContext* c) {
181       // serialized sparse is [?,3] matrix.
182       ShapeHandle serialized_sparse;
183       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
184       DimensionHandle unused;
185       TF_RETURN_IF_ERROR(
186           c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
187 
188       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
189                                  InferenceContext::kUnknownDim));
190       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
191       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
192       return OkStatus();
193     });
194 
195 REGISTER_OP("SparseToDense")
196     .Input("sparse_indices: Tindices")
197     .Input("output_shape: Tindices")
198     .Input("sparse_values: T")
199     .Input("default_value: T")
200     .Attr("validate_indices: bool = true")
201     .Attr("T: type")
202     .Output("dense: T")
203     .Attr("Tindices: {int32, int64}")
__anone91b66370902(InferenceContext* c) 204     .SetShapeFn([](InferenceContext* c) {
205       ShapeHandle out;
206       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
207       c->set_output(0, out);
208       return OkStatus();
209     });
210 
211 REGISTER_OP("SparseConcat")
212     .Input("indices: N * int64")
213     .Input("values: N * T")
214     .Input("shapes: N * int64")
215     .Output("output_indices: int64")
216     .Output("output_values: T")
217     .Output("output_shape: int64")
218     .Attr("concat_dim: int")
219     .Attr("N: int >= 2")
220     .Attr("T: type")
__anone91b66370a02(InferenceContext* c) 221     .SetShapeFn([](InferenceContext* c) {
222       // These accumulates the sum.
223       DimensionHandle output_row_count = c->MakeDim(0ll);
224 
225       // These are only merged.
226       DimensionHandle output_ind_cols = c->UnknownDim();
227       ShapeHandle output_shape = c->UnknownShape();
228 
229       const int n = c->num_inputs() / 3;
230       for (int i = 0; i < n; i++) {
231         ShapeHandle ind;
232         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
233         ShapeHandle val;
234         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
235         ShapeHandle shape;
236         TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
237 
238         // Add to output_ind_rows.
239         DimensionHandle num_dim;
240         TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
241         TF_RETURN_IF_ERROR(
242             c->Add(output_row_count, num_dim, &output_row_count));
243 
244         // Merge into output_ind_cols and output_shape.
245         TF_RETURN_IF_ERROR(
246             c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
247         TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
248       }
249 
250       c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
251       c->set_output(1, c->Vector(output_row_count));
252       c->set_output(2, output_shape);
253       return OkStatus();
254     });
255 
256 REGISTER_OP("SparseCross")
257     .Input("indices: N * int64")
258     .Input("values: sparse_types")
259     .Input("shapes: N * int64")
260     .Input("dense_inputs: dense_types")
261     .Output("output_indices: int64")
262     .Output("output_values: out_type")
263     .Output("output_shape: int64")
264     .Attr("N: int >= 0")
265     .Attr("hashed_output: bool")
266     .Attr("num_buckets: int >= 0")
267     .Attr("hash_key: int")
268     .Attr("sparse_types: list({int64, string}) >= 0")
269     .Attr("dense_types: list({int64, string}) >= 0")
270     .Attr("out_type: {int64, string}")
271     .Attr("internal_type: {int64, string}")
__anone91b66370b02(shape_inference::InferenceContext* c) 272     .SetShapeFn([](shape_inference::InferenceContext* c) {
273       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
274       c->set_output(1, c->Vector(c->UnknownDim()));
275       c->set_output(2, c->Vector(2));
276       return OkStatus();
277     });
278 
279 REGISTER_OP("SparseCrossV2")
280     .Input("indices: N * int64")
281     .Input("values: sparse_types")
282     .Input("shapes: N * int64")
283     .Input("dense_inputs: dense_types")
284     .Input("sep: string")
285     .Output("output_indices: int64")
286     .Output("output_values: string")
287     .Output("output_shape: int64")
288     .Attr("N: int >= 0")
289     .Attr("sparse_types: list({int64, string}) >= 0")
290     .Attr("dense_types: list({int64, string}) >= 0")
__anone91b66370c02(shape_inference::InferenceContext* c) 291     .SetShapeFn([](shape_inference::InferenceContext* c) {
292       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
293       c->set_output(1, c->Vector(c->UnknownDim()));
294       c->set_output(2, c->Vector(2));
295       return OkStatus();
296     });
297 
298 REGISTER_OP("SparseCrossHashed")
299     .Input("indices: N * int64")
300     .Input("values: sparse_types")
301     .Input("shapes: N * int64")
302     .Input("dense_inputs: dense_types")
303     .Input("num_buckets: int64")
304     .Input("strong_hash: bool")
305     .Input("salt: int64")
306     .Output("output_indices: int64")
307     .Output("output_values: int64")
308     .Output("output_shape: int64")
309     .Attr("N: int >= 0")
310     .Attr("sparse_types: list({int64, string}) >= 0")
311     .Attr("dense_types: list({int64, string}) >= 0")
__anone91b66370d02(shape_inference::InferenceContext* c) 312     .SetShapeFn([](shape_inference::InferenceContext* c) {
313       c->set_output(0, c->Matrix(c->UnknownDim(), 2));
314       c->set_output(1, c->Vector(c->UnknownDim()));
315       c->set_output(2, c->Vector(2));
316       return OkStatus();
317     });
318 
319 REGISTER_OP("SparseSplit")
320     .Input("split_dim: int64")
321     .Input("indices: int64")
322     .Input("values: T")
323     .Input("shape: int64")
324     .Output("output_indices: num_split * int64")
325     .Output("output_values:  num_split * T")
326     .Output("output_shape:   num_split * int64")
327     .Attr("num_split: int >= 1")
328     .Attr("T: type")
__anone91b66370e02(InferenceContext* c) 329     .SetShapeFn([](InferenceContext* c) {
330       ShapeHandle input_shape = c->input(3);
331       ShapeHandle output_indices =
332           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
333       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
334       ShapeHandle output_shape = input_shape;
335 
336       // Copy the outputs into the output ranges.
337       int num_splits = c->num_outputs() / 3;
338       int out_idx = 0;
339       for (int i = 0; i < num_splits; ++i)
340         c->set_output(out_idx++, output_indices);
341       for (int i = 0; i < num_splits; ++i)
342         c->set_output(out_idx++, output_values);
343       for (int i = 0; i < num_splits; ++i)
344         c->set_output(out_idx++, output_shape);
345       return OkStatus();
346     });
347 
348 REGISTER_OP("SparseSliceGrad")
349     .Input("backprop_val_grad: T")
350     .Input("input_indices: int64")
351     .Input("input_start: int64")
352     .Input("output_indices: int64")
353     .Output("val_grad: T")
354     .Attr("T: numbertype")
__anone91b66370f02(InferenceContext* c) 355     .SetShapeFn([](InferenceContext* c) {
356       ShapeHandle indices;
357       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
358       c->set_output(0, c->Vector(c->Dim(indices, 0)));
359       return OkStatus();
360     });
361 
362 REGISTER_OP("SparseSlice")
363     .Input("indices: int64")
364     .Input("values: T")
365     .Input("shape: int64")
366     .Input("start: int64")
367     .Input("size: int64")
368     .Output("output_indices: int64")
369     .Output("output_values: T")
370     .Output("output_shape: int64")
371     .Attr("T: type")
__anone91b66371002(InferenceContext* c) 372     .SetShapeFn([](InferenceContext* c) {
373       ShapeHandle input_shape = c->input(2);
374       ShapeHandle output_indices =
375           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
376       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
377       ShapeHandle output_shape = input_shape;
378 
379       c->set_output(0, output_indices);
380       c->set_output(1, output_values);
381       c->set_output(2, output_shape);
382       return OkStatus();
383     });
384 
385 REGISTER_OP("SparseReorder")
386     .Input("input_indices: int64")
387     .Input("input_values: T")
388     .Input("input_shape: int64")
389     .Output("output_indices: int64")
390     .Output("output_values: T")
391     .Attr("T: type")
__anone91b66371102(InferenceContext* c) 392     .SetShapeFn([](InferenceContext* c) {
393       ShapeHandle indices;
394       ShapeHandle values;
395       ShapeHandle unused;
396 
397       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
398       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
399       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
400 
401       c->set_output(0, indices);
402       c->set_output(1, values);
403       return OkStatus();
404     });
405 
406 REGISTER_OP("SparseReshape")
407     .Input("input_indices: int64")
408     .Input("input_shape: int64")
409     .Input("new_shape: int64")
410     .Output("output_indices: int64")
411     .Output("output_shape: int64")
__anone91b66371202(InferenceContext* c) 412     .SetShapeFn([](InferenceContext* c) {
413       ShapeHandle indices;
414       ShapeHandle unused;
415       ShapeHandle new_shape;
416 
417       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
418       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
419       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
420 
421       c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
422       c->set_output(1, new_shape);
423       return OkStatus();
424     });
425 
426 REGISTER_OP("SparseTensorDenseAdd")
427     .Input("a_indices: Tindices")
428     .Input("a_values: T")
429     .Input("a_shape: Tindices")
430     .Input("b: T")
431     .Output("output: T")
432     .Attr("T: numbertype")
433     .Attr("Tindices: {int32, int64}")
__anone91b66371302(InferenceContext* c) 434     .SetShapeFn([](InferenceContext* c) {
435       c->set_output(0, c->input(3));
436       return OkStatus();
437     });
438 
439 REGISTER_OP("SparseReduceMax")
440     .Input("input_indices: int64")
441     .Input("input_values: T")
442     .Input("input_shape: int64")
443     .Input("reduction_axes: int32")
444     .Attr("keep_dims: bool = False")
445     .Output("output: T")
446     .Attr("T: realnumbertype")
447     .SetShapeFn(shape_inference::SparseReduceShapeFn);
448 
449 REGISTER_OP("SparseReduceMaxSparse")
450     .Input("input_indices: int64")
451     .Input("input_values: T")
452     .Input("input_shape: int64")
453     .Input("reduction_axes: int32")
454     .Attr("keep_dims: bool = False")
455     .Output("output_indices: int64")
456     .Output("output_values: T")
457     .Output("output_shape: int64")
458     .Attr("T: realnumbertype")
459     .SetShapeFn(shape_inference::UnknownShape);
460 
461 REGISTER_OP("SparseReduceSum")
462     .Input("input_indices: int64")
463     .Input("input_values: T")
464     .Input("input_shape: int64")
465     .Input("reduction_axes: int32")
466     .Attr("keep_dims: bool = False")
467     .Output("output: T")
468     .Attr("T: numbertype")
469     .SetShapeFn(shape_inference::SparseReduceShapeFn);
470 
471 REGISTER_OP("SparseReduceSumSparse")
472     .Input("input_indices: int64")
473     .Input("input_values: T")
474     .Input("input_shape: int64")
475     .Input("reduction_axes: int32")
476     .Attr("keep_dims: bool = False")
477     .Output("output_indices: int64")
478     .Output("output_values: T")
479     .Output("output_shape: int64")
480     .Attr("T: numbertype")
481     .SetShapeFn(shape_inference::UnknownShape);
482 
483 #define SPARSE_DENSE_CWISE_SIGNATURE()                           \
484   Input("sp_indices: int64")                                     \
485       .Input("sp_values: T")                                     \
486       .Input("sp_shape: int64")                                  \
487       .Input("dense: T")                                         \
488       .Output("output: T")                                       \
489       .Attr("T: numbertype")                                     \
490       .SetShapeFn([](InferenceContext* c) {                      \
491         ShapeHandle input;                                       \
492         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
493         c->set_output(0, c->Vector(c->Dim(input, 0)));           \
494         return OkStatus();                                       \
495       })
496 
497 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
498 
499 REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
500 
501 REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
502 
503 #undef SPARSE_DENSE_CWISE_SIGNATURE
504 
505 REGISTER_OP("SparseSoftmax")
506     .Input("sp_indices: int64")
507     .Input("sp_values: T")
508     .Input("sp_shape: int64")
509     .Output("output: T")
510     .Attr("T: {float, double}")
__anone91b66371402(InferenceContext* c) 511     .SetShapeFn([](InferenceContext* c) {
512       ShapeHandle unused;
513       ShapeHandle values;
514       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));  // sp_indices
515       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));  // sp_values
516       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
517       c->set_output(0, values);
518       return OkStatus();
519     });
520 
521 REGISTER_OP("SparseSparseMaximum")
522     .Input("a_indices: int64")
523     .Input("a_values: T")
524     .Input("a_shape: int64")
525     .Input("b_indices: int64")
526     .Input("b_values: T")
527     .Input("b_shape: int64")
528     .Output("output_indices: int64")
529     .Output("output_values: T")
530     .Attr("T: realnumbertype")
531     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
532 
533 REGISTER_OP("SparseSparseMinimum")
534     .Input("a_indices: int64")
535     .Input("a_values: T")
536     .Input("a_shape: int64")
537     .Input("b_indices: int64")
538     .Input("b_values: T")
539     .Input("b_shape: int64")
540     .Output("output_indices: int64")
541     .Output("output_values: T")
542     .Attr("T: numbertype")
543     .SetShapeFn(SparseSparseMinOrMaxShapeFn);
544 
545 REGISTER_OP("AddSparseToTensorsMap")
546     .Input("sparse_indices: int64")
547     .Input("sparse_values: T")
548     .Input("sparse_shape: int64")
549     .Output("sparse_handle: int64")
550     .Attr("T: type")
551     .Attr("container: string = ''")
552     .Attr("shared_name: string = ''")
553     .SetIsStateful()
__anone91b66371502(InferenceContext* c) 554     .SetShapeFn([](InferenceContext* c) {
555       ShapeHandle unused;
556       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
557       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
558       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
559       c->set_output(0, c->Scalar());
560       return OkStatus();
561     });
562 
563 REGISTER_OP("AddManySparseToTensorsMap")
564     .Input("sparse_indices: int64")
565     .Input("sparse_values: T")
566     .Input("sparse_shape: int64")
567     .Output("sparse_handles: int64")
568     .Attr("T: type")
569     .Attr("container: string = ''")
570     .Attr("shared_name: string = ''")
571     .SetIsStateful()
__anone91b66371602(InferenceContext* c) 572     .SetShapeFn([](InferenceContext* c) {
573       ShapeHandle unused;
574       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
575       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
576       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
577       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
578       return OkStatus();
579     });
580 
581 REGISTER_OP("TakeManySparseFromTensorsMap")
582     .Input("sparse_handles: int64")
583     .Output("sparse_indices: int64")
584     .Output("sparse_values: dtype")
585     .Output("sparse_shape: int64")
586     .Attr("dtype: type")
587     .Attr("container: string = ''")
588     .Attr("shared_name: string = ''")
589     .SetIsStateful()
__anone91b66371702(InferenceContext* c) 590     .SetShapeFn([](InferenceContext* c) {
591       // serialized sparse is [?,1] matrix.
592       ShapeHandle sparse_handles;
593       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
594 
595       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
596                                  InferenceContext::kUnknownDim));
597       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
598       c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
599       return OkStatus();
600     });
601 
602 REGISTER_OP("SparseFillEmptyRows")
603     .Input("indices: int64")
604     .Input("values: T")
605     .Input("dense_shape: int64")
606     .Input("default_value: T")
607     .Output("output_indices: int64")
608     .Output("output_values: T")
609     .Output("empty_row_indicator: bool")
610     .Output("reverse_index_map: int64")
611     .Attr("T: type")
__anone91b66371802(InferenceContext* c) 612     .SetShapeFn([](InferenceContext* c) {
613       ShapeHandle input_indices = c->input(0);
614       TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
615       ShapeHandle input_values = c->input(1);
616       TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
617       ShapeHandle input_shape = c->input(2);
618       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
619       ShapeHandle default_value = c->input(3);
620       TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
621       DimensionHandle N = c->Dim(input_indices, 0);
622       TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
623       DimensionHandle unused_dim;
624       TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
625                                   c->Dim(input_shape, 0), &unused_dim));
626       if (c->Value(c->NumElements(input_shape)) == 0)
627         return errors::InvalidArgument("dense_shape must not be empty");
628       ShapeHandle output_indices =
629           c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
630       ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
631       ShapeHandle constant_input_shape;
632       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
633       ShapeHandle empty_row_indicator =
634           c->Vector(c->Dim(constant_input_shape, 0));
635       ShapeHandle reverse_index_map = c->Vector(N);
636       c->set_output(0, output_indices);
637       c->set_output(1, output_values);
638       c->set_output(2, empty_row_indicator);
639       c->set_output(3, reverse_index_map);
640       return OkStatus();
641     });
642 
643 REGISTER_OP("SparseFillEmptyRowsGrad")
644     .Input("reverse_index_map: int64")
645     .Input("grad_values: T")
646     .Output("d_values: T")
647     .Output("d_default_value: T")
648     .Attr("T: type")
__anone91b66371902(InferenceContext* c) 649     .SetShapeFn([](InferenceContext* c) {
650       ShapeHandle reverse_index_map = c->input(0);
651       TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
652       ShapeHandle grad_values = c->input(1);
653       TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
654       c->set_output(0, reverse_index_map);
655       c->set_output(1, c->Scalar());
656       return OkStatus();
657     });
658 
659 }  // namespace tensorflow
660