1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility to retrieve function args.""" 16 17import functools 18 19import six 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.util import tf_decorator 23from tensorflow.python.util import tf_inspect 24 25 26def _is_bound_method(fn): 27 _, fn = tf_decorator.unwrap(fn) 28 return tf_inspect.ismethod(fn) and (fn.__self__ is not None) 29 30 31def _is_callable_object(obj): 32 return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) 33 34 35def fn_args(fn): 36 """Get argument names for function-like object. 37 38 Args: 39 fn: Function, or function-like object (e.g., result of `functools.partial`). 40 41 Returns: 42 `tuple` of string argument names. 43 44 Raises: 45 ValueError: if partial function has positionally bound arguments 46 """ 47 if isinstance(fn, functools.partial): 48 args = fn_args(fn.func) 49 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] 50 else: 51 if _is_callable_object(fn): 52 fn = fn.__call__ 53 args = tf_inspect.getfullargspec(fn).args 54 if _is_bound_method(fn) and args: 55 # If it's a bound method, it may or may not have a self/cls first 56 # argument; for example, self could be captured in *args. 57 # If it does have a positional argument, it is self/cls. 58 args.pop(0) 59 return tuple(args) 60 61 62def has_kwargs(fn): 63 """Returns whether the passed callable has **kwargs in its signature. 64 65 Args: 66 fn: Function, or function-like object (e.g., result of `functools.partial`). 67 68 Returns: 69 `bool`: if `fn` has **kwargs in its signature. 70 71 Raises: 72 `TypeError`: If fn is not a Function, or function-like object. 73 """ 74 if isinstance(fn, functools.partial): 75 fn = fn.func 76 elif _is_callable_object(fn): 77 fn = fn.__call__ 78 elif not callable(fn): 79 raise TypeError( 80 'Argument `fn` should be a callable. ' 81 f'Received: fn={fn} (of type {type(fn)})') 82 return tf_inspect.getfullargspec(fn).varkw is not None 83 84 85def get_func_name(func): 86 """Returns name of passed callable.""" 87 _, func = tf_decorator.unwrap(func) 88 if callable(func): 89 if tf_inspect.isfunction(func): 90 return func.__name__ 91 elif tf_inspect.ismethod(func): 92 return '%s.%s' % (six.get_method_self(func).__class__.__name__, 93 six.get_method_function(func).__name__) 94 else: # Probably a class instance with __call__ 95 return str(type(func)) 96 else: 97 raise ValueError( 98 'Argument `func` must be a callable. ' 99 f'Received func={func} (of type {type(func)})') 100 101 102def get_func_code(func): 103 """Returns func_code of passed callable, or None if not available.""" 104 _, func = tf_decorator.unwrap(func) 105 if callable(func): 106 if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): 107 return six.get_function_code(func) 108 # Since the object is not a function or method, but is a callable, we will 109 # try to access the __call__method as a function. This works with callable 110 # classes but fails with functool.partial objects despite their __call__ 111 # attribute. 112 try: 113 return six.get_function_code(func.__call__) 114 except AttributeError: 115 return None 116 else: 117 raise ValueError( 118 'Argument `func` must be a callable. ' 119 f'Received func={func} (of type {type(func)})') 120 121 122_rewriter_config_optimizer_disabled = None 123 124 125def get_disabled_rewriter_config(): 126 global _rewriter_config_optimizer_disabled 127 if _rewriter_config_optimizer_disabled is None: 128 config = config_pb2.ConfigProto() 129 rewriter_config = config.graph_options.rewrite_options 130 rewriter_config.disable_meta_optimizer = True 131 _rewriter_config_optimizer_disabled = config.SerializeToString() 132 return _rewriter_config_optimizer_disabled 133