1# Copyright 2018 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for absl.flags.argparse_flags."""
16
17import io
18import os
19import subprocess
20import sys
21import tempfile
22from unittest import mock
23
24from absl import flags
25from absl import logging
26from absl.flags import argparse_flags
27from absl.testing import _bazelize_command
28from absl.testing import absltest
29from absl.testing import parameterized
30
31
32class ArgparseFlagsTest(parameterized.TestCase):
33
34  def setUp(self):
35    super().setUp()
36    self._absl_flags = flags.FlagValues()
37    flags.DEFINE_bool(
38        'absl_bool', None, 'help for --absl_bool.',
39        short_name='b', flag_values=self._absl_flags)
40    # Add a boolean flag that starts with "no", to verify it can correctly
41    # handle the "no" prefixes in boolean flags.
42    flags.DEFINE_bool(
43        'notice', None, 'help for --notice.',
44        flag_values=self._absl_flags)
45    flags.DEFINE_string(
46        'absl_string', 'default', 'help for --absl_string=%.',
47        short_name='s', flag_values=self._absl_flags)
48    flags.DEFINE_integer(
49        'absl_integer', 1, 'help for --absl_integer.',
50        flag_values=self._absl_flags)
51    flags.DEFINE_float(
52        'absl_float', 1, 'help for --absl_integer.',
53        flag_values=self._absl_flags)
54    flags.DEFINE_enum(
55        'absl_enum', 'apple', ['apple', 'orange'], 'help for --absl_enum.',
56        flag_values=self._absl_flags)
57
58  def test_dash_as_prefix_char_only(self):
59    with self.assertRaises(ValueError):
60      argparse_flags.ArgumentParser(prefix_chars='/')
61
62  def test_default_inherited_absl_flags_value(self):
63    parser = argparse_flags.ArgumentParser()
64    self.assertIs(parser._inherited_absl_flags, flags.FLAGS)
65
66  def test_parse_absl_flags(self):
67    parser = argparse_flags.ArgumentParser(
68        inherited_absl_flags=self._absl_flags)
69    self.assertFalse(self._absl_flags.is_parsed())
70    self.assertTrue(self._absl_flags['absl_string'].using_default_value)
71    self.assertTrue(self._absl_flags['absl_integer'].using_default_value)
72    self.assertTrue(self._absl_flags['absl_float'].using_default_value)
73    self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
74
75    parser.parse_args(
76        ['--absl_string=new_string', '--absl_integer', '2'])
77    self.assertEqual(self._absl_flags.absl_string, 'new_string')
78    self.assertEqual(self._absl_flags.absl_integer, 2)
79    self.assertTrue(self._absl_flags.is_parsed())
80    self.assertFalse(self._absl_flags['absl_string'].using_default_value)
81    self.assertFalse(self._absl_flags['absl_integer'].using_default_value)
82    self.assertTrue(self._absl_flags['absl_float'].using_default_value)
83    self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
84
85  @parameterized.named_parameters(
86      ('true', ['--absl_bool'], True),
87      ('false', ['--noabsl_bool'], False),
88      ('does_not_accept_equal_value', ['--absl_bool=true'], SystemExit),
89      ('does_not_accept_space_value', ['--absl_bool', 'true'], SystemExit),
90      ('long_name_single_dash', ['-absl_bool'], SystemExit),
91      ('short_name', ['-b'], True),
92      ('short_name_false', ['-nob'], SystemExit),
93      ('short_name_double_dash', ['--b'], SystemExit),
94      ('short_name_double_dash_false', ['--nob'], SystemExit),
95  )
96  def test_parse_boolean_flags(self, args, expected):
97    parser = argparse_flags.ArgumentParser(
98        inherited_absl_flags=self._absl_flags)
99    self.assertIsNone(self._absl_flags['absl_bool'].value)
100    self.assertIsNone(self._absl_flags['b'].value)
101    if isinstance(expected, bool):
102      parser.parse_args(args)
103      self.assertEqual(expected, self._absl_flags.absl_bool)
104      self.assertEqual(expected, self._absl_flags.b)
105    else:
106      with self.assertRaises(expected):
107        parser.parse_args(args)
108
109  @parameterized.named_parameters(
110      ('true', ['--notice'], True),
111      ('false', ['--nonotice'], False),
112  )
113  def test_parse_boolean_existing_no_prefix(self, args, expected):
114    parser = argparse_flags.ArgumentParser(
115        inherited_absl_flags=self._absl_flags)
116    self.assertIsNone(self._absl_flags['notice'].value)
117    parser.parse_args(args)
118    self.assertEqual(expected, self._absl_flags.notice)
119
120  def test_unrecognized_flag(self):
121    parser = argparse_flags.ArgumentParser(
122        inherited_absl_flags=self._absl_flags)
123    with self.assertRaises(SystemExit):
124      parser.parse_args(['--unknown_flag=what'])
125
126  def test_absl_validators(self):
127
128    @flags.validator('absl_integer', flag_values=self._absl_flags)
129    def ensure_positive(value):
130      return value > 0
131
132    parser = argparse_flags.ArgumentParser(
133        inherited_absl_flags=self._absl_flags)
134    with self.assertRaises(SystemExit):
135      parser.parse_args(['--absl_integer', '-2'])
136
137    del ensure_positive
138
139  @parameterized.named_parameters(
140      ('regular_name_double_dash', '--absl_string=new_string', 'new_string'),
141      ('regular_name_single_dash', '-absl_string=new_string', SystemExit),
142      ('short_name_double_dash', '--s=new_string', SystemExit),
143      ('short_name_single_dash', '-s=new_string', 'new_string'),
144  )
145  def test_dashes(self, argument, expected):
146    parser = argparse_flags.ArgumentParser(
147        inherited_absl_flags=self._absl_flags)
148    if isinstance(expected, str):
149      parser.parse_args([argument])
150      self.assertEqual(self._absl_flags.absl_string, expected)
151    else:
152      with self.assertRaises(expected):
153        parser.parse_args([argument])
154
155  def test_absl_flags_not_added_to_namespace(self):
156    parser = argparse_flags.ArgumentParser(
157        inherited_absl_flags=self._absl_flags)
158    args = parser.parse_args(['--absl_string=new_string'])
159    self.assertIsNone(getattr(args, 'absl_string', None))
160
161  def test_mixed_flags_and_positional(self):
162    parser = argparse_flags.ArgumentParser(
163        inherited_absl_flags=self._absl_flags)
164    parser.add_argument('--header', help='Header message to print.')
165    parser.add_argument('integers', metavar='N', type=int, nargs='+',
166                        help='an integer for the accumulator')
167
168    args = parser.parse_args(
169        ['--absl_string=new_string', '--header=HEADER', '--absl_integer',
170         '2', '3', '4'])
171    self.assertEqual(self._absl_flags.absl_string, 'new_string')
172    self.assertEqual(self._absl_flags.absl_integer, 2)
173    self.assertEqual(args.header, 'HEADER')
174    self.assertListEqual(args.integers, [3, 4])
175
176  def test_subparsers(self):
177    parser = argparse_flags.ArgumentParser(
178        inherited_absl_flags=self._absl_flags)
179    parser.add_argument('--header', help='Header message to print.')
180    subparsers = parser.add_subparsers(help='The command to execute.')
181
182    # NOTE: The sub parsers don't work well with typing hence `type: ignore`.
183    # See https://github.com/python/typeshed/issues/10082.
184    sub_parser = subparsers.add_parser(  # type: ignore
185        'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags
186    )
187    sub_parser.add_argument('--sub_flag', help='Sub command flag.')
188
189    def sub_command_func():
190      pass
191
192    sub_parser.set_defaults(command=sub_command_func)
193
194    args = parser.parse_args([
195        '--header=HEADER', '--absl_string=new_value', 'sub_cmd',
196        '--absl_integer=2', '--sub_flag=new_sub_flag_value'])
197
198    self.assertEqual(args.header, 'HEADER')
199    self.assertEqual(self._absl_flags.absl_string, 'new_value')
200    self.assertEqual(args.command, sub_command_func)
201    self.assertEqual(self._absl_flags.absl_integer, 2)
202    self.assertEqual(args.sub_flag, 'new_sub_flag_value')
203
204  def test_subparsers_no_inherit_in_subparser(self):
205    parser = argparse_flags.ArgumentParser(
206        inherited_absl_flags=self._absl_flags)
207    subparsers = parser.add_subparsers(help='The command to execute.')
208
209    # NOTE: The sub parsers don't work well with typing hence `type: ignore`.
210    # See https://github.com/python/typeshed/issues/10082.
211    subparsers.add_parser(  # type: ignore
212        'sub_cmd',
213        help='Sub command.',
214        # Do not inherit absl flags in the subparser.
215        # This is the behavior that this test exercises.
216        inherited_absl_flags=None,
217    )
218
219    with self.assertRaises(SystemExit):
220      parser.parse_args(['sub_cmd', '--absl_string=new_value'])
221
222  def test_help_main_module_flags(self):
223    parser = argparse_flags.ArgumentParser(
224        inherited_absl_flags=self._absl_flags)
225    help_message = parser.format_help()
226
227    # Only the short name is shown in the usage string.
228    self.assertIn('[-s ABSL_STRING]', help_message)
229    # Both names are included in the options section.
230    self.assertIn('-s ABSL_STRING, --absl_string ABSL_STRING', help_message)
231    # Verify help messages.
232    self.assertIn('help for --absl_string=%.', help_message)
233    self.assertIn('<apple|orange>: help for --absl_enum.', help_message)
234
235  def test_help_non_main_module_flags(self):
236    flags.DEFINE_string(
237        'non_main_module_flag', 'default', 'help',
238        module_name='other.module', flag_values=self._absl_flags)
239    parser = argparse_flags.ArgumentParser(
240        inherited_absl_flags=self._absl_flags)
241    help_message = parser.format_help()
242
243    # Non main module key flags are not printed in the help message.
244    self.assertNotIn('non_main_module_flag', help_message)
245
246  def test_help_non_main_module_key_flags(self):
247    flags.DEFINE_string(
248        'non_main_module_flag', 'default', 'help',
249        module_name='other.module', flag_values=self._absl_flags)
250    flags.declare_key_flag('non_main_module_flag', flag_values=self._absl_flags)
251    parser = argparse_flags.ArgumentParser(
252        inherited_absl_flags=self._absl_flags)
253    help_message = parser.format_help()
254
255    # Main module key fags are printed in the help message, even if the flag
256    # is defined in another module.
257    self.assertIn('non_main_module_flag', help_message)
258
259  @parameterized.named_parameters(
260      ('h', ['-h']),
261      ('help', ['--help']),
262      ('helpshort', ['--helpshort']),
263      ('helpfull', ['--helpfull']),
264  )
265  def test_help_flags(self, args):
266    parser = argparse_flags.ArgumentParser(
267        inherited_absl_flags=self._absl_flags)
268    with self.assertRaises(SystemExit):
269      parser.parse_args(args)
270
271  @parameterized.named_parameters(
272      ('h', ['-h']),
273      ('help', ['--help']),
274      ('helpshort', ['--helpshort']),
275      ('helpfull', ['--helpfull']),
276  )
277  def test_no_help_flags(self, args):
278    parser = argparse_flags.ArgumentParser(
279        inherited_absl_flags=self._absl_flags, add_help=False)
280    with mock.patch.object(parser, 'print_help') as print_help_mock:
281      with self.assertRaises(SystemExit):
282        parser.parse_args(args)
283    print_help_mock.assert_not_called()
284
285  def test_helpfull_message(self):
286    flags.DEFINE_string(
287        'non_main_module_flag', 'default', 'help',
288        module_name='other.module', flag_values=self._absl_flags)
289    parser = argparse_flags.ArgumentParser(
290        inherited_absl_flags=self._absl_flags)
291    with self.assertRaises(SystemExit),\
292        mock.patch.object(sys, 'stdout', new=io.StringIO()) as mock_stdout:
293      parser.parse_args(['--helpfull'])
294    stdout_message = mock_stdout.getvalue()
295    logging.info('captured stdout message:\n%s', stdout_message)
296    self.assertIn('--non_main_module_flag', stdout_message)
297    self.assertIn('other.module', stdout_message)
298    # Make sure the main module is not included.
299    self.assertNotIn(sys.argv[0], stdout_message)
300    # Special flags defined in absl.flags.
301    self.assertIn('absl.flags:', stdout_message)
302    self.assertIn('--flagfile', stdout_message)
303    self.assertIn('--undefok', stdout_message)
304
305  @parameterized.named_parameters(
306      ('at_end',
307       ('1', '--absl_string=value_from_cmd', '--flagfile='),
308       'value_from_file'),
309      ('at_beginning',
310       ('--flagfile=', '1', '--absl_string=value_from_cmd'),
311       'value_from_cmd'),
312  )
313  def test_flagfile(self, cmd_args, expected_absl_string_value):
314    # Set gnu_getopt to False, to verify it's ignored by argparse_flags.
315    self._absl_flags.set_gnu_getopt(False)
316
317    parser = argparse_flags.ArgumentParser(
318        inherited_absl_flags=self._absl_flags)
319    parser.add_argument('--header', help='Header message to print.')
320    parser.add_argument('integers', metavar='N', type=int, nargs='+',
321                        help='an integer for the accumulator')
322    flagfile = tempfile.NamedTemporaryFile(
323        dir=absltest.TEST_TMPDIR.value, delete=False)
324    self.addCleanup(os.unlink, flagfile.name)
325    with flagfile:
326      flagfile.write(b'''
327# The flag file.
328--absl_string=value_from_file
329--absl_integer=1
330--header=header_from_file
331''')
332
333    expand_flagfile = lambda x: x + flagfile.name if x == '--flagfile=' else x
334    cmd_args = [expand_flagfile(x) for x in cmd_args]
335    args = parser.parse_args(cmd_args)
336
337    self.assertEqual([1], args.integers)
338    self.assertEqual('header_from_file', args.header)
339    self.assertEqual(expected_absl_string_value, self._absl_flags.absl_string)
340
341  @parameterized.parameters(
342      ('positional', {'positional'}, False),
343      ('--not_existed', {'existed'}, False),
344      ('--empty', set(), False),
345      ('-single_dash', {'single_dash'}, True),
346      ('--double_dash', {'double_dash'}, True),
347      ('--with_value=value', {'with_value'}, True),
348  )
349  def test_is_undefok(self, arg, undefok_names, is_undefok):
350    self.assertEqual(is_undefok, argparse_flags._is_undefok(arg, undefok_names))
351
352  @parameterized.named_parameters(
353      ('single', 'single', ['--single'], []),
354      ('multiple', 'first,second', ['--first', '--second'], []),
355      ('single_dash', 'dash', ['-dash'], []),
356      ('mixed_dash', 'mixed', ['-mixed', '--mixed'], []),
357      ('value', 'name', ['--name=value'], []),
358      ('boolean_positive', 'bool', ['--bool'], []),
359      ('boolean_negative', 'bool', ['--nobool'], []),
360      ('left_over', 'strip', ['--first', '--strip', '--last'],
361       ['--first', '--last']),
362  )
363  def test_strip_undefok_args(self, undefok, args, expected_args):
364    actual_args = argparse_flags._strip_undefok_args(undefok, args)
365    self.assertListEqual(expected_args, actual_args)
366
367  @parameterized.named_parameters(
368      ('at_end', ['--unknown', '--undefok=unknown']),
369      ('at_beginning', ['--undefok=unknown', '--unknown']),
370      ('multiple', ['--unknown', '--undefok=unknown,another_unknown']),
371      ('with_value', ['--unknown=value', '--undefok=unknown']),
372      ('maybe_boolean', ['--nounknown', '--undefok=unknown']),
373      ('with_space', ['--unknown', '--undefok', 'unknown']),
374  )
375  def test_undefok_flag_correct_use(self, cmd_args):
376    parser = argparse_flags.ArgumentParser(
377        inherited_absl_flags=self._absl_flags)
378    args = parser.parse_args(cmd_args)  # Make sure it doesn't raise.
379    # Make sure `undefok` is not exposed in namespace.
380    sentinel = object()
381    self.assertIs(sentinel, getattr(args, 'undefok', sentinel))
382
383  def test_undefok_flag_existing(self):
384    parser = argparse_flags.ArgumentParser(
385        inherited_absl_flags=self._absl_flags)
386    parser.parse_args(
387        ['--absl_string=new_value', '--undefok=absl_string'])
388    self.assertEqual('new_value', self._absl_flags.absl_string)
389
390  @parameterized.named_parameters(
391      ('no_equal', ['--unknown', 'value', '--undefok=unknown']),
392      ('single_dash', ['--unknown', '-undefok=unknown']),
393  )
394  def test_undefok_flag_incorrect_use(self, cmd_args):
395    parser = argparse_flags.ArgumentParser(
396        inherited_absl_flags=self._absl_flags)
397    with self.assertRaises(SystemExit):
398      parser.parse_args(cmd_args)
399
400  def test_argument_default(self):
401    # Regression test for https://github.com/abseil/abseil-py/issues/171.
402    parser = argparse_flags.ArgumentParser(
403        inherited_absl_flags=self._absl_flags, argument_default=23)
404    parser.add_argument(
405        '--magic_number', type=int, help='The magic number to use.')
406    args = parser.parse_args([])
407    self.assertEqual(args.magic_number, 23)
408
409  def test_empty_inherited_absl_flags(self):
410    parser = argparse_flags.ArgumentParser(
411        inherited_absl_flags=flags.FlagValues()
412    )
413    parser.add_argument('--foo')
414    flagfile = self.create_tempfile(content='--foo=bar').full_path
415    # Make sure these flags are still available when inheriting an empty
416    # FlagValues instance.
417    ns = parser.parse_args([
418        '--undefok=undefined_flag',
419        '--undefined_flag=value',
420        '--flagfile=' + flagfile,
421    ])
422    self.assertEqual(ns.foo, 'bar')
423
424
425class ArgparseWithAppRunTest(parameterized.TestCase):
426
427  @parameterized.named_parameters(
428      ('simple',
429       'main_simple', 'parse_flags_simple',
430       ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.'],
431       ['I am argparse.', 'I am absl.']),
432      ('subcommand_roll_dice',
433       'main_subcommands', 'parse_flags_subcommands',
434       ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
435        'roll_dice', '--num_faces=12'],
436       ['I am argparse.', 'I am absl.', 'Rolled a dice: ']),
437      ('subcommand_shuffle',
438       'main_subcommands', 'parse_flags_subcommands',
439       ['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
440        'shuffle', 'a', 'b', 'c'],
441       ['I am argparse.', 'I am absl.', 'Shuffled: ']),
442  )
443  def test_argparse_with_app_run(
444      self, main_func_name, flags_parser_func_name, args, output_strings):
445    env = os.environ.copy()
446    env['MAIN_FUNC'] = main_func_name
447    env['FLAGS_PARSER_FUNC'] = flags_parser_func_name
448    helper = _bazelize_command.get_executable_path(
449        'absl/flags/tests/argparse_flags_test_helper')
450    try:
451      stdout = subprocess.check_output(
452          [helper] + args, env=env, universal_newlines=True)
453    except subprocess.CalledProcessError as e:
454      error_info = ('ERROR: argparse_helper failed\n'
455                    'Command: {}\n'
456                    'Exit code: {}\n'
457                    '----- output -----\n{}'
458                    '------------------')
459      error_info = error_info.format(e.cmd, e.returncode,
460                                     e.output + '\n' if e.output else '<empty>')
461      print(error_info, file=sys.stderr)
462      raise
463
464    for output_string in output_strings:
465      self.assertIn(output_string, stdout)
466
467
468if __name__ == '__main__':
469  absltest.main()
470