1# Copyright 2014 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4""" 5A test facility to assert call sequences while mocking their behavior. 6""" 7 8import unittest 9 10from devil import devil_env 11 12with devil_env.SysPath(devil_env.PYMOCK_PATH): 13 import mock # pylint: disable=import-error 14 15 16class TestCase(unittest.TestCase): 17 """Adds assertCalls to TestCase objects.""" 18 19 class _AssertCalls(object): 20 def __init__(self, test_case, expected_calls, watched): 21 def call_action(pair): 22 if isinstance(pair, type(mock.call)): 23 return (pair, None) 24 else: 25 return pair 26 27 def do_check(call): 28 def side_effect(*args, **kwargs): 29 received_call = call(*args, **kwargs) 30 self._test_case.assertTrue( 31 self._expected_calls, 32 msg=('Unexpected call: %s' % str(received_call))) 33 expected_call, action = self._expected_calls.pop(0) 34 self._test_case.assertTrue( 35 received_call == expected_call, 36 msg=('Expected call mismatch:\n' 37 ' expected: %s\n' 38 ' received: %s\n' % (str(expected_call), 39 str(received_call)))) 40 if callable(action): 41 return action(*args, **kwargs) 42 else: 43 return action 44 45 return side_effect 46 47 self._test_case = test_case 48 self._expected_calls = [call_action(pair) for pair in expected_calls] 49 watched = watched.copy() # do not pollute the caller's dict 50 watched.update( 51 (call.parent.name, call.parent) for call, _ in self._expected_calls) 52 self._patched = [ 53 test_case.patch_call(call, side_effect=do_check(call)) 54 for call in watched.values() 55 ] 56 57 def __enter__(self): 58 for patch in self._patched: 59 patch.__enter__() 60 return self 61 62 def __exit__(self, exc_type, exc_val, exc_tb): 63 for patch in self._patched: 64 patch.__exit__(exc_type, exc_val, exc_tb) 65 if exc_type is None: 66 missing = ''.join( 67 ' expected: %s\n' % str(call) for call, _ in self._expected_calls) 68 self._test_case.assertFalse( 69 missing, msg='Expected calls not found:\n' + missing) 70 71 def __init__(self, *args, **kwargs): 72 super(TestCase, self).__init__(*args, **kwargs) 73 self.call = mock.call.self 74 self._watched = {} 75 76 def call_target(self, call): 77 """Resolve a self.call instance to the target it represents. 78 79 Args: 80 call: a self.call instance, e.g. self.call.adb.Shell 81 82 Returns: 83 The target object represented by the call, e.g. self.adb.Shell 84 85 Raises: 86 ValueError if the path of the call does not start with "self", i.e. the 87 target of the call is external to the self object. 88 AttributeError if the path of the call does not specify a valid 89 chain of attributes (without any calls) starting from "self". 90 """ 91 path = call.name.split('.') 92 if path.pop(0) != 'self': 93 raise ValueError("Target %r outside of 'self' object" % call.name) 94 target = self 95 for attr in path: 96 target = getattr(target, attr) 97 return target 98 99 def patch_call(self, call, **kwargs): 100 """Patch the target of a mock.call instance. 101 102 Args: 103 call: a mock.call instance identifying a target to patch 104 Extra keyword arguments are processed by mock.patch 105 106 Returns: 107 A context manager to mock/unmock the target of the call 108 """ 109 if call.name.startswith('self.'): 110 target = self.call_target(call.parent) 111 _, attribute = call.name.rsplit('.', 1) 112 if (hasattr(type(target), attribute) 113 and isinstance(getattr(type(target), attribute), property)): 114 return mock.patch.object( 115 type(target), attribute, new_callable=mock.PropertyMock, **kwargs) 116 else: 117 return mock.patch.object(target, attribute, **kwargs) 118 else: 119 return mock.patch(call.name, **kwargs) 120 121 def watchCalls(self, calls): 122 """Add calls to the set of watched calls. 123 124 Args: 125 calls: a sequence of mock.call instances identifying targets to watch 126 """ 127 self._watched.update((call.name, call) for call in calls) 128 129 def watchMethodCalls(self, call, ignore=None): 130 """Watch all public methods of the target identified by a self.call. 131 132 Args: 133 call: a self.call instance indetifying an object 134 ignore: a list of public methods to ignore when watching for calls 135 """ 136 target = self.call_target(call) 137 if ignore is None: 138 ignore = [] 139 self.watchCalls( 140 getattr(call, method) for method in dir(target.__class__) 141 if not method.startswith('_') and not method in ignore) 142 143 def clearWatched(self): 144 """Clear the set of watched calls.""" 145 self._watched = {} 146 147 def assertCalls(self, *calls): 148 """A context manager to assert that a sequence of calls is made. 149 150 During the assertion, a number of functions and methods will be "watched", 151 and any calls made to them is expected to appear---in the exact same order, 152 and with the exact same arguments---as specified by the argument |calls|. 153 154 By default, the targets of all expected calls are watched. Further targets 155 to watch may be added using watchCalls and watchMethodCalls. 156 157 Optionaly, each call may be accompanied by an action. If the action is a 158 (non-callable) value, this value will be used as the return value given to 159 the caller when the matching call is found. Alternatively, if the action is 160 a callable, the action will be then called with the same arguments as the 161 intercepted call, so that it can provide a return value or perform other 162 side effects. If the action is missing, a return value of None is assumed. 163 164 Note that mock.Mock objects are often convenient to use as a callable 165 action, e.g. to raise exceptions or return other objects which are 166 themselves callable. 167 168 Args: 169 calls: each argument is either a pair (expected_call, action) or just an 170 expected_call, where expected_call is a mock.call instance. 171 172 Raises: 173 AssertionError if the watched targets do not receive the exact sequence 174 of calls specified. Missing calls, extra calls, and calls with 175 mismatching arguments, all cause the assertion to fail. 176 """ 177 return self._AssertCalls(self, calls, self._watched) 178 179 def assertCall(self, call, action=None): 180 return self.assertCalls((call, action)) 181