xref: /aosp_15_r20/external/autotest/client/common_lib/test_utils/mock.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1# Lint as: python2, python3
2from __future__ import absolute_import
3from __future__ import division
4from __future__ import print_function
5__author__ = "[email protected] (Travis Miller)"
6
7
8import re, collections, six, sys, unittest
9
10import six
11from six.moves import zip
12
13
14class StubNotFoundError(Exception):
15    'Raised when god is asked to unstub an attribute that was not stubbed'
16    pass
17
18
19class CheckPlaybackError(Exception):
20    'Raised when mock playback does not match recorded calls.'
21    pass
22
23
24class SaveDataAfterCloseStringIO(six.StringIO):
25    """Saves the contents in a final_data property when close() is called.
26
27    Useful as a mock output file object to test both that the file was
28    closed and what was written.
29
30    Properties:
31      final_data: Set to the StringIO's getvalue() data when close() is
32          called.  None if close() has not been called.
33    """
34
35    final_data = None
36
37
38    def __enter__(self):
39        return self
40
41    def __exit__(self, type, value, traceback):
42        self.close()
43
44    def close(self):
45        self.final_data = self.getvalue()
46        six.StringIO.close(self)
47
48
49class SaveDataAfterCloseBytesIO(six.BytesIO):
50    """Saves the contents in a final_data property when close() is called.
51
52    Useful as a mock output file object to test both that the file was
53    closed and what was written.
54
55    Properties:
56      final_data: Set to the BytesIO's getvalue() data when close() is
57          called.  None if close() has not been called.
58    """
59    final_data = None
60
61
62    def __enter__(self):
63        return self
64
65
66    def __exit__(self, type, value, traceback):
67        self.close()
68
69
70    def close(self):
71        self.final_data = self.getvalue()
72        six.BytesIO.close(self)
73
74
75class argument_comparator(object):
76    def is_satisfied_by(self, parameter):
77        raise NotImplementedError
78
79
80class equality_comparator(argument_comparator):
81    def __init__(self, value):
82        self.value = value
83
84
85    @staticmethod
86    def _types_match(arg1, arg2):
87        if isinstance(arg1, six.string_types) and isinstance(
88                arg2, six.string_types):
89            return True
90        return type(arg1) == type(arg2)
91
92
93    @classmethod
94    def _compare(cls, actual_arg, expected_arg):
95        if isinstance(expected_arg, argument_comparator):
96            return expected_arg.is_satisfied_by(actual_arg)
97        if not cls._types_match(expected_arg, actual_arg):
98            return False
99
100        if isinstance(expected_arg, list) or isinstance(expected_arg, tuple):
101            # recurse on lists/tuples
102            if len(actual_arg) != len(expected_arg):
103                return False
104            for actual_item, expected_item in zip(actual_arg, expected_arg):
105                if not cls._compare(actual_item, expected_item):
106                    return False
107        elif isinstance(expected_arg, dict):
108            # recurse on dicts
109            if not cls._compare(sorted(actual_arg.keys()),
110                                sorted(expected_arg.keys())):
111                return False
112            for key, value in six.iteritems(actual_arg):
113                if not cls._compare(value, expected_arg[key]):
114                    return False
115        elif actual_arg != expected_arg:
116            return False
117
118        return True
119
120
121    def is_satisfied_by(self, parameter):
122        return self._compare(parameter, self.value)
123
124
125    def __str__(self):
126        if isinstance(self.value, argument_comparator):
127            return str(self.value)
128        return repr(self.value)
129
130
131class regex_comparator(argument_comparator):
132    def __init__(self, pattern, flags=0):
133        self.regex = re.compile(pattern, flags)
134
135
136    def is_satisfied_by(self, parameter):
137        return self.regex.search(parameter) is not None
138
139
140    def __str__(self):
141        return self.regex.pattern
142
143
144class is_string_comparator(argument_comparator):
145    def is_satisfied_by(self, parameter):
146        return isinstance(parameter, six.string_types)
147
148
149    def __str__(self):
150        return "a string"
151
152
153class is_instance_comparator(argument_comparator):
154    def __init__(self, cls):
155        self.cls = cls
156
157
158    def is_satisfied_by(self, parameter):
159        return isinstance(parameter, self.cls)
160
161
162    def __str__(self):
163        return "is a %s" % self.cls
164
165
166class anything_comparator(argument_comparator):
167    def is_satisfied_by(self, parameter):
168        return True
169
170
171    def __str__(self):
172        return 'anything'
173
174
175class base_mapping(object):
176    def __init__(self, symbol, return_obj, *args, **dargs):
177        self.return_obj = return_obj
178        self.symbol = symbol
179        self.args = [equality_comparator(arg) for arg in args]
180        self.dargs = dict((key, equality_comparator(value))
181                          for key, value in six.iteritems(dargs))
182        self.error = None
183
184
185    def match(self, *args, **dargs):
186        if len(args) != len(self.args) or len(dargs) != len(self.dargs):
187            return False
188
189        for i, expected_arg in enumerate(self.args):
190            if not expected_arg.is_satisfied_by(args[i]):
191                return False
192
193        # check for incorrect dargs
194        for key, value in six.iteritems(dargs):
195            if key not in self.dargs:
196                return False
197            if not self.dargs[key].is_satisfied_by(value):
198                return False
199
200        # check for missing dargs
201        for key in six.iterkeys(self.dargs):
202            if key not in dargs:
203                return False
204
205        return True
206
207
208    def __str__(self):
209        return _dump_function_call(self.symbol, self.args, self.dargs)
210
211
212class function_mapping(base_mapping):
213    def __init__(self, symbol, return_val, *args, **dargs):
214        super(function_mapping, self).__init__(symbol, return_val, *args,
215                                               **dargs)
216
217
218    def and_return(self, return_obj):
219        self.return_obj = return_obj
220
221
222    def and_raises(self, error):
223        self.error = error
224
225
226class function_any_args_mapping(function_mapping):
227    """A mock function mapping that doesn't verify its arguments."""
228    def match(self, *args, **dargs):
229        return True
230
231
232class mock_function(object):
233    def __init__(self, symbol, default_return_val=None,
234                 record=None, playback=None):
235        self.default_return_val = default_return_val
236        self.num_calls = 0
237        self.args = []
238        self.dargs = []
239        self.symbol = symbol
240        self.record = record
241        self.playback = playback
242        self.__name__ = symbol
243
244
245    def __call__(self, *args, **dargs):
246        self.num_calls += 1
247        self.args.append(args)
248        self.dargs.append(dargs)
249        if self.playback:
250            return self.playback(self.symbol, *args, **dargs)
251        else:
252            return self.default_return_val
253
254
255    def expect_call(self, *args, **dargs):
256        mapping = function_mapping(self.symbol, None, *args, **dargs)
257        if self.record:
258            self.record(mapping)
259
260        return mapping
261
262
263    def expect_any_call(self):
264        """Like expect_call but don't give a hoot what arguments are passed."""
265        mapping = function_any_args_mapping(self.symbol, None)
266        if self.record:
267            self.record(mapping)
268
269        return mapping
270
271
272class mask_function(mock_function):
273    def __init__(self, symbol, original_function, default_return_val=None,
274                 record=None, playback=None):
275        super(mask_function, self).__init__(symbol,
276                                            default_return_val,
277                                            record, playback)
278        self.original_function = original_function
279
280
281    def run_original_function(self, *args, **dargs):
282        return self.original_function(*args, **dargs)
283
284
285class mock_class(object):
286    def __init__(self, cls, name, default_ret_val=None,
287                 record=None, playback=None):
288        self.__name = name
289        self.__record = record
290        self.__playback = playback
291
292        for symbol in dir(cls):
293            if symbol.startswith("_"):
294                continue
295
296            orig_symbol = getattr(cls, symbol)
297            if callable(orig_symbol):
298                f_name = "%s.%s" % (self.__name, symbol)
299                func = mock_function(f_name, default_ret_val,
300                                     self.__record, self.__playback)
301                setattr(self, symbol, func)
302            else:
303                setattr(self, symbol, orig_symbol)
304
305
306    def __repr__(self):
307        return '<mock_class: %s>' % self.__name
308
309
310class mock_god(object):
311    NONEXISTENT_ATTRIBUTE = object()
312
313    def __init__(self, debug=False, fail_fast=True, ut=None):
314        """
315        With debug=True, all recorded method calls will be printed as
316        they happen.
317        With fail_fast=True, unexpected calls will immediately cause an
318        exception to be raised.  With False, they will be silently recorded and
319        only reported when check_playback() is called.
320        """
321        self.recording = collections.deque()
322        self.errors = []
323        self._stubs = []
324        self._debug = debug
325        self._fail_fast = fail_fast
326        self._ut = ut
327
328
329    def set_fail_fast(self, fail_fast):
330        self._fail_fast = fail_fast
331
332
333    def create_mock_class_obj(self, cls, name, default_ret_val=None):
334        record = self.__record_call
335        playback = self.__method_playback
336        errors = self.errors
337
338
339        class RecordingMockMeta(type):
340            """Metaclass to override default class invocation behavior.
341
342            Normally, calling a class like a function creates and initializes an
343            instance of that class. This metaclass causes class invocation to
344            have no side effects and to return nothing, instead recording the
345            call in the mock_god object to be inspected or asserted against as a
346            part of a test.
347            """
348            def __call__(self, *args, **kwargs):
349                return playback(name, *args, **kwargs)
350
351
352        @six.add_metaclass(RecordingMockMeta)
353        class cls_sub(cls):
354            cls_count = 0
355
356            # overwrite the initializer
357            def __init__(self, *args, **dargs):
358                pass
359
360
361            @classmethod
362            def expect_new(typ, *args, **dargs):
363                obj = typ.make_new(*args, **dargs)
364                mapping = base_mapping(name, obj, *args, **dargs)
365                record(mapping)
366                return obj
367
368
369            @classmethod
370            def make_new(typ, *args, **dargs):
371                obj = super(cls_sub, typ).__new__(typ, *args, **dargs)
372
373                typ.cls_count += 1
374                obj_name = "%s_%s" % (name, typ.cls_count)
375                for symbol in dir(obj):
376                    if (symbol.startswith("__") and
377                        symbol.endswith("__")):
378                        continue
379
380                    if isinstance(getattr(typ, symbol, None), property):
381                        continue
382
383                    orig_symbol = getattr(obj, symbol)
384                    if callable(orig_symbol):
385                        f_name = ("%s.%s" %
386                                  (obj_name, symbol))
387                        func = mock_function(f_name,
388                                        default_ret_val,
389                                        record,
390                                        playback)
391                        setattr(obj, symbol, func)
392                    else:
393                        setattr(obj, symbol,
394                                orig_symbol)
395
396                return obj
397
398        return cls_sub
399
400
401    def create_mock_class(self, cls, name, default_ret_val=None):
402        """
403        Given something that defines a namespace cls (class, object,
404        module), and a (hopefully unique) name, will create a
405        mock_class object with that name and that possessess all
406        the public attributes of cls.  default_ret_val sets the
407        default_ret_val on all methods of the cls mock.
408        """
409        return mock_class(cls, name, default_ret_val,
410                          self.__record_call, self.__method_playback)
411
412
413    def create_mock_function(self, symbol, default_return_val=None):
414        """
415        create a mock_function with name symbol and default return
416        value of default_ret_val.
417        """
418        return mock_function(symbol, default_return_val,
419                             self.__record_call, self.__method_playback)
420
421
422    def mock_up(self, obj, name, default_ret_val=None):
423        """
424        Given an object (class instance or module) and a registration
425        name, then replace all its methods with mock function objects
426        (passing the orignal functions to the mock functions).
427        """
428        for symbol in dir(obj):
429            if symbol.startswith("__"):
430                continue
431
432            orig_symbol = getattr(obj, symbol)
433            if callable(orig_symbol):
434                f_name = "%s.%s" % (name, symbol)
435                func = mask_function(f_name, orig_symbol,
436                                     default_ret_val,
437                                     self.__record_call,
438                                     self.__method_playback)
439                setattr(obj, symbol, func)
440
441
442    def stub_with(self, namespace, symbol, new_attribute):
443        original_attribute = getattr(namespace, symbol,
444                                     self.NONEXISTENT_ATTRIBUTE)
445
446        # You only want to save the original attribute in cases where it is
447        # directly associated with the object in question. In cases where
448        # the attribute is actually inherited via some sort of hierarchy
449        # you want to delete the stub (restoring the original structure)
450        attribute_is_inherited = (hasattr(namespace, '__dict__') and
451                                  symbol not in namespace.__dict__)
452        if attribute_is_inherited:
453            original_attribute = self.NONEXISTENT_ATTRIBUTE
454
455        newstub = (namespace, symbol, original_attribute, new_attribute)
456        self._stubs.append(newstub)
457        setattr(namespace, symbol, new_attribute)
458
459
460    def stub_function(self, namespace, symbol):
461        mock_attribute = self.create_mock_function(symbol)
462        self.stub_with(namespace, symbol, mock_attribute)
463
464
465    def stub_class_method(self, cls, symbol):
466        mock_attribute = self.create_mock_function(symbol)
467        self.stub_with(cls, symbol, staticmethod(mock_attribute))
468
469
470    def stub_class(self, namespace, symbol):
471        attr = getattr(namespace, symbol)
472        mock_class = self.create_mock_class_obj(attr, symbol)
473        self.stub_with(namespace, symbol, mock_class)
474
475
476    def stub_function_to_return(self, namespace, symbol, object_to_return):
477        """Stub out a function with one that always returns a fixed value.
478
479        @param namespace The namespace containing the function to stub out.
480        @param symbol The attribute within the namespace to stub out.
481        @param object_to_return The value that the stub should return whenever
482            it is called.
483        """
484        self.stub_with(namespace, symbol,
485                       lambda *args, **dargs: object_to_return)
486
487
488    def _perform_unstub(self, stub):
489        namespace, symbol, orig_attr, new_attr = stub
490        if orig_attr == self.NONEXISTENT_ATTRIBUTE:
491            delattr(namespace, symbol)
492        else:
493            setattr(namespace, symbol, orig_attr)
494
495
496    def unstub(self, namespace, symbol):
497        for stub in reversed(self._stubs):
498            if (namespace, symbol) == (stub[0], stub[1]):
499                self._perform_unstub(stub)
500                self._stubs.remove(stub)
501                return
502
503        raise StubNotFoundError()
504
505
506    def unstub_all(self):
507        self._stubs.reverse()
508        for stub in self._stubs:
509            self._perform_unstub(stub)
510        self._stubs = []
511
512
513    def __method_playback(self, symbol, *args, **dargs):
514        if self._debug:
515            print((' * Mock call: ' +
516                   _dump_function_call(symbol, args, dargs)),
517                  file=sys.__stdout__)
518
519        if len(self.recording) != 0:
520            func_call = self.recording[0]
521            if func_call.symbol != symbol:
522                msg = ("Unexpected call: %s\nExpected: %s"
523                    % (_dump_function_call(symbol, args, dargs),
524                       func_call))
525                self._append_error(msg)
526                return None
527
528            if not func_call.match(*args, **dargs):
529                msg = ("Incorrect call: %s\nExpected: %s"
530                    % (_dump_function_call(symbol, args, dargs),
531                      func_call))
532                self._append_error(msg)
533                return None
534
535            # this is the expected call so pop it and return
536            self.recording.popleft()
537            if func_call.error:
538                raise func_call.error
539            else:
540                return func_call.return_obj
541        else:
542            msg = ("unexpected call: %s"
543                   % (_dump_function_call(symbol, args, dargs)))
544            self._append_error(msg)
545            return None
546
547
548    def __record_call(self, mapping):
549        self.recording.append(mapping)
550
551
552    def _append_error(self, error):
553        if self._debug:
554            print(' *** ' + error, file=sys.__stdout__)
555        if self._fail_fast:
556            raise CheckPlaybackError(error)
557        self.errors.append(error)
558
559
560    def check_playback(self):
561        """
562        Report any errors that were encounterd during calls
563        to __method_playback().
564        """
565        if len(self.errors) > 0:
566            if self._debug:
567                print('\nPlayback errors:')
568            for error in self.errors:
569                print(error, file=sys.__stdout__)
570
571            if self._ut:
572                self._ut.fail('\n'.join(self.errors))
573
574            raise CheckPlaybackError
575        elif len(self.recording) != 0:
576            errors = []
577            for func_call in self.recording:
578                error = "%s not called" % (func_call,)
579                errors.append(error)
580                print(error, file=sys.__stdout__)
581
582            if self._ut:
583                self._ut.fail('\n'.join(errors))
584
585            raise CheckPlaybackError
586        self.recording.clear()
587
588
589    def mock_io(self):
590        """Mocks and saves the stdout & stderr output"""
591        self.orig_stdout = sys.stdout
592        self.orig_stderr = sys.stderr
593
594        self.mock_streams_stdout = six.StringIO('')
595        self.mock_streams_stderr = six.StringIO('')
596
597        sys.stdout = self.mock_streams_stdout
598        sys.stderr = self.mock_streams_stderr
599
600
601    def unmock_io(self):
602        """Restores the stdout & stderr, and returns both
603        output strings"""
604        sys.stdout = self.orig_stdout
605        sys.stderr = self.orig_stderr
606        values = (self.mock_streams_stdout.getvalue(),
607                  self.mock_streams_stderr.getvalue())
608
609        self.mock_streams_stdout.close()
610        self.mock_streams_stderr.close()
611        return values
612
613
614def _arg_to_str(arg):
615    if isinstance(arg, argument_comparator):
616        return str(arg)
617    return repr(arg)
618
619
620def _dump_function_call(symbol, args, dargs):
621    arg_vec = []
622    for arg in args:
623        arg_vec.append(_arg_to_str(arg))
624    for key, val in six.iteritems(dargs):
625        arg_vec.append("%s=%s" % (key, _arg_to_str(val)))
626    return "%s(%s)" % (symbol, ', '.join(arg_vec))
627