xref: /aosp_15_r20/external/tink/python/examples/streaming_aead/streaming_aead.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# [START streaming-aead-example]
16"""A command-line utility for using streaming AEAD for a file.
17
18It loads cleartext keys from disk - this is not recommended!
19
20It requires 4 arguments (and one optional one):
21  mode: either 'encrypt' or 'decrypt'
22  keyset_path: name of the file with the keyset to be used for encryption or
23    decryption
24  input_path: name of the file with the input data to be encrypted or decrypted
25  output_path: name of the file to write the ciphertext respectively plaintext
26    to
27  [optional] associated_data: the associated data used for encryption/decryption
28    provided as a string.
29"""
30
31from typing import BinaryIO
32
33from absl import app
34from absl import flags
35from absl import logging
36import tink
37from tink import cleartext_keyset_handle
38from tink import streaming_aead
39
40FLAGS = flags.FLAGS
41BLOCK_SIZE = 1024 * 1024  # The CLI tool will read/write at most 1 MB at once.
42
43flags.DEFINE_enum('mode', None, ['encrypt', 'decrypt'],
44                  'Selects if the file should be encrypted or decrypted.')
45flags.DEFINE_string('keyset_path', None,
46                    'Path to the keyset used for encryption or decryption.')
47flags.DEFINE_string('input_path', None, 'Path to the input file.')
48flags.DEFINE_string('output_path', None, 'Path to the output file.')
49flags.DEFINE_string('associated_data', None,
50                    'Associated data used for the encryption or decryption.')
51
52
53def read_as_blocks(file: BinaryIO):
54  """Generator function to read from a file BLOCK_SIZE bytes.
55
56  Args:
57    file: The file object to read from.
58
59  Yields:
60    Returns up to BLOCK_SIZE bytes from the file.
61  """
62  while True:
63    data = file.read(BLOCK_SIZE)
64    # If file was opened in rawIO, EOF is only reached when b'' is returned.
65    # pylint: disable=g-explicit-bool-comparison
66    if data == b'':
67      break
68    # pylint: enable=g-explicit-bool-comparison
69    yield data
70
71
72def encrypt_file(input_file: BinaryIO, output_file: BinaryIO,
73                 associated_data: bytes,
74                 primitive: streaming_aead.StreamingAead):
75  """Encrypts a file with the given streaming AEAD primitive.
76
77  Args:
78    input_file: File to read from.
79    output_file: File to write to.
80    associated_data: Associated data provided for the AEAD.
81    primitive: The streaming AEAD primitive used for encryption.
82  """
83  with primitive.new_encrypting_stream(output_file,
84                                       associated_data) as enc_stream:
85    for data_block in read_as_blocks(input_file):
86      enc_stream.write(data_block)
87
88
89def decrypt_file(input_file: BinaryIO, output_file: BinaryIO,
90                 associated_data: bytes,
91                 primitive: streaming_aead.StreamingAead):
92  """Decrypts a file with the given streaming AEAD primitive.
93
94  This function will cause the program to exit with 1 if the decryption fails.
95
96  Args:
97    input_file: File to read from.
98    output_file: File to write to.
99    associated_data: Associated data provided for the AEAD.
100    primitive: The streaming AEAD primitive used for decryption.
101  """
102  try:
103    with primitive.new_decrypting_stream(input_file,
104                                         associated_data) as dec_stream:
105      for data_block in read_as_blocks(dec_stream):
106        output_file.write(data_block)
107  except tink.TinkError as e:
108    logging.exception('Error decrypting ciphertext: %s', e)
109    exit(1)
110
111
112def main(argv):
113  del argv
114
115  associated_data = b'' if not FLAGS.associated_data else bytes(
116      FLAGS.associated_data, 'utf-8')
117
118  # Initialise Tink.
119  try:
120    streaming_aead.register()
121  except tink.TinkError as e:
122    logging.exception('Error initialising Tink: %s', e)
123    return 1
124
125  # Read the keyset into a keyset_handle.
126  with open(FLAGS.keyset_path, 'rt') as keyset_file:
127    try:
128      text = keyset_file.read()
129      keyset_handle = cleartext_keyset_handle.read(tink.JsonKeysetReader(text))
130    except tink.TinkError as e:
131      logging.exception('Error reading key: %s', e)
132      return 1
133
134  # Get the primitive.
135  try:
136    streaming_aead_primitive = keyset_handle.primitive(
137        streaming_aead.StreamingAead)
138  except tink.TinkError as e:
139    logging.exception('Error creating streaming AEAD primitive from keyset: %s',
140                      e)
141    return 1
142
143  # Encrypt or decrypt the file.
144  with open(FLAGS.input_path, 'rb') as input_file:
145    with open(FLAGS.output_path, 'wb') as output_file:
146      if FLAGS.mode == 'encrypt':
147        encrypt_file(input_file, output_file, associated_data,
148                     streaming_aead_primitive)
149      elif FLAGS.mode == 'decrypt':
150        decrypt_file(input_file, output_file, associated_data,
151                     streaming_aead_primitive)
152
153
154if __name__ == '__main__':
155  flags.mark_flag_as_required('mode')
156  flags.mark_flag_as_required('keyset_path')
157  flags.mark_flag_as_required('input_path')
158  flags.mark_flag_as_required('output_path')
159  app.run(main)
160
161# [END streaming-aead-example]
162