1# Copyright 2016 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import re
16import unittest
17
18from mobly import signals
19
20# Have an instance of unittest.TestCase so we could reuse some logic
21# from python's own unittest.
22_pyunit_proxy = unittest.TestCase()
23_pyunit_proxy.maxDiff = None
24
25
26def _call_unittest_assertion(
27    assertion_method, *args, msg=None, extras=None, **kwargs
28):
29  """Wrapper for converting a unittest assertion into a Mobly one.
30
31  Args:
32    assertion_method: unittest.TestCase assertion method to call.
33    *args: Positional arguments for the assertion call.
34    msg: A string that adds additional info about the failure.
35    extras: An optional field for extra information to be included in
36      test result.
37    **kwargs: Keyword arguments for the assertion call.
38  """
39  my_msg = None
40  try:
41    assertion_method(*args, **kwargs)
42  except AssertionError as e:
43    my_msg = str(e)
44    if msg:
45      my_msg = f'{my_msg} {msg}'
46
47  # This raise statement is outside of the above except statement to
48  # prevent Python3's exception message from having two tracebacks.
49  if my_msg is not None:
50    raise signals.TestFailure(my_msg, extras=extras)
51
52
53def assert_equal(first, second, msg=None, extras=None):
54  """Asserts the equality of objects, otherwise fail the test.
55
56  Error message is "first != second" by default. Additional explanation can
57  be supplied in the message.
58
59  Args:
60    first: The first object to compare.
61    second: The second object to compare.
62    msg: A string that adds additional info about the failure.
63    extras: An optional field for extra information to be included in
64      test result.
65  """
66  _call_unittest_assertion(
67      _pyunit_proxy.assertEqual, first, second, msg=msg, extras=extras
68  )
69
70
71def assert_not_equal(first, second, msg=None, extras=None):
72  """Asserts that first is not equal (!=) to second."""
73  _call_unittest_assertion(
74      _pyunit_proxy.assertNotEqual, first, second, msg=msg, extras=extras
75  )
76
77
78def assert_almost_equal(
79    first, second, places=None, msg=None, delta=None, extras=None
80):
81  """Asserts that first is almost equal to second.
82
83  Fails if the two objects are unequal as determined by their difference
84  rounded to the given number of decimal places (default 7) and
85  comparing to zero, or by comparing that the difference between the two
86  objects is more than the given delta.
87  If the two objects compare equal then they automatically compare
88  almost equal.
89
90  Args:
91    first: The first value to compare.
92    second: The second value to compare.
93    places: How many decimal places to take into account for comparison.
94      Note that decimal places (from zero) are usually not the same
95      as significant digits (measured from the most significant digit).
96    msg: A string that adds additional info about the failure.
97    delta: Delta to use for comparison instead of decimal places.
98    extras: An optional field for extra information to be included in
99      test result.
100  """
101  _call_unittest_assertion(
102      _pyunit_proxy.assertAlmostEqual,
103      first,
104      second,
105      places=places,
106      msg=msg,
107      delta=delta,
108      extras=extras,
109  )
110
111
112def assert_not_almost_equal(
113    first, second, places=None, msg=None, delta=None, extras=None
114):
115  """Asserts that first is not almost equal to second.
116
117  Args:
118    first: The first value to compare.
119    second: The second value to compare.
120    places: How many decimal places to take into account for comparison.
121      Note that decimal places (from zero) are usually not the same
122      as significant digits (measured from the most significant digit).
123    msg: A string that adds additional info about the failure.
124    delta: Delta to use for comparison instead of decimal places.
125    extras: An optional field for extra information to be included in
126      test result.
127  """
128  _call_unittest_assertion(
129      _pyunit_proxy.assertNotAlmostEqual,
130      first,
131      second,
132      places=places,
133      msg=msg,
134      delta=delta,
135      extras=extras,
136  )
137
138
139def assert_in(member, container, msg=None, extras=None):
140  """Asserts that member is in container."""
141  _call_unittest_assertion(
142      _pyunit_proxy.assertIn, member, container, msg=msg, extras=extras
143  )
144
145
146def assert_not_in(member, container, msg=None, extras=None):
147  """Asserts that member is not in container."""
148  _call_unittest_assertion(
149      _pyunit_proxy.assertNotIn, member, container, msg=msg, extras=extras
150  )
151
152
153def assert_is(expr1, expr2, msg=None, extras=None):
154  """Asserts that expr1 is expr2."""
155  _call_unittest_assertion(
156      _pyunit_proxy.assertIs, expr1, expr2, msg=msg, extras=extras
157  )
158
159
160def assert_is_not(expr1, expr2, msg=None, extras=None):
161  """Asserts that expr1 is not expr2."""
162  _call_unittest_assertion(
163      _pyunit_proxy.assertIsNot, expr1, expr2, msg=msg, extras=extras
164  )
165
166
167def assert_count_equal(first, second, msg=None, extras=None):
168  """Asserts that two iterables have the same elements, the same number of
169  times, without regard to order.
170
171  Similar to assert_equal(Counter(list(first)), Counter(list(second))).
172
173  Args:
174    first: The first iterable to compare.
175    second: The second iterable to compare.
176    msg: A string that adds additional info about the failure.
177    extras: An optional field for extra information to be included in
178      test result.
179
180  Example:
181    assert_count_equal([0, 1, 1], [1, 0, 1]) passes the assertion.
182    assert_count_equal([0, 0, 1], [0, 1]) raises an assertion error.
183  """
184  _call_unittest_assertion(
185      _pyunit_proxy.assertCountEqual, first, second, msg=msg, extras=extras
186  )
187
188
189def assert_less(a, b, msg=None, extras=None):
190  """Asserts that a < b."""
191  _call_unittest_assertion(
192      _pyunit_proxy.assertLess, a, b, msg=msg, extras=extras
193  )
194
195
196def assert_less_equal(a, b, msg=None, extras=None):
197  """Asserts that a <= b."""
198  _call_unittest_assertion(
199      _pyunit_proxy.assertLessEqual, a, b, msg=msg, extras=extras
200  )
201
202
203def assert_greater(a, b, msg=None, extras=None):
204  """Asserts that a > b."""
205  _call_unittest_assertion(
206      _pyunit_proxy.assertGreater, a, b, msg=msg, extras=extras
207  )
208
209
210def assert_greater_equal(a, b, msg=None, extras=None):
211  """Asserts that a >= b."""
212  _call_unittest_assertion(
213      _pyunit_proxy.assertGreaterEqual, a, b, msg=msg, extras=extras
214  )
215
216
217def assert_is_none(obj, msg=None, extras=None):
218  """Asserts that obj is None."""
219  _call_unittest_assertion(
220      _pyunit_proxy.assertIsNone, obj, msg=msg, extras=extras
221  )
222
223
224def assert_is_not_none(obj, msg=None, extras=None):
225  """Asserts that obj is not None."""
226  _call_unittest_assertion(
227      _pyunit_proxy.assertIsNotNone, obj, msg=msg, extras=extras
228  )
229
230
231def assert_is_instance(obj, cls, msg=None, extras=None):
232  """Asserts that obj is an instance of cls."""
233  _call_unittest_assertion(
234      _pyunit_proxy.assertIsInstance, obj, cls, msg=msg, extras=extras
235  )
236
237
238def assert_not_is_instance(obj, cls, msg=None, extras=None):
239  """Asserts that obj is not an instance of cls."""
240  _call_unittest_assertion(
241      _pyunit_proxy.assertNotIsInstance, obj, cls, msg=msg, extras=extras
242  )
243
244
245def assert_regex(text, expected_regex, msg=None, extras=None):
246  """Fails the test unless the text matches the regular expression."""
247  _call_unittest_assertion(
248      _pyunit_proxy.assertRegex, text, expected_regex, msg=msg, extras=extras
249  )
250
251
252def assert_not_regex(text, unexpected_regex, msg=None, extras=None):
253  """Fails the test if the text matches the regular expression."""
254  _call_unittest_assertion(
255      _pyunit_proxy.assertNotRegex,
256      text,
257      unexpected_regex,
258      msg=msg,
259      extras=extras,
260  )
261
262
263def assert_raises(expected_exception, extras=None, *args, **kwargs):
264  """Assert that an exception is raised when a function is called.
265
266  If no exception is raised, test fail. If an exception is raised but not
267  of the expected type, the exception is let through.
268
269  This should only be used as a context manager:
270    with assert_raises(Exception):
271      func()
272
273  Args:
274    expected_exception: An exception class that is expected to be
275      raised.
276    extras: An optional field for extra information to be included in
277      test result.
278  """
279  context = _AssertRaisesContext(expected_exception, extras=extras)
280  return context
281
282
283def assert_raises_regex(
284    expected_exception, expected_regex, extras=None, *args, **kwargs
285):
286  """Assert that an exception is raised when a function is called.
287
288  If no exception is raised, test fail. If an exception is raised but not
289  of the expected type, the exception is let through. If an exception of the
290  expected type is raised but the error message does not match the
291  expected_regex, test fail.
292
293  This should only be used as a context manager:
294    with assert_raises(Exception):
295      func()
296
297  Args:
298    expected_exception: An exception class that is expected to be
299      raised.
300    extras: An optional field for extra information to be included in
301      test result.
302  """
303  context = _AssertRaisesContext(
304      expected_exception, expected_regex, extras=extras
305  )
306  return context
307
308
309def assert_true(expr, msg, extras=None):
310  """Assert an expression evaluates to true, otherwise fail the test.
311
312  Args:
313    expr: The expression that is evaluated.
314    msg: A string explaining the details in case of failure.
315    extras: An optional field for extra information to be included in
316      test result.
317  """
318  if not expr:
319    fail(msg, extras)
320
321
322def assert_false(expr, msg, extras=None):
323  """Assert an expression evaluates to false, otherwise fail the test.
324
325  Args:
326    expr: The expression that is evaluated.
327    msg: A string explaining the details in case of failure.
328    extras: An optional field for extra information to be included in
329      test result.
330  """
331  if expr:
332    fail(msg, extras)
333
334
335def skip(reason, extras=None):
336  """Skip a test.
337
338  Args:
339    reason: The reason this test is skipped.
340    extras: An optional field for extra information to be included in
341      test result.
342
343  Raises:
344    signals.TestSkip: Mark a test as skipped.
345  """
346  raise signals.TestSkip(reason, extras)
347
348
349def skip_if(expr, reason, extras=None):
350  """Skip a test if expression evaluates to True.
351
352  Args:
353    expr: The expression that is evaluated.
354    reason: The reason this test is skipped.
355    extras: An optional field for extra information to be included in
356      test result.
357  """
358  if expr:
359    skip(reason, extras)
360
361
362def abort_class(reason, extras=None):
363  """Abort all subsequent tests within the same test class in one iteration.
364
365  If one test class is requested multiple times in a test run, this can
366  only abort one of the requested executions, NOT all.
367
368  Args:
369    reason: The reason to abort.
370    extras: An optional field for extra information to be included in
371      test result.
372
373  Raises:
374    signals.TestAbortClass: Abort all subsequent tests in a test class.
375  """
376  raise signals.TestAbortClass(reason, extras)
377
378
379def abort_class_if(expr, reason, extras=None):
380  """Abort all subsequent tests within the same test class in one iteration,
381  if expression evaluates to True.
382
383  If one test class is requested multiple times in a test run, this can
384  only abort one of the requested executions, NOT all.
385
386  Args:
387    expr: The expression that is evaluated.
388    reason: The reason to abort.
389    extras: An optional field for extra information to be included in
390      test result.
391
392  Raises:
393    signals.TestAbortClass: Abort all subsequent tests in a test class.
394  """
395  if expr:
396    abort_class(reason, extras)
397
398
399def abort_all(reason, extras=None):
400  """Abort all subsequent tests, including the ones not in this test class or
401  iteration.
402
403  Args:
404    reason: The reason to abort.
405    extras: An optional field for extra information to be included in
406      test result.
407
408  Raises:
409    signals.TestAbortAll: Abort all subsequent tests.
410  """
411  raise signals.TestAbortAll(reason, extras)
412
413
414def abort_all_if(expr, reason, extras=None):
415  """Abort all subsequent tests, if the expression evaluates to True.
416
417  Args:
418    expr: The expression that is evaluated.
419    reason: The reason to abort.
420    extras: An optional field for extra information to be included in
421      test result.
422
423  Raises:
424    signals.TestAbortAll: Abort all subsequent tests.
425  """
426  if expr:
427    abort_all(reason, extras)
428
429
430def fail(msg, extras=None):
431  """Explicitly fail a test.
432
433  Args:
434    msg: A string explaining the details of the failure.
435    extras: An optional field for extra information to be included in
436      test result.
437
438  Raises:
439    signals.TestFailure: Mark a test as failed.
440  """
441  raise signals.TestFailure(msg, extras)
442
443
444def explicit_pass(msg, extras=None):
445  """Explicitly pass a test.
446
447  This will pass the test explicitly regardless of any other error happened
448  in the test body. E.g. even if errors have been recorded with `expects`,
449  the test will still be marked pass if this is called.
450
451  A test without uncaught exception will pass implicitly so this should be
452  used scarcely.
453
454  Args:
455    msg: A string explaining the details of the passed test.
456    extras: An optional field for extra information to be included in
457      test result.
458
459  Raises:
460    signals.TestPass: Mark a test as passed.
461  """
462  raise signals.TestPass(msg, extras)
463
464
465class _AssertRaisesContext:
466  """A context manager used to implement TestCase.assertRaises* methods."""
467
468  def __init__(self, expected, expected_regexp=None, extras=None):
469    self.expected = expected
470    self.failureException = signals.TestFailure
471    self.expected_regexp = expected_regexp
472    self.extras = extras
473
474  def __enter__(self):
475    return self
476
477  def __exit__(self, exc_type, exc_value, tb):
478    if exc_type is None:
479      try:
480        exc_name = self.expected.__name__
481      except AttributeError:
482        exc_name = str(self.expected)
483      raise signals.TestFailure('%s not raised' % exc_name, extras=self.extras)
484    if not issubclass(exc_type, self.expected):
485      # let unexpected exceptions pass through
486      return False
487    self.exception = exc_value  # store for later retrieval
488    if self.expected_regexp is None:
489      return True
490
491    expected_regexp = self.expected_regexp
492    if isinstance(expected_regexp, str):
493      expected_regexp = re.compile(expected_regexp)
494    if not expected_regexp.search(str(exc_value)):
495      raise signals.TestFailure(
496          '"%s" does not match "%s"'
497          % (expected_regexp.pattern, str(exc_value)),
498          extras=self.extras,
499      )
500    return True
501