1# Lint as: python2, python3 2from __future__ import absolute_import 3from __future__ import division 4from __future__ import print_function 5__author__ = "[email protected] (Travis Miller)" 6 7 8import re, collections, six, sys, unittest 9 10import six 11from six.moves import zip 12 13 14class StubNotFoundError(Exception): 15 'Raised when god is asked to unstub an attribute that was not stubbed' 16 pass 17 18 19class CheckPlaybackError(Exception): 20 'Raised when mock playback does not match recorded calls.' 21 pass 22 23 24class SaveDataAfterCloseStringIO(six.StringIO): 25 """Saves the contents in a final_data property when close() is called. 26 27 Useful as a mock output file object to test both that the file was 28 closed and what was written. 29 30 Properties: 31 final_data: Set to the StringIO's getvalue() data when close() is 32 called. None if close() has not been called. 33 """ 34 35 final_data = None 36 37 38 def __enter__(self): 39 return self 40 41 def __exit__(self, type, value, traceback): 42 self.close() 43 44 def close(self): 45 self.final_data = self.getvalue() 46 six.StringIO.close(self) 47 48 49class SaveDataAfterCloseBytesIO(six.BytesIO): 50 """Saves the contents in a final_data property when close() is called. 51 52 Useful as a mock output file object to test both that the file was 53 closed and what was written. 54 55 Properties: 56 final_data: Set to the BytesIO's getvalue() data when close() is 57 called. None if close() has not been called. 58 """ 59 final_data = None 60 61 62 def __enter__(self): 63 return self 64 65 66 def __exit__(self, type, value, traceback): 67 self.close() 68 69 70 def close(self): 71 self.final_data = self.getvalue() 72 six.BytesIO.close(self) 73 74 75class argument_comparator(object): 76 def is_satisfied_by(self, parameter): 77 raise NotImplementedError 78 79 80class equality_comparator(argument_comparator): 81 def __init__(self, value): 82 self.value = value 83 84 85 @staticmethod 86 def _types_match(arg1, arg2): 87 if isinstance(arg1, six.string_types) and isinstance( 88 arg2, six.string_types): 89 return True 90 return type(arg1) == type(arg2) 91 92 93 @classmethod 94 def _compare(cls, actual_arg, expected_arg): 95 if isinstance(expected_arg, argument_comparator): 96 return expected_arg.is_satisfied_by(actual_arg) 97 if not cls._types_match(expected_arg, actual_arg): 98 return False 99 100 if isinstance(expected_arg, list) or isinstance(expected_arg, tuple): 101 # recurse on lists/tuples 102 if len(actual_arg) != len(expected_arg): 103 return False 104 for actual_item, expected_item in zip(actual_arg, expected_arg): 105 if not cls._compare(actual_item, expected_item): 106 return False 107 elif isinstance(expected_arg, dict): 108 # recurse on dicts 109 if not cls._compare(sorted(actual_arg.keys()), 110 sorted(expected_arg.keys())): 111 return False 112 for key, value in six.iteritems(actual_arg): 113 if not cls._compare(value, expected_arg[key]): 114 return False 115 elif actual_arg != expected_arg: 116 return False 117 118 return True 119 120 121 def is_satisfied_by(self, parameter): 122 return self._compare(parameter, self.value) 123 124 125 def __str__(self): 126 if isinstance(self.value, argument_comparator): 127 return str(self.value) 128 return repr(self.value) 129 130 131class regex_comparator(argument_comparator): 132 def __init__(self, pattern, flags=0): 133 self.regex = re.compile(pattern, flags) 134 135 136 def is_satisfied_by(self, parameter): 137 return self.regex.search(parameter) is not None 138 139 140 def __str__(self): 141 return self.regex.pattern 142 143 144class is_string_comparator(argument_comparator): 145 def is_satisfied_by(self, parameter): 146 return isinstance(parameter, six.string_types) 147 148 149 def __str__(self): 150 return "a string" 151 152 153class is_instance_comparator(argument_comparator): 154 def __init__(self, cls): 155 self.cls = cls 156 157 158 def is_satisfied_by(self, parameter): 159 return isinstance(parameter, self.cls) 160 161 162 def __str__(self): 163 return "is a %s" % self.cls 164 165 166class anything_comparator(argument_comparator): 167 def is_satisfied_by(self, parameter): 168 return True 169 170 171 def __str__(self): 172 return 'anything' 173 174 175class base_mapping(object): 176 def __init__(self, symbol, return_obj, *args, **dargs): 177 self.return_obj = return_obj 178 self.symbol = symbol 179 self.args = [equality_comparator(arg) for arg in args] 180 self.dargs = dict((key, equality_comparator(value)) 181 for key, value in six.iteritems(dargs)) 182 self.error = None 183 184 185 def match(self, *args, **dargs): 186 if len(args) != len(self.args) or len(dargs) != len(self.dargs): 187 return False 188 189 for i, expected_arg in enumerate(self.args): 190 if not expected_arg.is_satisfied_by(args[i]): 191 return False 192 193 # check for incorrect dargs 194 for key, value in six.iteritems(dargs): 195 if key not in self.dargs: 196 return False 197 if not self.dargs[key].is_satisfied_by(value): 198 return False 199 200 # check for missing dargs 201 for key in six.iterkeys(self.dargs): 202 if key not in dargs: 203 return False 204 205 return True 206 207 208 def __str__(self): 209 return _dump_function_call(self.symbol, self.args, self.dargs) 210 211 212class function_mapping(base_mapping): 213 def __init__(self, symbol, return_val, *args, **dargs): 214 super(function_mapping, self).__init__(symbol, return_val, *args, 215 **dargs) 216 217 218 def and_return(self, return_obj): 219 self.return_obj = return_obj 220 221 222 def and_raises(self, error): 223 self.error = error 224 225 226class function_any_args_mapping(function_mapping): 227 """A mock function mapping that doesn't verify its arguments.""" 228 def match(self, *args, **dargs): 229 return True 230 231 232class mock_function(object): 233 def __init__(self, symbol, default_return_val=None, 234 record=None, playback=None): 235 self.default_return_val = default_return_val 236 self.num_calls = 0 237 self.args = [] 238 self.dargs = [] 239 self.symbol = symbol 240 self.record = record 241 self.playback = playback 242 self.__name__ = symbol 243 244 245 def __call__(self, *args, **dargs): 246 self.num_calls += 1 247 self.args.append(args) 248 self.dargs.append(dargs) 249 if self.playback: 250 return self.playback(self.symbol, *args, **dargs) 251 else: 252 return self.default_return_val 253 254 255 def expect_call(self, *args, **dargs): 256 mapping = function_mapping(self.symbol, None, *args, **dargs) 257 if self.record: 258 self.record(mapping) 259 260 return mapping 261 262 263 def expect_any_call(self): 264 """Like expect_call but don't give a hoot what arguments are passed.""" 265 mapping = function_any_args_mapping(self.symbol, None) 266 if self.record: 267 self.record(mapping) 268 269 return mapping 270 271 272class mask_function(mock_function): 273 def __init__(self, symbol, original_function, default_return_val=None, 274 record=None, playback=None): 275 super(mask_function, self).__init__(symbol, 276 default_return_val, 277 record, playback) 278 self.original_function = original_function 279 280 281 def run_original_function(self, *args, **dargs): 282 return self.original_function(*args, **dargs) 283 284 285class mock_class(object): 286 def __init__(self, cls, name, default_ret_val=None, 287 record=None, playback=None): 288 self.__name = name 289 self.__record = record 290 self.__playback = playback 291 292 for symbol in dir(cls): 293 if symbol.startswith("_"): 294 continue 295 296 orig_symbol = getattr(cls, symbol) 297 if callable(orig_symbol): 298 f_name = "%s.%s" % (self.__name, symbol) 299 func = mock_function(f_name, default_ret_val, 300 self.__record, self.__playback) 301 setattr(self, symbol, func) 302 else: 303 setattr(self, symbol, orig_symbol) 304 305 306 def __repr__(self): 307 return '<mock_class: %s>' % self.__name 308 309 310class mock_god(object): 311 NONEXISTENT_ATTRIBUTE = object() 312 313 def __init__(self, debug=False, fail_fast=True, ut=None): 314 """ 315 With debug=True, all recorded method calls will be printed as 316 they happen. 317 With fail_fast=True, unexpected calls will immediately cause an 318 exception to be raised. With False, they will be silently recorded and 319 only reported when check_playback() is called. 320 """ 321 self.recording = collections.deque() 322 self.errors = [] 323 self._stubs = [] 324 self._debug = debug 325 self._fail_fast = fail_fast 326 self._ut = ut 327 328 329 def set_fail_fast(self, fail_fast): 330 self._fail_fast = fail_fast 331 332 333 def create_mock_class_obj(self, cls, name, default_ret_val=None): 334 record = self.__record_call 335 playback = self.__method_playback 336 errors = self.errors 337 338 339 class RecordingMockMeta(type): 340 """Metaclass to override default class invocation behavior. 341 342 Normally, calling a class like a function creates and initializes an 343 instance of that class. This metaclass causes class invocation to 344 have no side effects and to return nothing, instead recording the 345 call in the mock_god object to be inspected or asserted against as a 346 part of a test. 347 """ 348 def __call__(self, *args, **kwargs): 349 return playback(name, *args, **kwargs) 350 351 352 @six.add_metaclass(RecordingMockMeta) 353 class cls_sub(cls): 354 cls_count = 0 355 356 # overwrite the initializer 357 def __init__(self, *args, **dargs): 358 pass 359 360 361 @classmethod 362 def expect_new(typ, *args, **dargs): 363 obj = typ.make_new(*args, **dargs) 364 mapping = base_mapping(name, obj, *args, **dargs) 365 record(mapping) 366 return obj 367 368 369 @classmethod 370 def make_new(typ, *args, **dargs): 371 obj = super(cls_sub, typ).__new__(typ, *args, **dargs) 372 373 typ.cls_count += 1 374 obj_name = "%s_%s" % (name, typ.cls_count) 375 for symbol in dir(obj): 376 if (symbol.startswith("__") and 377 symbol.endswith("__")): 378 continue 379 380 if isinstance(getattr(typ, symbol, None), property): 381 continue 382 383 orig_symbol = getattr(obj, symbol) 384 if callable(orig_symbol): 385 f_name = ("%s.%s" % 386 (obj_name, symbol)) 387 func = mock_function(f_name, 388 default_ret_val, 389 record, 390 playback) 391 setattr(obj, symbol, func) 392 else: 393 setattr(obj, symbol, 394 orig_symbol) 395 396 return obj 397 398 return cls_sub 399 400 401 def create_mock_class(self, cls, name, default_ret_val=None): 402 """ 403 Given something that defines a namespace cls (class, object, 404 module), and a (hopefully unique) name, will create a 405 mock_class object with that name and that possessess all 406 the public attributes of cls. default_ret_val sets the 407 default_ret_val on all methods of the cls mock. 408 """ 409 return mock_class(cls, name, default_ret_val, 410 self.__record_call, self.__method_playback) 411 412 413 def create_mock_function(self, symbol, default_return_val=None): 414 """ 415 create a mock_function with name symbol and default return 416 value of default_ret_val. 417 """ 418 return mock_function(symbol, default_return_val, 419 self.__record_call, self.__method_playback) 420 421 422 def mock_up(self, obj, name, default_ret_val=None): 423 """ 424 Given an object (class instance or module) and a registration 425 name, then replace all its methods with mock function objects 426 (passing the orignal functions to the mock functions). 427 """ 428 for symbol in dir(obj): 429 if symbol.startswith("__"): 430 continue 431 432 orig_symbol = getattr(obj, symbol) 433 if callable(orig_symbol): 434 f_name = "%s.%s" % (name, symbol) 435 func = mask_function(f_name, orig_symbol, 436 default_ret_val, 437 self.__record_call, 438 self.__method_playback) 439 setattr(obj, symbol, func) 440 441 442 def stub_with(self, namespace, symbol, new_attribute): 443 original_attribute = getattr(namespace, symbol, 444 self.NONEXISTENT_ATTRIBUTE) 445 446 # You only want to save the original attribute in cases where it is 447 # directly associated with the object in question. In cases where 448 # the attribute is actually inherited via some sort of hierarchy 449 # you want to delete the stub (restoring the original structure) 450 attribute_is_inherited = (hasattr(namespace, '__dict__') and 451 symbol not in namespace.__dict__) 452 if attribute_is_inherited: 453 original_attribute = self.NONEXISTENT_ATTRIBUTE 454 455 newstub = (namespace, symbol, original_attribute, new_attribute) 456 self._stubs.append(newstub) 457 setattr(namespace, symbol, new_attribute) 458 459 460 def stub_function(self, namespace, symbol): 461 mock_attribute = self.create_mock_function(symbol) 462 self.stub_with(namespace, symbol, mock_attribute) 463 464 465 def stub_class_method(self, cls, symbol): 466 mock_attribute = self.create_mock_function(symbol) 467 self.stub_with(cls, symbol, staticmethod(mock_attribute)) 468 469 470 def stub_class(self, namespace, symbol): 471 attr = getattr(namespace, symbol) 472 mock_class = self.create_mock_class_obj(attr, symbol) 473 self.stub_with(namespace, symbol, mock_class) 474 475 476 def stub_function_to_return(self, namespace, symbol, object_to_return): 477 """Stub out a function with one that always returns a fixed value. 478 479 @param namespace The namespace containing the function to stub out. 480 @param symbol The attribute within the namespace to stub out. 481 @param object_to_return The value that the stub should return whenever 482 it is called. 483 """ 484 self.stub_with(namespace, symbol, 485 lambda *args, **dargs: object_to_return) 486 487 488 def _perform_unstub(self, stub): 489 namespace, symbol, orig_attr, new_attr = stub 490 if orig_attr == self.NONEXISTENT_ATTRIBUTE: 491 delattr(namespace, symbol) 492 else: 493 setattr(namespace, symbol, orig_attr) 494 495 496 def unstub(self, namespace, symbol): 497 for stub in reversed(self._stubs): 498 if (namespace, symbol) == (stub[0], stub[1]): 499 self._perform_unstub(stub) 500 self._stubs.remove(stub) 501 return 502 503 raise StubNotFoundError() 504 505 506 def unstub_all(self): 507 self._stubs.reverse() 508 for stub in self._stubs: 509 self._perform_unstub(stub) 510 self._stubs = [] 511 512 513 def __method_playback(self, symbol, *args, **dargs): 514 if self._debug: 515 print((' * Mock call: ' + 516 _dump_function_call(symbol, args, dargs)), 517 file=sys.__stdout__) 518 519 if len(self.recording) != 0: 520 func_call = self.recording[0] 521 if func_call.symbol != symbol: 522 msg = ("Unexpected call: %s\nExpected: %s" 523 % (_dump_function_call(symbol, args, dargs), 524 func_call)) 525 self._append_error(msg) 526 return None 527 528 if not func_call.match(*args, **dargs): 529 msg = ("Incorrect call: %s\nExpected: %s" 530 % (_dump_function_call(symbol, args, dargs), 531 func_call)) 532 self._append_error(msg) 533 return None 534 535 # this is the expected call so pop it and return 536 self.recording.popleft() 537 if func_call.error: 538 raise func_call.error 539 else: 540 return func_call.return_obj 541 else: 542 msg = ("unexpected call: %s" 543 % (_dump_function_call(symbol, args, dargs))) 544 self._append_error(msg) 545 return None 546 547 548 def __record_call(self, mapping): 549 self.recording.append(mapping) 550 551 552 def _append_error(self, error): 553 if self._debug: 554 print(' *** ' + error, file=sys.__stdout__) 555 if self._fail_fast: 556 raise CheckPlaybackError(error) 557 self.errors.append(error) 558 559 560 def check_playback(self): 561 """ 562 Report any errors that were encounterd during calls 563 to __method_playback(). 564 """ 565 if len(self.errors) > 0: 566 if self._debug: 567 print('\nPlayback errors:') 568 for error in self.errors: 569 print(error, file=sys.__stdout__) 570 571 if self._ut: 572 self._ut.fail('\n'.join(self.errors)) 573 574 raise CheckPlaybackError 575 elif len(self.recording) != 0: 576 errors = [] 577 for func_call in self.recording: 578 error = "%s not called" % (func_call,) 579 errors.append(error) 580 print(error, file=sys.__stdout__) 581 582 if self._ut: 583 self._ut.fail('\n'.join(errors)) 584 585 raise CheckPlaybackError 586 self.recording.clear() 587 588 589 def mock_io(self): 590 """Mocks and saves the stdout & stderr output""" 591 self.orig_stdout = sys.stdout 592 self.orig_stderr = sys.stderr 593 594 self.mock_streams_stdout = six.StringIO('') 595 self.mock_streams_stderr = six.StringIO('') 596 597 sys.stdout = self.mock_streams_stdout 598 sys.stderr = self.mock_streams_stderr 599 600 601 def unmock_io(self): 602 """Restores the stdout & stderr, and returns both 603 output strings""" 604 sys.stdout = self.orig_stdout 605 sys.stderr = self.orig_stderr 606 values = (self.mock_streams_stdout.getvalue(), 607 self.mock_streams_stderr.getvalue()) 608 609 self.mock_streams_stdout.close() 610 self.mock_streams_stderr.close() 611 return values 612 613 614def _arg_to_str(arg): 615 if isinstance(arg, argument_comparator): 616 return str(arg) 617 return repr(arg) 618 619 620def _dump_function_call(symbol, args, dargs): 621 arg_vec = [] 622 for arg in args: 623 arg_vec.append(_arg_to_str(arg)) 624 for key, val in six.iteritems(dargs): 625 arg_vec.append("%s=%s" % (key, _arg_to_str(val))) 626 return "%s(%s)" % (symbol, ', '.join(arg_vec)) 627