xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/readers.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16import collections
17import csv
18import functools
19import gzip
20
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.data.experimental.ops import error_ops
25from tensorflow.python.data.experimental.ops import parsing_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.data.ops import options as options_lib
28from tensorflow.python.data.ops import readers as core_readers
29from tensorflow.python.data.util import convert
30from tensorflow.python.data.util import nest
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.ops import gen_experimental_dataset_ops
38from tensorflow.python.ops import io_ops
39from tensorflow.python.platform import gfile
40from tensorflow.python.util.tf_export import tf_export
41
42_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
43                         dtypes.int64, dtypes.string)
44
45
46def _is_valid_int32(str_val):
47  try:
48    # Checks equality to prevent int32 overflow
49    return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
50        str_val)
51  except (ValueError, OverflowError):
52    return False
53
54
55def _is_valid_int64(str_val):
56  try:
57    dtypes.int64.as_numpy_dtype(str_val)
58    return True
59  except (ValueError, OverflowError):
60    return False
61
62
63def _is_valid_float(str_val, float_dtype):
64  try:
65    return float_dtype.as_numpy_dtype(str_val) < np.inf
66  except ValueError:
67    return False
68
69
70def _infer_type(str_val, na_value, prev_type):
71  """Given a string, infers its tensor type.
72
73  Infers the type of a value by picking the least 'permissive' type possible,
74  while still allowing the previous type inference for this column to be valid.
75
76  Args:
77    str_val: String value to infer the type of.
78    na_value: Additional string to recognize as a NA/NaN CSV value.
79    prev_type: Type previously inferred based on values of this column that
80      we've seen up till now.
81  Returns:
82    Inferred dtype.
83  """
84  if str_val in ("", na_value):
85    # If the field is null, it gives no extra information about its type
86    return prev_type
87
88  type_list = [
89      dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
90  ]  # list of types to try, ordered from least permissive to most
91
92  type_functions = [
93      _is_valid_int32,
94      _is_valid_int64,
95      lambda str_val: _is_valid_float(str_val, dtypes.float32),
96      lambda str_val: _is_valid_float(str_val, dtypes.float64),
97      lambda str_val: True,
98  ]  # Corresponding list of validation functions
99
100  for i in range(len(type_list)):
101    validation_fn = type_functions[i]
102    if validation_fn(str_val) and (prev_type is None or
103                                   prev_type in type_list[:i + 1]):
104      return type_list[i]
105
106
107def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
108                  file_io_fn):
109  """Generator that yields rows of CSV file(s) in order."""
110  for fn in filenames:
111    with file_io_fn(fn) as f:
112      rdr = csv.reader(
113          f,
114          delimiter=field_delim,
115          quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
116      row_num = 1
117      if header:
118        next(rdr)  # Skip header lines
119        row_num += 1
120
121      for csv_row in rdr:
122        if len(csv_row) != num_cols:
123          raise ValueError(
124              f"Problem inferring types: CSV row {row_num} has {len(csv_row)} "
125              f"number of fields. Expected: {num_cols}.")
126        row_num += 1
127        yield csv_row
128
129
130def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
131                           na_value, header, num_rows_for_inference,
132                           select_columns, file_io_fn):
133  """Infers column types from the first N valid CSV records of files."""
134  if select_columns is None:
135    select_columns = range(num_cols)
136  inferred_types = [None] * len(select_columns)
137
138  for i, csv_row in enumerate(
139      _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header,
140                    file_io_fn)):
141    if num_rows_for_inference is not None and i >= num_rows_for_inference:
142      break
143
144    for j, col_index in enumerate(select_columns):
145      inferred_types[j] = _infer_type(csv_row[col_index], na_value,
146                                      inferred_types[j])
147
148  # Replace None's with a default type
149  inferred_types = [t or dtypes.string for t in inferred_types]
150  # Default to 0 or '' for null values
151  return [
152      constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
153      for t in inferred_types
154  ]
155
156
157def _infer_column_names(filenames, field_delim, use_quote_delim, file_io_fn):
158  """Infers column names from first rows of files."""
159  csv_kwargs = {
160      "delimiter": field_delim,
161      "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
162  }
163  with file_io_fn(filenames[0]) as f:
164    try:
165      column_names = next(csv.reader(f, **csv_kwargs))
166    except StopIteration:
167      raise ValueError("Failed when reading the header line of "
168                       f"{filenames[0]}. Is it an empty file?")
169
170  for name in filenames[1:]:
171    with file_io_fn(name) as f:
172      try:
173        if next(csv.reader(f, **csv_kwargs)) != column_names:
174          raise ValueError(
175              "All input CSV files should have the same column names in the "
176              f"header row. File {name} has different column names.")
177      except StopIteration:
178        raise ValueError("Failed when reading the header line of "
179                         f"{name}. Is it an empty file?")
180  return column_names
181
182
183def _get_sorted_col_indices(select_columns, column_names):
184  """Transforms select_columns argument into sorted column indices."""
185  names_to_indices = {n: i for i, n in enumerate(column_names)}
186  num_cols = len(column_names)
187
188  results = []
189  for v in select_columns:
190    # If value is already an int, check if it's valid.
191    if isinstance(v, int):
192      if v < 0 or v >= num_cols:
193        raise ValueError(
194            f"Column index {v} specified in `select_columns` should be > 0 "
195            f" and <= {num_cols}, which is the number of columns.")
196      results.append(v)
197    # Otherwise, check that it's a valid column name and convert to the
198    # the relevant column index.
199    elif v not in names_to_indices:
200      raise ValueError(
201          f"Column {v} specified in `select_columns` must be of one of the "
202          f"columns: {names_to_indices.keys()}.")
203    else:
204      results.append(names_to_indices[v])
205
206  # Sort and ensure there are no duplicates
207  results = sorted(set(results))
208  if len(results) != len(select_columns):
209    sorted_names = sorted(results)
210    duplicate_columns = set([a for a, b in zip(
211        sorted_names[:-1], sorted_names[1:]) if a == b])
212    raise ValueError("The `select_columns` argument contains duplicate "
213                     f"columns: {duplicate_columns}.")
214  return results
215
216
217def _maybe_shuffle_and_repeat(
218    dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
219  """Optionally shuffle and repeat dataset, as requested."""
220  if shuffle:
221    dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
222  if num_epochs != 1:
223    dataset = dataset.repeat(num_epochs)
224  return dataset
225
226
227def make_tf_record_dataset(file_pattern,
228                           batch_size,
229                           parser_fn=None,
230                           num_epochs=None,
231                           shuffle=True,
232                           shuffle_buffer_size=None,
233                           shuffle_seed=None,
234                           prefetch_buffer_size=None,
235                           num_parallel_reads=None,
236                           num_parallel_parser_calls=None,
237                           drop_final_batch=False):
238  """Reads and optionally parses TFRecord files into a dataset.
239
240  Provides common functionality such as batching, optional parsing, shuffling,
241  and performant defaults.
242
243  Args:
244    file_pattern: List of files or patterns of TFRecord file paths.
245      See `tf.io.gfile.glob` for pattern rules.
246    batch_size: An int representing the number of records to combine
247      in a single batch.
248    parser_fn: (Optional.) A function accepting string input to parse
249      and process the record contents. This function must map records
250      to components of a fixed shape, so they may be batched. By
251      default, uses the record contents unmodified.
252    num_epochs: (Optional.) An int specifying the number of times this
253      dataset is repeated.  If None (the default), cycles through the
254      dataset forever.
255    shuffle: (Optional.) A bool that indicates whether the input
256      should be shuffled. Defaults to `True`.
257    shuffle_buffer_size: (Optional.) Buffer size to use for
258      shuffling. A large buffer size ensures better shuffling, but
259      increases memory usage and startup time.
260    shuffle_seed: (Optional.) Randomization seed to use for shuffling.
261    prefetch_buffer_size: (Optional.) An int specifying the number of
262      feature batches to prefetch for performance improvement.
263      Defaults to auto-tune. Set to 0 to disable prefetching.
264    num_parallel_reads: (Optional.) Number of threads used to read
265      records from files. By default or if set to a value >1, the
266      results will be interleaved. Defaults to `24`.
267    num_parallel_parser_calls: (Optional.) Number of parallel
268      records to parse in parallel. Defaults to `batch_size`.
269    drop_final_batch: (Optional.) Whether the last batch should be
270      dropped in case its size is smaller than `batch_size`; the
271      default behavior is not to drop the smaller batch.
272
273  Returns:
274    A dataset, where each element matches the output of `parser_fn`
275    except it will have an additional leading `batch-size` dimension,
276    or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
277    unspecified.
278  """
279  if num_parallel_reads is None:
280    # NOTE: We considered auto-tuning this value, but there is a concern
281    # that this affects the mixing of records from different files, which
282    # could affect training convergence/accuracy, so we are defaulting to
283    # a constant for now.
284    num_parallel_reads = 24
285
286  if num_parallel_parser_calls is None:
287    # TODO(josh11b): if num_parallel_parser_calls is None, use some function
288    # of num cores instead of `batch_size`.
289    num_parallel_parser_calls = batch_size
290
291  if prefetch_buffer_size is None:
292    prefetch_buffer_size = dataset_ops.AUTOTUNE
293
294  files = dataset_ops.Dataset.list_files(
295      file_pattern, shuffle=shuffle, seed=shuffle_seed)
296
297  dataset = core_readers.TFRecordDataset(
298      files, num_parallel_reads=num_parallel_reads)
299
300  if shuffle_buffer_size is None:
301    # TODO(josh11b): Auto-tune this value when not specified
302    shuffle_buffer_size = 10000
303  dataset = _maybe_shuffle_and_repeat(
304      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
305
306  # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
307  # improve the shape inference, because it makes the batch dimension static.
308  # It is safe to do this because in that case we are repeating the input
309  # indefinitely, and all batches will be full-sized.
310  drop_final_batch = drop_final_batch or num_epochs is None
311
312  if parser_fn is None:
313    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
314  else:
315    dataset = dataset.map(
316        parser_fn, num_parallel_calls=num_parallel_parser_calls)
317    dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
318
319  if prefetch_buffer_size == 0:
320    return dataset
321  else:
322    return dataset.prefetch(buffer_size=prefetch_buffer_size)
323
324
325@tf_export("data.experimental.make_csv_dataset", v1=[])
326def make_csv_dataset_v2(
327    file_pattern,
328    batch_size,
329    column_names=None,
330    column_defaults=None,
331    label_name=None,
332    select_columns=None,
333    field_delim=",",
334    use_quote_delim=True,
335    na_value="",
336    header=True,
337    num_epochs=None,  # TODO(aaudibert): Change default to 1 when graduating.
338    shuffle=True,
339    shuffle_buffer_size=10000,
340    shuffle_seed=None,
341    prefetch_buffer_size=None,
342    num_parallel_reads=None,
343    sloppy=False,
344    num_rows_for_inference=100,
345    compression_type=None,
346    ignore_errors=False,
347    encoding="utf-8",
348):
349  """Reads CSV files into a dataset.
350
351  Reads CSV files into a dataset, where each element of the dataset is a
352  (features, labels) tuple that corresponds to a batch of CSV rows. The features
353  dictionary maps feature column names to `Tensor`s containing the corresponding
354  feature data, and labels is a `Tensor` containing the batch's label data.
355
356  By default, the first rows of the CSV files are expected to be headers listing
357  the column names. If the first rows are not headers, set `header=False` and
358  provide the column names with the `column_names` argument.
359
360  By default, the dataset is repeated indefinitely, reshuffling the order each
361  time. This behavior can be modified by setting the `num_epochs` and `shuffle`
362  arguments.
363
364  For example, suppose you have a CSV file containing
365
366  | Feature_A | Feature_B |
367  | --------- | --------- |
368  | 1         | "a"       |
369  | 2         | "b"       |
370  | 3         | "c"       |
371  | 4         | "d"       |
372
373  ```
374  # No label column specified
375  dataset = tf.data.experimental.make_csv_dataset(filename, batch_size=2)
376  iterator = dataset.as_numpy_iterator()
377  print(dict(next(iterator)))
378  # prints a dictionary of batched features:
379  # OrderedDict([('Feature_A', array([1, 4], dtype=int32)),
380  #              ('Feature_B', array([b'a', b'd'], dtype=object))])
381  ```
382
383  ```
384  # Set Feature_B as label column
385  dataset = tf.data.experimental.make_csv_dataset(
386      filename, batch_size=2, label_name="Feature_B")
387  iterator = dataset.as_numpy_iterator()
388  print(next(iterator))
389  # prints (features, labels) tuple:
390  # (OrderedDict([('Feature_A', array([1, 2], dtype=int32))]),
391  #  array([b'a', b'b'], dtype=object))
392  ```
393
394  See the
395  [Load CSV data guide](https://www.tensorflow.org/tutorials/load_data/csv) for
396  more examples of using `make_csv_dataset` to read CSV data.
397
398  Args:
399    file_pattern: List of files or patterns of file paths containing CSV
400      records. See `tf.io.gfile.glob` for pattern rules.
401    batch_size: An int representing the number of records to combine
402      in a single batch.
403    column_names: An optional list of strings that corresponds to the CSV
404      columns, in order. One per column of the input record. If this is not
405      provided, infers the column names from the first row of the records.
406      These names will be the keys of the features dict of each dataset element.
407    column_defaults: A optional list of default values for the CSV fields. One
408      item per selected column of the input record. Each item in the list is
409      either a valid CSV dtype (float32, float64, int32, int64, or string), or a
410      `Tensor` with one of the aforementioned types. The tensor can either be
411      a scalar default value (if the column is optional), or an empty tensor (if
412      the column is required). If a dtype is provided instead of a tensor, the
413      column is also treated as required. If this list is not provided, tries
414      to infer types based on reading the first num_rows_for_inference rows of
415      files specified, and assumes all columns are optional, defaulting to `0`
416      for numeric values and `""` for string values. If both this and
417      `select_columns` are specified, these must have the same lengths, and
418      `column_defaults` is assumed to be sorted in order of increasing column
419      index.
420    label_name: A optional string corresponding to the label column. If
421      provided, the data for this column is returned as a separate `Tensor` from
422      the features dictionary, so that the dataset complies with the format
423      expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
424      function.
425    select_columns: An optional list of integer indices or string column
426      names, that specifies a subset of columns of CSV data to select. If
427      column names are provided, these must correspond to names provided in
428      `column_names` or inferred from the file header lines. When this argument
429      is specified, only a subset of CSV columns will be parsed and returned,
430      corresponding to the columns specified. Using this results in faster
431      parsing and lower memory usage. If both this and `column_defaults` are
432      specified, these must have the same lengths, and `column_defaults` is
433      assumed to be sorted in order of increasing column index.
434    field_delim: An optional `string`. Defaults to `","`. Char delimiter to
435      separate fields in a record.
436    use_quote_delim: An optional bool. Defaults to `True`. If false, treats
437      double quotation marks as regular characters inside of the string fields.
438    na_value: Additional string to recognize as NA/NaN.
439    header: A bool that indicates whether the first rows of provided CSV files
440      correspond to header lines with column names, and should not be included
441      in the data.
442    num_epochs: An int specifying the number of times this dataset is repeated.
443      If None, cycles through the dataset forever.
444    shuffle: A bool that indicates whether the input should be shuffled.
445    shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
446      ensures better shuffling, but increases memory usage and startup time.
447    shuffle_seed: Randomization seed to use for shuffling.
448    prefetch_buffer_size: An int specifying the number of feature
449      batches to prefetch for performance improvement. Recommended value is the
450      number of batches consumed per training step. Defaults to auto-tune.
451    num_parallel_reads: Number of threads used to read CSV records from files.
452      If >1, the results will be interleaved. Defaults to `1`.
453    sloppy: If `True`, reading performance will be improved at
454      the cost of non-deterministic ordering. If `False`, the order of elements
455      produced is deterministic prior to shuffling (elements are still
456      randomized if `shuffle=True`. Note that if the seed is set, then order
457      of elements after shuffling is deterministic). Defaults to `False`.
458    num_rows_for_inference: Number of rows of a file to use for type inference
459      if record_defaults is not provided. If None, reads all the rows of all
460      the files. Defaults to 100.
461    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
462      `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
463    ignore_errors: (Optional.) If `True`, ignores errors with CSV file parsing,
464      such as malformed data or empty lines, and moves on to the next valid
465      CSV record. Otherwise, the dataset raises an error and stops processing
466      when encountering any invalid records. Defaults to `False`.
467    encoding: Encoding to use when reading. Defaults to `UTF-8`.
468
469  Returns:
470    A dataset, where each element is a (features, labels) tuple that corresponds
471    to a batch of `batch_size` CSV rows. The features dictionary maps feature
472    column names to `Tensor`s containing the corresponding column data, and
473    labels is a `Tensor` containing the column data for the label column
474    specified by `label_name`.
475
476  Raises:
477    ValueError: If any of the arguments is malformed.
478  """
479  if num_parallel_reads is None:
480    num_parallel_reads = 1
481
482  if prefetch_buffer_size is None:
483    prefetch_buffer_size = dataset_ops.AUTOTUNE
484
485  # Create dataset of all matching filenames
486  filenames = _get_file_names(file_pattern, False)
487  dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
488  if shuffle:
489    dataset = dataset.shuffle(len(filenames), shuffle_seed)
490
491  # Clean arguments; figure out column names and defaults
492  if column_names is None or column_defaults is None:
493    # Find out which io function to open the file
494    file_io_fn = lambda filename: file_io.FileIO(  # pylint: disable=g-long-lambda
495        filename, "r", encoding=encoding)
496    if compression_type is not None:
497      compression_type_value = tensor_util.constant_value(compression_type)
498      if compression_type_value is None:
499        raise ValueError(
500            f"Received unknown `compression_type` {compression_type}. "
501            "Expected: GZIP, ZLIB or "" (empty string).")
502      if compression_type_value == "GZIP":
503        file_io_fn = lambda filename: gzip.open(  # pylint: disable=g-long-lambda
504            filename, "rt", encoding=encoding)
505      elif compression_type_value == "ZLIB":
506        raise ValueError(
507            f"`compression_type` {compression_type} is not supported for "
508            "probing columns.")
509      elif compression_type_value != "":
510        raise ValueError(
511            f"Received unknown `compression_type` {compression_type}. "
512            "Expected: GZIP, ZLIB or "
513            " (empty string).")
514  if column_names is None:
515    if not header:
516      raise ValueError("Expected `column_names` or `header` arguments. Neither "
517                       "is provided.")
518    # If column names are not provided, infer from the header lines
519    column_names = _infer_column_names(filenames, field_delim, use_quote_delim,
520                                       file_io_fn)
521  if len(column_names) != len(set(column_names)):
522    sorted_names = sorted(column_names)
523    duplicate_columns = set([a for a, b in zip(
524        sorted_names[:-1], sorted_names[1:]) if a == b])
525    raise ValueError(
526        "Either `column_names` argument or CSV header row contains duplicate "
527        f"column names: {duplicate_columns}.")
528
529  if select_columns is not None:
530    select_columns = _get_sorted_col_indices(select_columns, column_names)
531
532  if column_defaults is not None:
533    column_defaults = [
534        constant_op.constant([], dtype=x)
535        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
536        for x in column_defaults
537    ]
538  else:
539    # If column defaults are not provided, infer from records at graph
540    # construction time
541    column_defaults = _infer_column_defaults(filenames, len(column_names),
542                                             field_delim, use_quote_delim,
543                                             na_value, header,
544                                             num_rows_for_inference,
545                                             select_columns, file_io_fn)
546
547  if select_columns is not None and len(column_defaults) != len(select_columns):
548    raise ValueError(
549        "If specified, `column_defaults` and `select_columns` must have the "
550        f"same length: `column_defaults` has length {len(column_defaults)}, "
551        f"`select_columns` has length {len(select_columns)}.")
552  if select_columns is not None and len(column_names) > len(select_columns):
553    # Pick the relevant subset of column names
554    column_names = [column_names[i] for i in select_columns]
555
556  if label_name is not None and label_name not in column_names:
557    raise ValueError("`label_name` provided must be one of the columns: "
558                     f"{column_names}. Received: {label_name}.")
559
560  def filename_to_dataset(filename):
561    dataset = CsvDataset(
562        filename,
563        record_defaults=column_defaults,
564        field_delim=field_delim,
565        use_quote_delim=use_quote_delim,
566        na_value=na_value,
567        select_cols=select_columns,
568        header=header,
569        compression_type=compression_type
570    )
571    if ignore_errors:
572      dataset = dataset.apply(error_ops.ignore_errors())
573    return dataset
574
575  def map_fn(*columns):
576    """Organizes columns into a features dictionary.
577
578    Args:
579      *columns: list of `Tensor`s corresponding to one csv record.
580    Returns:
581      An OrderedDict of feature names to values for that particular record. If
582      label_name is provided, extracts the label feature to be returned as the
583      second element of the tuple.
584    """
585    features = collections.OrderedDict(zip(column_names, columns))
586    if label_name is not None:
587      label = features.pop(label_name)
588      return features, label
589    return features
590
591  if num_parallel_reads == dataset_ops.AUTOTUNE:
592    dataset = dataset.interleave(
593        filename_to_dataset, num_parallel_calls=num_parallel_reads)
594    options = options_lib.Options()
595    options.deterministic = not sloppy
596    dataset = dataset.with_options(options)
597  else:
598    # Read files sequentially (if num_parallel_reads=1) or in parallel
599    def apply_fn(dataset):
600      return core_readers.ParallelInterleaveDataset(
601          dataset,
602          filename_to_dataset,
603          cycle_length=num_parallel_reads,
604          block_length=1,
605          sloppy=sloppy,
606          buffer_output_elements=None,
607          prefetch_input_elements=None)
608
609    dataset = dataset.apply(apply_fn)
610
611  dataset = _maybe_shuffle_and_repeat(
612      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
613
614  # Apply batch before map for perf, because map has high overhead relative
615  # to the size of the computation in each map.
616  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
617  # improve the shape inference, because it makes the batch dimension static.
618  # It is safe to do this because in that case we are repeating the input
619  # indefinitely, and all batches will be full-sized.
620  dataset = dataset.batch(batch_size=batch_size,
621                          drop_remainder=num_epochs is None)
622  dataset = dataset_ops.MapDataset(
623      dataset, map_fn, use_inter_op_parallelism=False)
624  dataset = dataset.prefetch(prefetch_buffer_size)
625
626  return dataset
627
628
629@tf_export(v1=["data.experimental.make_csv_dataset"])
630def make_csv_dataset_v1(
631    file_pattern,
632    batch_size,
633    column_names=None,
634    column_defaults=None,
635    label_name=None,
636    select_columns=None,
637    field_delim=",",
638    use_quote_delim=True,
639    na_value="",
640    header=True,
641    num_epochs=None,
642    shuffle=True,
643    shuffle_buffer_size=10000,
644    shuffle_seed=None,
645    prefetch_buffer_size=None,
646    num_parallel_reads=None,
647    sloppy=False,
648    num_rows_for_inference=100,
649    compression_type=None,
650    ignore_errors=False,
651    encoding="utf-8",
652):  # pylint: disable=missing-docstring
653  return dataset_ops.DatasetV1Adapter(
654      make_csv_dataset_v2(file_pattern, batch_size, column_names,
655                          column_defaults, label_name, select_columns,
656                          field_delim, use_quote_delim, na_value, header,
657                          num_epochs, shuffle, shuffle_buffer_size,
658                          shuffle_seed, prefetch_buffer_size,
659                          num_parallel_reads, sloppy, num_rows_for_inference,
660                          compression_type, ignore_errors, encoding))
661make_csv_dataset_v1.__doc__ = make_csv_dataset_v2.__doc__
662
663
664_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024  # 4 MB
665
666
667@tf_export("data.experimental.CsvDataset", v1=[])
668class CsvDatasetV2(dataset_ops.DatasetSource):
669  r"""A Dataset comprising lines from one or more CSV files.
670
671  The `tf.data.experimental.CsvDataset` class provides a minimal CSV Dataset
672  interface. There is also a richer `tf.data.experimental.make_csv_dataset`
673  function which provides additional convenience features such as column header
674  parsing, column type-inference, automatic shuffling, and file interleaving.
675
676  The elements of this dataset correspond to records from the file(s).
677  RFC 4180 format is expected for CSV files
678  (https://tools.ietf.org/html/rfc4180)
679  Note that we allow leading and trailing spaces for int or float fields.
680
681  For example, suppose we have a file 'my_file0.csv' with four CSV columns of
682  different data types:
683
684  >>> with open('/tmp/my_file0.csv', 'w') as f:
685  ...   f.write('abcdefg,4.28E10,5.55E6,12\n')
686  ...   f.write('hijklmn,-5.3E14,,2\n')
687
688  We can construct a CsvDataset from it as follows:
689
690  >>> dataset = tf.data.experimental.CsvDataset(
691  ...   "/tmp/my_file0.csv",
692  ...   [tf.float32,  # Required field, use dtype or empty tensor
693  ...    tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
694  ...    tf.int32,  # Required field, use dtype or empty tensor
695  ...   ],
696  ...   select_cols=[1,2,3]  # Only parse last three columns
697  ... )
698
699  The expected output of its iterations is:
700
701  >>> for element in dataset.as_numpy_iterator():
702  ...   print(element)
703  (4.28e10, 5.55e6, 12)
704  (-5.3e14, 0.0, 2)
705
706  See
707  https://www.tensorflow.org/tutorials/load_data/csv#tfdataexperimentalcsvdataset
708  for more in-depth example usage.
709  """
710
711  def __init__(self,
712               filenames,
713               record_defaults,
714               compression_type=None,
715               buffer_size=None,
716               header=False,
717               field_delim=",",
718               use_quote_delim=True,
719               na_value="",
720               select_cols=None,
721               exclude_cols=None):
722    """Creates a `CsvDataset` by reading and decoding CSV files.
723
724    Args:
725      filenames: A `tf.string` tensor containing one or more filenames.
726      record_defaults: A list of default values for the CSV fields. Each item in
727        the list is either a valid CSV `DType` (float32, float64, int32, int64,
728        string), or a `Tensor` object with one of the above types. One per
729        column of CSV data, with either a scalar `Tensor` default value for the
730        column if it is optional, or `DType` or empty `Tensor` if required. If
731        both this and `select_columns` are specified, these must have the same
732        lengths, and `column_defaults` is assumed to be sorted in order of
733        increasing column index. If both this and 'exclude_cols' are specified,
734        the sum of lengths of record_defaults and exclude_cols should equal
735        the total number of columns in the CSV file.
736      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
737        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
738        compression.
739      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
740        to buffer while reading files. Defaults to 4MB.
741      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
742        have header line(s) that should be skipped when parsing. Defaults to
743        `False`.
744      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
745        character that separates fields in a record. Defaults to `","`.
746      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
747        double quotation marks as regular characters inside of string fields
748        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
749      na_value: (Optional.) A `tf.string` scalar indicating a value that will
750        be treated as NA/NaN.
751      select_cols: (Optional.) A sorted list of column indices to select from
752        the input data. If specified, only this subset of columns will be
753        parsed. Defaults to parsing all columns. At most one of `select_cols`
754        and `exclude_cols` can be specified.
755      exclude_cols: (Optional.) A sorted list of column indices to exclude from
756        the input data. If specified, only the complement of this set of column
757        will be parsed. Defaults to parsing all columns. At most one of
758        `select_cols` and `exclude_cols` can be specified.
759
760    Raises:
761       InvalidArgumentError: If exclude_cols is not None and
762           len(exclude_cols) + len(record_defaults) does not match the total
763           number of columns in the file(s)
764
765
766    """
767    self._filenames = ops.convert_to_tensor(
768        filenames, dtype=dtypes.string, name="filenames")
769    self._compression_type = convert.optional_param_to_tensor(
770        "compression_type",
771        compression_type,
772        argument_default="",
773        argument_dtype=dtypes.string)
774    record_defaults = [
775        constant_op.constant([], dtype=x)
776        if not tensor_util.is_tf_type(x) and x in _ACCEPTABLE_CSV_TYPES else x
777        for x in record_defaults
778    ]
779    self._record_defaults = ops.convert_n_to_tensor(
780        record_defaults, name="record_defaults")
781    self._buffer_size = convert.optional_param_to_tensor(
782        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
783    self._header = ops.convert_to_tensor(
784        header, dtype=dtypes.bool, name="header")
785    self._field_delim = ops.convert_to_tensor(
786        field_delim, dtype=dtypes.string, name="field_delim")
787    self._use_quote_delim = ops.convert_to_tensor(
788        use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
789    self._na_value = ops.convert_to_tensor(
790        na_value, dtype=dtypes.string, name="na_value")
791    self._select_cols = convert.optional_param_to_tensor(
792        "select_cols",
793        select_cols,
794        argument_default=[],
795        argument_dtype=dtypes.int64,
796    )
797    self._exclude_cols = convert.optional_param_to_tensor(
798        "exclude_cols",
799        exclude_cols,
800        argument_default=[],
801        argument_dtype=dtypes.int64,
802    )
803    self._element_spec = tuple(
804        tensor_spec.TensorSpec([], d.dtype) for d in self._record_defaults)
805    variant_tensor = gen_experimental_dataset_ops.csv_dataset_v2(
806        filenames=self._filenames,
807        record_defaults=self._record_defaults,
808        buffer_size=self._buffer_size,
809        header=self._header,
810        output_shapes=self._flat_shapes,
811        field_delim=self._field_delim,
812        use_quote_delim=self._use_quote_delim,
813        na_value=self._na_value,
814        select_cols=self._select_cols,
815        exclude_cols=self._exclude_cols,
816        compression_type=self._compression_type)
817    super(CsvDatasetV2, self).__init__(variant_tensor)
818
819  @property
820  def element_spec(self):
821    return self._element_spec
822
823
824@tf_export(v1=["data.experimental.CsvDataset"])
825class CsvDatasetV1(dataset_ops.DatasetV1Adapter):
826  """A Dataset comprising lines from one or more CSV files."""
827
828  @functools.wraps(CsvDatasetV2.__init__, ("__module__", "__name__"))
829  def __init__(self,
830               filenames,
831               record_defaults,
832               compression_type=None,
833               buffer_size=None,
834               header=False,
835               field_delim=",",
836               use_quote_delim=True,
837               na_value="",
838               select_cols=None):
839    """Creates a `CsvDataset` by reading and decoding CSV files.
840
841    The elements of this dataset correspond to records from the file(s).
842    RFC 4180 format is expected for CSV files
843    (https://tools.ietf.org/html/rfc4180)
844    Note that we allow leading and trailing spaces with int or float field.
845
846
847    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
848    different data types:
849    ```
850    abcdefg,4.28E10,5.55E6,12
851    hijklmn,-5.3E14,,2
852    ```
853
854    We can construct a CsvDataset from it as follows:
855
856    ```python
857     dataset = tf.data.experimental.CsvDataset(
858        "my_file*.csv",
859        [tf.float32,  # Required field, use dtype or empty tensor
860         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
861         tf.int32,  # Required field, use dtype or empty tensor
862         ],
863        select_cols=[1,2,3]  # Only parse last three columns
864    )
865    ```
866
867    The expected output of its iterations is:
868
869    ```python
870    for element in dataset:
871      print(element)
872
873    >> (4.28e10, 5.55e6, 12)
874    >> (-5.3e14, 0.0, 2)
875    ```
876
877    Args:
878      filenames: A `tf.string` tensor containing one or more filenames.
879      record_defaults: A list of default values for the CSV fields. Each item in
880        the list is either a valid CSV `DType` (float32, float64, int32, int64,
881        string), or a `Tensor` object with one of the above types. One per
882        column of CSV data, with either a scalar `Tensor` default value for the
883        column if it is optional, or `DType` or empty `Tensor` if required. If
884        both this and `select_columns` are specified, these must have the same
885        lengths, and `column_defaults` is assumed to be sorted in order of
886        increasing column index. If both this and 'exclude_cols' are specified,
887        the sum of lengths of record_defaults and exclude_cols should equal the
888        total number of columns in the CSV file.
889      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
890        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
891        compression.
892      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
893        to buffer while reading files. Defaults to 4MB.
894      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
895        have header line(s) that should be skipped when parsing. Defaults to
896        `False`.
897      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
898        character that separates fields in a record. Defaults to `","`.
899      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double
900        quotation marks as regular characters inside of string fields (ignoring
901        RFC 4180, Section 2, Bullet 5). Defaults to `True`.
902      na_value: (Optional.) A `tf.string` scalar indicating a value that will be
903        treated as NA/NaN.
904      select_cols: (Optional.) A sorted list of column indices to select from
905        the input data. If specified, only this subset of columns will be
906        parsed. Defaults to parsing all columns. At most one of `select_cols`
907        and `exclude_cols` can be specified.
908    """
909    wrapped = CsvDatasetV2(filenames, record_defaults, compression_type,
910                           buffer_size, header, field_delim, use_quote_delim,
911                           na_value, select_cols)
912    super(CsvDatasetV1, self).__init__(wrapped)
913
914
915@tf_export("data.experimental.make_batched_features_dataset", v1=[])
916def make_batched_features_dataset_v2(file_pattern,
917                                     batch_size,
918                                     features,
919                                     reader=None,
920                                     label_key=None,
921                                     reader_args=None,
922                                     num_epochs=None,
923                                     shuffle=True,
924                                     shuffle_buffer_size=10000,
925                                     shuffle_seed=None,
926                                     prefetch_buffer_size=None,
927                                     reader_num_threads=None,
928                                     parser_num_threads=None,
929                                     sloppy_ordering=False,
930                                     drop_final_batch=False):
931  """Returns a `Dataset` of feature dictionaries from `Example` protos.
932
933  If label_key argument is provided, returns a `Dataset` of tuple
934  comprising of feature dictionaries and label.
935
936  Example:
937
938  ```
939  serialized_examples = [
940    features {
941      feature { key: "age" value { int64_list { value: [ 0 ] } } }
942      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
943      feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
944    },
945    features {
946      feature { key: "age" value { int64_list { value: [] } } }
947      feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
948      feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
949    }
950  ]
951  ```
952
953  We can use arguments:
954
955  ```
956  features: {
957    "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
958    "gender": FixedLenFeature([], dtype=tf.string),
959    "kws": VarLenFeature(dtype=tf.string),
960  }
961  ```
962
963  And the expected output is:
964
965  ```python
966  {
967    "age": [[0], [-1]],
968    "gender": [["f"], ["f"]],
969    "kws": SparseTensor(
970      indices=[[0, 0], [0, 1], [1, 0]],
971      values=["code", "art", "sports"]
972      dense_shape=[2, 2]),
973  }
974  ```
975
976  Args:
977    file_pattern: List of files or patterns of file paths containing
978      `Example` records. See `tf.io.gfile.glob` for pattern rules.
979    batch_size: An int representing the number of records to combine
980      in a single batch.
981    features: A `dict` mapping feature keys to `FixedLenFeature` or
982      `VarLenFeature` values. See `tf.io.parse_example`.
983    reader: A function or class that can be
984      called with a `filenames` tensor and (optional) `reader_args` and returns
985      a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
986    label_key: (Optional) A string corresponding to the key labels are stored in
987      `tf.Examples`. If provided, it must be one of the `features` key,
988      otherwise results in `ValueError`.
989    reader_args: Additional arguments to pass to the reader class.
990    num_epochs: Integer specifying the number of times to read through the
991      dataset. If None, cycles through the dataset forever. Defaults to `None`.
992    shuffle: A boolean, indicates whether the input should be shuffled. Defaults
993      to `True`.
994    shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
995      ensures better shuffling but would increase memory usage and startup time.
996    shuffle_seed: Randomization seed to use for shuffling.
997    prefetch_buffer_size: Number of feature batches to prefetch in order to
998      improve performance. Recommended value is the number of batches consumed
999      per training step. Defaults to auto-tune.
1000    reader_num_threads: Number of threads used to read `Example` records. If >1,
1001      the results will be interleaved. Defaults to `1`.
1002    parser_num_threads: Number of threads to use for parsing `Example` tensors
1003      into a dictionary of `Feature` tensors. Defaults to `2`.
1004    sloppy_ordering: If `True`, reading performance will be improved at
1005      the cost of non-deterministic ordering. If `False`, the order of elements
1006      produced is deterministic prior to shuffling (elements are still
1007      randomized if `shuffle=True`. Note that if the seed is set, then order
1008      of elements after shuffling is deterministic). Defaults to `False`.
1009    drop_final_batch: If `True`, and the batch size does not evenly divide the
1010      input dataset size, the final smaller batch will be dropped. Defaults to
1011      `False`.
1012
1013  Returns:
1014    A dataset of `dict` elements, (or a tuple of `dict` elements and label).
1015    Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
1016
1017  Raises:
1018    TypeError: If `reader` is of the wrong type.
1019    ValueError: If `label_key` is not one of the `features` keys.
1020  """
1021  if reader is None:
1022    reader = core_readers.TFRecordDataset
1023
1024  if reader_num_threads is None:
1025    reader_num_threads = 1
1026  if parser_num_threads is None:
1027    parser_num_threads = 2
1028  if prefetch_buffer_size is None:
1029    prefetch_buffer_size = dataset_ops.AUTOTUNE
1030
1031  # Create dataset of all matching filenames
1032  dataset = dataset_ops.Dataset.list_files(
1033      file_pattern, shuffle=shuffle, seed=shuffle_seed)
1034
1035  if isinstance(reader, type) and issubclass(reader, io_ops.ReaderBase):
1036    raise TypeError("The `reader` argument must return a `Dataset` object. "
1037                    "`tf.ReaderBase` subclasses are not supported. For "
1038                    "example, pass `tf.data.TFRecordDataset` instead of "
1039                    "`tf.TFRecordReader`.")
1040
1041  # Read `Example` records from files as tensor objects.
1042  if reader_args is None:
1043    reader_args = []
1044
1045  if reader_num_threads == dataset_ops.AUTOTUNE:
1046    dataset = dataset.interleave(
1047        lambda filename: reader(filename, *reader_args),
1048        num_parallel_calls=reader_num_threads)
1049    options = options_lib.Options()
1050    options.deterministic = not sloppy_ordering
1051    dataset = dataset.with_options(options)
1052  else:
1053    # Read files sequentially (if reader_num_threads=1) or in parallel
1054    def apply_fn(dataset):
1055      return core_readers.ParallelInterleaveDataset(
1056          dataset,
1057          lambda filename: reader(filename, *reader_args),
1058          cycle_length=reader_num_threads,
1059          block_length=1,
1060          sloppy=sloppy_ordering,
1061          buffer_output_elements=None,
1062          prefetch_input_elements=None)
1063
1064    dataset = dataset.apply(apply_fn)
1065
1066  # Extract values if the `Example` tensors are stored as key-value tuples.
1067  if dataset_ops.get_legacy_output_types(dataset) == (
1068      dtypes.string, dtypes.string):
1069    dataset = dataset_ops.MapDataset(
1070        dataset, lambda _, v: v, use_inter_op_parallelism=False)
1071
1072  # Apply dataset repeat and shuffle transformations.
1073  dataset = _maybe_shuffle_and_repeat(
1074      dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
1075
1076  # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
1077  # improve the shape inference, because it makes the batch dimension static.
1078  # It is safe to do this because in that case we are repeating the input
1079  # indefinitely, and all batches will be full-sized.
1080  dataset = dataset.batch(
1081      batch_size, drop_remainder=drop_final_batch or num_epochs is None)
1082
1083  # Parse `Example` tensors to a dictionary of `Feature` tensors.
1084  dataset = dataset.apply(
1085      parsing_ops.parse_example_dataset(
1086          features, num_parallel_calls=parser_num_threads))
1087
1088  if label_key:
1089    if label_key not in features:
1090      raise ValueError(
1091          f"The `label_key` provided ({label_key}) must be one of the "
1092          f"`features` keys: {features.keys()}.")
1093    dataset = dataset.map(lambda x: (x, x.pop(label_key)))
1094
1095  dataset = dataset.prefetch(prefetch_buffer_size)
1096  return dataset
1097
1098
1099@tf_export(v1=["data.experimental.make_batched_features_dataset"])
1100def make_batched_features_dataset_v1(file_pattern,  # pylint: disable=missing-docstring
1101                                     batch_size,
1102                                     features,
1103                                     reader=None,
1104                                     label_key=None,
1105                                     reader_args=None,
1106                                     num_epochs=None,
1107                                     shuffle=True,
1108                                     shuffle_buffer_size=10000,
1109                                     shuffle_seed=None,
1110                                     prefetch_buffer_size=None,
1111                                     reader_num_threads=None,
1112                                     parser_num_threads=None,
1113                                     sloppy_ordering=False,
1114                                     drop_final_batch=False):
1115  return dataset_ops.DatasetV1Adapter(make_batched_features_dataset_v2(
1116      file_pattern, batch_size, features, reader, label_key, reader_args,
1117      num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
1118      prefetch_buffer_size, reader_num_threads, parser_num_threads,
1119      sloppy_ordering, drop_final_batch))
1120make_batched_features_dataset_v1.__doc__ = (
1121    make_batched_features_dataset_v2.__doc__)
1122
1123
1124def _get_file_names(file_pattern, shuffle):
1125  """Parse list of file names from pattern, optionally shuffled.
1126
1127  Args:
1128    file_pattern: File glob pattern, or list of glob patterns.
1129    shuffle: Whether to shuffle the order of file names.
1130
1131  Returns:
1132    List of file names matching `file_pattern`.
1133
1134  Raises:
1135    ValueError: If `file_pattern` is empty, or pattern matches no files.
1136  """
1137  if isinstance(file_pattern, list):
1138    if not file_pattern:
1139      raise ValueError("Argument `file_pattern` should not be empty.")
1140    file_names = []
1141    for entry in file_pattern:
1142      file_names.extend(gfile.Glob(entry))
1143  else:
1144    file_names = list(gfile.Glob(file_pattern))
1145
1146  if not file_names:
1147    raise ValueError(f"No files match `file_pattern` {file_pattern}.")
1148
1149  # Sort files so it will be deterministic for unit tests.
1150  if not shuffle:
1151    file_names = sorted(file_names)
1152  return file_names
1153
1154
1155@tf_export("data.experimental.SqlDataset", v1=[])
1156class SqlDatasetV2(dataset_ops.DatasetSource):
1157  """A `Dataset` consisting of the results from a SQL query.
1158
1159  `SqlDataset` allows a user to read data from the result set of a SQL query.
1160  For example:
1161
1162  ```python
1163  dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
1164                                            "SELECT name, age FROM people",
1165                                            (tf.string, tf.int32))
1166  # Prints the rows of the result set of the above query.
1167  for element in dataset:
1168    print(element)
1169  ```
1170  """
1171
1172  def __init__(self, driver_name, data_source_name, query, output_types):
1173    """Creates a `SqlDataset`.
1174
1175    Args:
1176      driver_name: A 0-D `tf.string` tensor containing the database type.
1177        Currently, the only supported value is 'sqlite'.
1178      data_source_name: A 0-D `tf.string` tensor containing a connection string
1179        to connect to the database.
1180      query: A 0-D `tf.string` tensor containing the SQL query to execute.
1181      output_types: A tuple of `tf.DType` objects representing the types of the
1182        columns returned by `query`.
1183    """
1184    self._driver_name = ops.convert_to_tensor(
1185        driver_name, dtype=dtypes.string, name="driver_name")
1186    self._data_source_name = ops.convert_to_tensor(
1187        data_source_name, dtype=dtypes.string, name="data_source_name")
1188    self._query = ops.convert_to_tensor(
1189        query, dtype=dtypes.string, name="query")
1190    self._element_spec = nest.map_structure(
1191        lambda dtype: tensor_spec.TensorSpec([], dtype), output_types)
1192    variant_tensor = gen_experimental_dataset_ops.sql_dataset(
1193        self._driver_name, self._data_source_name, self._query,
1194        **self._flat_structure)
1195    super(SqlDatasetV2, self).__init__(variant_tensor)
1196
1197  @property
1198  def element_spec(self):
1199    return self._element_spec
1200
1201
1202@tf_export(v1=["data.experimental.SqlDataset"])
1203class SqlDatasetV1(dataset_ops.DatasetV1Adapter):
1204  """A `Dataset` consisting of the results from a SQL query."""
1205
1206  @functools.wraps(SqlDatasetV2.__init__)
1207  def __init__(self, driver_name, data_source_name, query, output_types):
1208    wrapped = SqlDatasetV2(driver_name, data_source_name, query, output_types)
1209    super(SqlDatasetV1, self).__init__(wrapped)
1210
1211
1212if tf2.enabled():
1213  CsvDataset = CsvDatasetV2
1214  SqlDataset = SqlDatasetV2
1215  make_batched_features_dataset = make_batched_features_dataset_v2
1216  make_csv_dataset = make_csv_dataset_v2
1217else:
1218  CsvDataset = CsvDatasetV1
1219  SqlDataset = SqlDatasetV1
1220  make_batched_features_dataset = make_batched_features_dataset_v1
1221  make_csv_dataset = make_csv_dataset_v1
1222