xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/external_dataset.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Provides the 'ExternalDataset' implementation of tf.Data.Dataset.
16
17This wraps the generated op (in external_dataset_py_wrapper).
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24
25import tensorflow as tf
26
27from fcp.tensorflow import gen_external_dataset_py
28
29_external_dataset_so = tf.load_op_library(
30    tf.compat.v1.resource_loader.get_path_to_datafile(
31        "./_external_dataset_op.so"))
32
33
34class ExternalDataset(tf.data.Dataset):
35  """An ExternalDataset is defined by whomever is running the graph.
36
37  To use an ExternalDataset, the graph must be fed a 'token' indicating what
38  external dataset to use. It also takes a 'selector' input - an opaque string,
39  to be interpreted by that external implementation.
40  """
41
42  def __init__(self, token, selector):
43    token = tf.convert_to_tensor(token, dtype=tf.string, name="token")
44    selector = tf.convert_to_tensor(selector, dtype=tf.string, name="selector")
45    variant_tensor = gen_external_dataset_py.ExternalDataset(
46        token=token, selector=selector)
47    super(ExternalDataset, self).__init__(variant_tensor)
48
49  @property
50  def element_spec(self):
51    return tf.TensorSpec([], tf.string)
52
53  def _inputs(self):
54    return []
55