1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base classes for probability distributions.""" 16 17import abc 18import contextlib 19import types 20 21import numpy as np 22 23from tensorflow.python.eager import context 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.distributions import kullback_leibler 31from tensorflow.python.ops.distributions import util 32from tensorflow.python.util import deprecation 33from tensorflow.python.util import tf_inspect 34from tensorflow.python.util.tf_export import tf_export 35 36 37__all__ = [ 38 "ReparameterizationType", 39 "FULLY_REPARAMETERIZED", 40 "NOT_REPARAMETERIZED", 41 "Distribution", 42] 43 44_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [ 45 "batch_shape", 46 "batch_shape_tensor", 47 "cdf", 48 "covariance", 49 "cross_entropy", 50 "entropy", 51 "event_shape", 52 "event_shape_tensor", 53 "kl_divergence", 54 "log_cdf", 55 "log_prob", 56 "log_survival_function", 57 "mean", 58 "mode", 59 "prob", 60 "sample", 61 "stddev", 62 "survival_function", 63 "variance", 64] 65 66 67class _BaseDistribution(metaclass=abc.ABCMeta): 68 """Abstract base class needed for resolving subclass hierarchy.""" 69 pass 70 71 72def _copy_fn(fn): 73 """Create a deep copy of fn. 74 75 Args: 76 fn: a callable 77 78 Returns: 79 A `FunctionType`: a deep copy of fn. 80 81 Raises: 82 TypeError: if `fn` is not a callable. 83 """ 84 if not callable(fn): 85 raise TypeError("fn is not callable: %s" % fn) 86 # The blessed way to copy a function. copy.deepcopy fails to create a 87 # non-reference copy. Since: 88 # types.FunctionType == type(lambda: None), 89 # and the docstring for the function type states: 90 # 91 # function(code, globals[, name[, argdefs[, closure]]]) 92 # 93 # Create a function object from a code object and a dictionary. 94 # ... 95 # 96 # Here we can use this to create a new function with the old function's 97 # code, globals, closure, etc. 98 return types.FunctionType( 99 code=fn.__code__, globals=fn.__globals__, 100 name=fn.__name__, argdefs=fn.__defaults__, 101 closure=fn.__closure__) 102 103 104def _update_docstring(old_str, append_str): 105 """Update old_str by inserting append_str just before the "Args:" section.""" 106 old_str = old_str or "" 107 old_str_lines = old_str.split("\n") 108 109 # Step 0: Prepend spaces to all lines of append_str. This is 110 # necessary for correct markdown generation. 111 append_str = "\n".join(" %s" % line for line in append_str.split("\n")) 112 113 # Step 1: Find mention of "Args": 114 has_args_ix = [ 115 ix for ix, line in enumerate(old_str_lines) 116 if line.strip().lower() == "args:"] 117 if has_args_ix: 118 final_args_ix = has_args_ix[-1] 119 return ("\n".join(old_str_lines[:final_args_ix]) 120 + "\n\n" + append_str + "\n\n" 121 + "\n".join(old_str_lines[final_args_ix:])) 122 else: 123 return old_str + "\n\n" + append_str 124 125 126def _convert_to_tensor(value, name=None, preferred_dtype=None): 127 """Converts to tensor avoiding an eager bug that loses float precision.""" 128 # TODO(b/116672045): Remove this function. 129 if (context.executing_eagerly() and preferred_dtype is not None and 130 (preferred_dtype.is_integer or preferred_dtype.is_bool)): 131 v = ops.convert_to_tensor(value, name=name) 132 if v.dtype.is_floating: 133 return v 134 return ops.convert_to_tensor( 135 value, name=name, preferred_dtype=preferred_dtype) 136 137 138class _DistributionMeta(abc.ABCMeta): 139 140 def __new__(mcs, classname, baseclasses, attrs): 141 """Control the creation of subclasses of the Distribution class. 142 143 The main purpose of this method is to properly propagate docstrings 144 from private Distribution methods, like `_log_prob`, into their 145 public wrappers as inherited by the Distribution base class 146 (e.g. `log_prob`). 147 148 Args: 149 classname: The name of the subclass being created. 150 baseclasses: A tuple of parent classes. 151 attrs: A dict mapping new attributes to their values. 152 153 Returns: 154 The class object. 155 156 Raises: 157 TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or 158 the new class is derived via multiple inheritance and the first 159 parent class is not a subclass of `BaseDistribution`. 160 AttributeError: If `Distribution` does not implement e.g. `log_prob`. 161 ValueError: If a `Distribution` public method lacks a docstring. 162 """ 163 if not baseclasses: # Nothing to be done for Distribution 164 raise TypeError("Expected non-empty baseclass. Does Distribution " 165 "not subclass _BaseDistribution?") 166 which_base = [ 167 base for base in baseclasses 168 if base == _BaseDistribution or issubclass(base, Distribution)] 169 base = which_base[0] 170 if base == _BaseDistribution: # Nothing to be done for Distribution 171 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 172 if not issubclass(base, Distribution): 173 raise TypeError("First parent class declared for %s must be " 174 "Distribution, but saw '%s'" % (classname, base.__name__)) 175 for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS: 176 special_attr = "_%s" % attr 177 class_attr_value = attrs.get(attr, None) 178 if attr in attrs: 179 # The method is being overridden, do not update its docstring 180 continue 181 base_attr_value = getattr(base, attr, None) 182 if not base_attr_value: 183 raise AttributeError( 184 "Internal error: expected base class '%s' to implement method '%s'" 185 % (base.__name__, attr)) 186 class_special_attr_value = attrs.get(special_attr, None) 187 if class_special_attr_value is None: 188 # No _special method available, no need to update the docstring. 189 continue 190 class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value) 191 if not class_special_attr_docstring: 192 # No docstring to append. 193 continue 194 class_attr_value = _copy_fn(base_attr_value) 195 class_attr_docstring = tf_inspect.getdoc(base_attr_value) 196 if class_attr_docstring is None: 197 raise ValueError( 198 "Expected base class fn to contain a docstring: %s.%s" 199 % (base.__name__, attr)) 200 class_attr_value.__doc__ = _update_docstring( 201 class_attr_value.__doc__, 202 ("Additional documentation from `%s`:\n\n%s" 203 % (classname, class_special_attr_docstring))) 204 attrs[attr] = class_attr_value 205 206 return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs) 207 208 209@tf_export(v1=["distributions.ReparameterizationType"]) 210class ReparameterizationType: 211 """Instances of this class represent how sampling is reparameterized. 212 213 Two static instances exist in the distributions library, signifying 214 one of two possible properties for samples from a distribution: 215 216 `FULLY_REPARAMETERIZED`: Samples from the distribution are fully 217 reparameterized, and straight-through gradients are supported. 218 219 `NOT_REPARAMETERIZED`: Samples from the distribution are not fully 220 reparameterized, and straight-through gradients are either partially 221 unsupported or are not supported at all. In this case, for purposes of 222 e.g. RL or variational inference, it is generally safest to wrap the 223 sample results in a `stop_gradients` call and use policy 224 gradients / surrogate loss instead. 225 """ 226 227 @deprecation.deprecated( 228 "2019-01-01", 229 "The TensorFlow Distributions library has moved to " 230 "TensorFlow Probability " 231 "(https://github.com/tensorflow/probability). You " 232 "should update all references to use `tfp.distributions` " 233 "instead of `tf.distributions`.", 234 warn_once=True) 235 def __init__(self, rep_type): 236 self._rep_type = rep_type 237 238 def __repr__(self): 239 return "<Reparameterization Type: %s>" % self._rep_type 240 241 def __eq__(self, other): 242 """Determine if this `ReparameterizationType` is equal to another. 243 244 Since ReparameterizationType instances are constant static global 245 instances, equality checks if two instances' id() values are equal. 246 247 Args: 248 other: Object to compare against. 249 250 Returns: 251 `self is other`. 252 """ 253 return self is other 254 255 256# Fully reparameterized distribution: samples from a fully 257# reparameterized distribution support straight-through gradients with 258# respect to all parameters. 259FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED") 260tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant( 261 __name__, "FULLY_REPARAMETERIZED") 262 263 264# Not reparameterized distribution: samples from a non- 265# reparameterized distribution do not support straight-through gradients for 266# at least some of the parameters. 267NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED") 268tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant( 269 __name__, "NOT_REPARAMETERIZED") 270 271 272@tf_export(v1=["distributions.Distribution"]) 273class Distribution(_BaseDistribution, metaclass=_DistributionMeta): 274 """A generic probability distribution base class. 275 276 `Distribution` is a base class for constructing and organizing properties 277 (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian). 278 279 #### Subclassing 280 281 Subclasses are expected to implement a leading-underscore version of the 282 same-named function. The argument signature should be identical except for 283 the omission of `name="..."`. For example, to enable `log_prob(value, 284 name="log_prob")` a subclass should implement `_log_prob(value)`. 285 286 Subclasses can append to public-level docstrings by providing 287 docstrings for their method specializations. For example: 288 289 ```python 290 @util.AppendDocstring("Some other details.") 291 def _log_prob(self, value): 292 ... 293 ``` 294 295 would add the string "Some other details." to the `log_prob` function 296 docstring. This is implemented as a simple decorator to avoid python 297 linter complaining about missing Args/Returns/Raises sections in the 298 partial docstrings. 299 300 #### Broadcasting, batching, and shapes 301 302 All distributions support batches of independent distributions of that type. 303 The batch shape is determined by broadcasting together the parameters. 304 305 The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and 306 `log_prob` reflect this broadcasting, as does the return value of `sample` and 307 `sample_n`. 308 309 `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is 310 the shape of the `Tensor` returned from `sample_n`, `n` is the number of 311 samples, `batch_shape` defines how many independent distributions there are, 312 and `event_shape` defines the shape of samples from each of those independent 313 distributions. Samples are independent along the `batch_shape` dimensions, but 314 not necessarily so along the `event_shape` dimensions (depending on the 315 particulars of the underlying distribution). 316 317 Using the `Uniform` distribution as an example: 318 319 ```python 320 minval = 3.0 321 maxval = [[4.0, 6.0], 322 [10.0, 12.0]] 323 324 # Broadcasting: 325 # This instance represents 4 Uniform distributions. Each has a lower bound at 326 # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape. 327 u = Uniform(minval, maxval) 328 329 # `event_shape` is `TensorShape([])`. 330 event_shape = u.event_shape 331 # `event_shape_t` is a `Tensor` which will evaluate to []. 332 event_shape_t = u.event_shape_tensor() 333 334 # Sampling returns a sample per distribution. `samples` has shape 335 # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5, 336 # batch_shape=[2, 2], and event_shape=[]. 337 samples = u.sample_n(5) 338 339 # The broadcasting holds across methods. Here we use `cdf` as an example. The 340 # same holds for `log_cdf` and the likelihood functions. 341 342 # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the 343 # shape of the `Uniform` instance. 344 cum_prob_broadcast = u.cdf(4.0) 345 346 # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting 347 # occurred. 348 cum_prob_per_dist = u.cdf([[4.0, 5.0], 349 [6.0, 7.0]]) 350 351 # INVALID as the `value` argument is not broadcastable to the distribution's 352 # shape. 353 cum_prob_invalid = u.cdf([4.0, 5.0, 6.0]) 354 ``` 355 356 #### Shapes 357 358 There are three important concepts associated with TensorFlow Distributions 359 shapes: 360 - Event shape describes the shape of a single draw from the distribution; 361 it may be dependent across dimensions. For scalar distributions, the event 362 shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is 363 `[5]`. 364 - Batch shape describes independent, not identically distributed draws, aka a 365 "collection" or "bunch" of distributions. 366 - Sample shape describes independent, identically distributed draws of batches 367 from the distribution family. 368 369 The event shape and the batch shape are properties of a Distribution object, 370 whereas the sample shape is associated with a specific call to `sample` or 371 `log_prob`. 372 373 For detailed usage examples of TensorFlow Distributions shapes, see 374 [this tutorial]( 375 https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb) 376 377 #### Parameter values leading to undefined statistics or distributions. 378 379 Some distributions do not have well-defined statistics for all initialization 380 parameter values. For example, the beta distribution is parameterized by 381 positive real numbers `concentration1` and `concentration0`, and does not have 382 well-defined mode if `concentration1 < 1` or `concentration0 < 1`. 383 384 The user is given the option of raising an exception or returning `NaN`. 385 386 ```python 387 a = tf.exp(tf.matmul(logits, weights_a)) 388 b = tf.exp(tf.matmul(logits, weights_b)) 389 390 # Will raise exception if ANY batch member has a < 1 or b < 1. 391 dist = distributions.beta(a, b, allow_nan_stats=False) 392 mode = dist.mode().eval() 393 394 # Will return NaN for batch members with either a < 1 or b < 1. 395 dist = distributions.beta(a, b, allow_nan_stats=True) # Default behavior 396 mode = dist.mode().eval() 397 ``` 398 399 In all cases, an exception is raised if *invalid* parameters are passed, e.g. 400 401 ```python 402 # Will raise an exception if any Op is run. 403 negative_a = -1.0 * a # beta distribution by definition has a > 0. 404 dist = distributions.beta(negative_a, b, allow_nan_stats=True) 405 dist.mean().eval() 406 ``` 407 408 """ 409 410 @deprecation.deprecated( 411 "2019-01-01", 412 "The TensorFlow Distributions library has moved to " 413 "TensorFlow Probability " 414 "(https://github.com/tensorflow/probability). You " 415 "should update all references to use `tfp.distributions` " 416 "instead of `tf.distributions`.", 417 warn_once=True) 418 def __init__(self, 419 dtype, 420 reparameterization_type, 421 validate_args, 422 allow_nan_stats, 423 parameters=None, 424 graph_parents=None, 425 name=None): 426 """Constructs the `Distribution`. 427 428 **This is a private method for subclass use.** 429 430 Args: 431 dtype: The type of the event samples. `None` implies no type-enforcement. 432 reparameterization_type: Instance of `ReparameterizationType`. 433 If `distributions.FULLY_REPARAMETERIZED`, this 434 `Distribution` can be reparameterized in terms of some standard 435 distribution with a function whose Jacobian is constant for the support 436 of the standard distribution. If `distributions.NOT_REPARAMETERIZED`, 437 then no such reparameterization is available. 438 validate_args: Python `bool`, default `False`. When `True` distribution 439 parameters are checked for validity despite possibly degrading runtime 440 performance. When `False` invalid inputs may silently render incorrect 441 outputs. 442 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 443 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 444 result is undefined. When `False`, an exception is raised if one or 445 more of the statistic's batch members are undefined. 446 parameters: Python `dict` of parameters used to instantiate this 447 `Distribution`. 448 graph_parents: Python `list` of graph prerequisites of this 449 `Distribution`. 450 name: Python `str` name prefixed to Ops created by this class. Default: 451 subclass name. 452 453 Raises: 454 ValueError: if any member of graph_parents is `None` or not a `Tensor`. 455 """ 456 graph_parents = [] if graph_parents is None else graph_parents 457 for i, t in enumerate(graph_parents): 458 if t is None or not tensor_util.is_tf_type(t): 459 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 460 if not name or name[-1] != "/": # `name` is not a name scope 461 non_unique_name = name or type(self).__name__ 462 with ops.name_scope(non_unique_name) as name: 463 pass 464 self._dtype = dtype 465 self._reparameterization_type = reparameterization_type 466 self._allow_nan_stats = allow_nan_stats 467 self._validate_args = validate_args 468 self._parameters = parameters or {} 469 self._graph_parents = graph_parents 470 self._name = name 471 472 @property 473 def _parameters(self): 474 return self._parameter_dict 475 476 @_parameters.setter 477 def _parameters(self, value): 478 """Intercept assignments to self._parameters to avoid reference cycles. 479 480 Parameters are often created using locals(), so we need to clean out any 481 references to `self` before assigning it to an attribute. 482 483 Args: 484 value: A dictionary of parameters to assign to the `_parameters` property. 485 """ 486 if "self" in value: 487 del value["self"] 488 self._parameter_dict = value 489 490 @classmethod 491 def param_shapes(cls, sample_shape, name="DistributionParamShapes"): 492 """Shapes of parameters given the desired shape of a call to `sample()`. 493 494 This is a class method that describes what key/value arguments are required 495 to instantiate the given `Distribution` so that a particular shape is 496 returned for that instance's call to `sample()`. 497 498 Subclasses should override class method `_param_shapes`. 499 500 Args: 501 sample_shape: `Tensor` or python list/tuple. Desired shape of a call to 502 `sample()`. 503 name: name to prepend ops with. 504 505 Returns: 506 `dict` of parameter name to `Tensor` shapes. 507 """ 508 with ops.name_scope(name, values=[sample_shape]): 509 return cls._param_shapes(sample_shape) 510 511 @classmethod 512 def param_static_shapes(cls, sample_shape): 513 """param_shapes with static (i.e. `TensorShape`) shapes. 514 515 This is a class method that describes what key/value arguments are required 516 to instantiate the given `Distribution` so that a particular shape is 517 returned for that instance's call to `sample()`. Assumes that the sample's 518 shape is known statically. 519 520 Subclasses should override class method `_param_shapes` to return 521 constant-valued tensors when constant values are fed. 522 523 Args: 524 sample_shape: `TensorShape` or python list/tuple. Desired shape of a call 525 to `sample()`. 526 527 Returns: 528 `dict` of parameter name to `TensorShape`. 529 530 Raises: 531 ValueError: if `sample_shape` is a `TensorShape` and is not fully defined. 532 """ 533 if isinstance(sample_shape, tensor_shape.TensorShape): 534 if not sample_shape.is_fully_defined(): 535 raise ValueError("TensorShape sample_shape must be fully defined") 536 sample_shape = sample_shape.as_list() 537 538 params = cls.param_shapes(sample_shape) 539 540 static_params = {} 541 for name, shape in params.items(): 542 static_shape = tensor_util.constant_value(shape) 543 if static_shape is None: 544 raise ValueError( 545 "sample_shape must be a fully-defined TensorShape or list/tuple") 546 static_params[name] = tensor_shape.TensorShape(static_shape) 547 548 return static_params 549 550 @staticmethod 551 def _param_shapes(sample_shape): 552 raise NotImplementedError("_param_shapes not implemented") 553 554 @property 555 def name(self): 556 """Name prepended to all ops created by this `Distribution`.""" 557 return self._name 558 559 @property 560 def dtype(self): 561 """The `DType` of `Tensor`s handled by this `Distribution`.""" 562 return self._dtype 563 564 @property 565 def parameters(self): 566 """Dictionary of parameters used to instantiate this `Distribution`.""" 567 # Remove "self", "__class__", or other special variables. These can appear 568 # if the subclass used: 569 # `parameters = dict(locals())`. 570 return {k: v for k, v in self._parameters.items() 571 if not k.startswith("__") and k != "self"} 572 573 @property 574 def reparameterization_type(self): 575 """Describes how samples from the distribution are reparameterized. 576 577 Currently this is one of the static instances 578 `distributions.FULLY_REPARAMETERIZED` 579 or `distributions.NOT_REPARAMETERIZED`. 580 581 Returns: 582 An instance of `ReparameterizationType`. 583 """ 584 return self._reparameterization_type 585 586 @property 587 def allow_nan_stats(self): 588 """Python `bool` describing behavior when a stat is undefined. 589 590 Stats return +/- infinity when it makes sense. E.g., the variance of a 591 Cauchy distribution is infinity. However, sometimes the statistic is 592 undefined, e.g., if a distribution's pdf does not achieve a maximum within 593 the support of the distribution, the mode is undefined. If the mean is 594 undefined, then by definition the variance is undefined. E.g. the mean for 595 Student's T for df = 1 is undefined (no clear way to say it is either + or - 596 infinity), so the variance = E[(X - mean)**2] is also undefined. 597 598 Returns: 599 allow_nan_stats: Python `bool`. 600 """ 601 return self._allow_nan_stats 602 603 @property 604 def validate_args(self): 605 """Python `bool` indicating possibly expensive checks are enabled.""" 606 return self._validate_args 607 608 def copy(self, **override_parameters_kwargs): 609 """Creates a deep copy of the distribution. 610 611 Note: the copy distribution may continue to depend on the original 612 initialization arguments. 613 614 Args: 615 **override_parameters_kwargs: String/value dictionary of initialization 616 arguments to override with new values. 617 618 Returns: 619 distribution: A new instance of `type(self)` initialized from the union 620 of self.parameters and override_parameters_kwargs, i.e., 621 `dict(self.parameters, **override_parameters_kwargs)`. 622 """ 623 parameters = dict(self.parameters, **override_parameters_kwargs) 624 return type(self)(**parameters) 625 626 def _batch_shape_tensor(self): 627 raise NotImplementedError( 628 "batch_shape_tensor is not implemented: {}".format(type(self).__name__)) 629 630 def batch_shape_tensor(self, name="batch_shape_tensor"): 631 """Shape of a single sample from a single event index as a 1-D `Tensor`. 632 633 The batch dimensions are indexes into independent, non-identical 634 parameterizations of this distribution. 635 636 Args: 637 name: name to give to the op 638 639 Returns: 640 batch_shape: `Tensor`. 641 """ 642 with self._name_scope(name): 643 if self.batch_shape.is_fully_defined(): 644 return ops.convert_to_tensor(self.batch_shape.as_list(), 645 dtype=dtypes.int32, 646 name="batch_shape") 647 return self._batch_shape_tensor() 648 649 def _batch_shape(self): 650 return tensor_shape.TensorShape(None) 651 652 @property 653 def batch_shape(self): 654 """Shape of a single sample from a single event index as a `TensorShape`. 655 656 May be partially defined or unknown. 657 658 The batch dimensions are indexes into independent, non-identical 659 parameterizations of this distribution. 660 661 Returns: 662 batch_shape: `TensorShape`, possibly unknown. 663 """ 664 return tensor_shape.as_shape(self._batch_shape()) 665 666 def _event_shape_tensor(self): 667 raise NotImplementedError( 668 "event_shape_tensor is not implemented: {}".format(type(self).__name__)) 669 670 def event_shape_tensor(self, name="event_shape_tensor"): 671 """Shape of a single sample from a single batch as a 1-D int32 `Tensor`. 672 673 Args: 674 name: name to give to the op 675 676 Returns: 677 event_shape: `Tensor`. 678 """ 679 with self._name_scope(name): 680 if self.event_shape.is_fully_defined(): 681 return ops.convert_to_tensor(self.event_shape.as_list(), 682 dtype=dtypes.int32, 683 name="event_shape") 684 return self._event_shape_tensor() 685 686 def _event_shape(self): 687 return tensor_shape.TensorShape(None) 688 689 @property 690 def event_shape(self): 691 """Shape of a single sample from a single batch as a `TensorShape`. 692 693 May be partially defined or unknown. 694 695 Returns: 696 event_shape: `TensorShape`, possibly unknown. 697 """ 698 return tensor_shape.as_shape(self._event_shape()) 699 700 def is_scalar_event(self, name="is_scalar_event"): 701 """Indicates that `event_shape == []`. 702 703 Args: 704 name: Python `str` prepended to names of ops created by this function. 705 706 Returns: 707 is_scalar_event: `bool` scalar `Tensor`. 708 """ 709 with self._name_scope(name): 710 return ops.convert_to_tensor( 711 self._is_scalar_helper(self.event_shape, self.event_shape_tensor), 712 name="is_scalar_event") 713 714 def is_scalar_batch(self, name="is_scalar_batch"): 715 """Indicates that `batch_shape == []`. 716 717 Args: 718 name: Python `str` prepended to names of ops created by this function. 719 720 Returns: 721 is_scalar_batch: `bool` scalar `Tensor`. 722 """ 723 with self._name_scope(name): 724 return ops.convert_to_tensor( 725 self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor), 726 name="is_scalar_batch") 727 728 def _sample_n(self, n, seed=None): 729 raise NotImplementedError("sample_n is not implemented: {}".format( 730 type(self).__name__)) 731 732 def _call_sample_n(self, sample_shape, seed, name, **kwargs): 733 with self._name_scope(name, values=[sample_shape]): 734 sample_shape = ops.convert_to_tensor( 735 sample_shape, dtype=dtypes.int32, name="sample_shape") 736 sample_shape, n = self._expand_sample_shape_to_vector( 737 sample_shape, "sample_shape") 738 samples = self._sample_n(n, seed, **kwargs) 739 batch_event_shape = array_ops.shape(samples)[1:] 740 final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) 741 samples = array_ops.reshape(samples, final_shape) 742 samples = self._set_sample_static_shape(samples, sample_shape) 743 return samples 744 745 def sample(self, sample_shape=(), seed=None, name="sample"): 746 """Generate samples of the specified shape. 747 748 Note that a call to `sample()` without arguments will generate a single 749 sample. 750 751 Args: 752 sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples. 753 seed: Python integer seed for RNG 754 name: name to give to the op. 755 756 Returns: 757 samples: a `Tensor` with prepended dimensions `sample_shape`. 758 """ 759 return self._call_sample_n(sample_shape, seed, name) 760 761 def _log_prob(self, value): 762 raise NotImplementedError("log_prob is not implemented: {}".format( 763 type(self).__name__)) 764 765 def _call_log_prob(self, value, name, **kwargs): 766 with self._name_scope(name, values=[value]): 767 value = _convert_to_tensor( 768 value, name="value", preferred_dtype=self.dtype) 769 try: 770 return self._log_prob(value, **kwargs) 771 except NotImplementedError as original_exception: 772 try: 773 return math_ops.log(self._prob(value, **kwargs)) 774 except NotImplementedError: 775 raise original_exception 776 777 def log_prob(self, value, name="log_prob"): 778 """Log probability density/mass function. 779 780 Args: 781 value: `float` or `double` `Tensor`. 782 name: Python `str` prepended to names of ops created by this function. 783 784 Returns: 785 log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 786 values of type `self.dtype`. 787 """ 788 return self._call_log_prob(value, name) 789 790 def _prob(self, value): 791 raise NotImplementedError("prob is not implemented: {}".format( 792 type(self).__name__)) 793 794 def _call_prob(self, value, name, **kwargs): 795 with self._name_scope(name, values=[value]): 796 value = _convert_to_tensor( 797 value, name="value", preferred_dtype=self.dtype) 798 try: 799 return self._prob(value, **kwargs) 800 except NotImplementedError as original_exception: 801 try: 802 return math_ops.exp(self._log_prob(value, **kwargs)) 803 except NotImplementedError: 804 raise original_exception 805 806 def prob(self, value, name="prob"): 807 """Probability density/mass function. 808 809 Args: 810 value: `float` or `double` `Tensor`. 811 name: Python `str` prepended to names of ops created by this function. 812 813 Returns: 814 prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 815 values of type `self.dtype`. 816 """ 817 return self._call_prob(value, name) 818 819 def _log_cdf(self, value): 820 raise NotImplementedError("log_cdf is not implemented: {}".format( 821 type(self).__name__)) 822 823 def _call_log_cdf(self, value, name, **kwargs): 824 with self._name_scope(name, values=[value]): 825 value = _convert_to_tensor( 826 value, name="value", preferred_dtype=self.dtype) 827 try: 828 return self._log_cdf(value, **kwargs) 829 except NotImplementedError as original_exception: 830 try: 831 return math_ops.log(self._cdf(value, **kwargs)) 832 except NotImplementedError: 833 raise original_exception 834 835 def log_cdf(self, value, name="log_cdf"): 836 """Log cumulative distribution function. 837 838 Given random variable `X`, the cumulative distribution function `cdf` is: 839 840 ```none 841 log_cdf(x) := Log[ P[X <= x] ] 842 ``` 843 844 Often, a numerical approximation can be used for `log_cdf(x)` that yields 845 a more accurate answer than simply taking the logarithm of the `cdf` when 846 `x << -1`. 847 848 Args: 849 value: `float` or `double` `Tensor`. 850 name: Python `str` prepended to names of ops created by this function. 851 852 Returns: 853 logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 854 values of type `self.dtype`. 855 """ 856 return self._call_log_cdf(value, name) 857 858 def _cdf(self, value): 859 raise NotImplementedError("cdf is not implemented: {}".format( 860 type(self).__name__)) 861 862 def _call_cdf(self, value, name, **kwargs): 863 with self._name_scope(name, values=[value]): 864 value = _convert_to_tensor( 865 value, name="value", preferred_dtype=self.dtype) 866 try: 867 return self._cdf(value, **kwargs) 868 except NotImplementedError as original_exception: 869 try: 870 return math_ops.exp(self._log_cdf(value, **kwargs)) 871 except NotImplementedError: 872 raise original_exception 873 874 def cdf(self, value, name="cdf"): 875 """Cumulative distribution function. 876 877 Given random variable `X`, the cumulative distribution function `cdf` is: 878 879 ```none 880 cdf(x) := P[X <= x] 881 ``` 882 883 Args: 884 value: `float` or `double` `Tensor`. 885 name: Python `str` prepended to names of ops created by this function. 886 887 Returns: 888 cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 889 values of type `self.dtype`. 890 """ 891 return self._call_cdf(value, name) 892 893 def _log_survival_function(self, value): 894 raise NotImplementedError( 895 "log_survival_function is not implemented: {}".format( 896 type(self).__name__)) 897 898 def _call_log_survival_function(self, value, name, **kwargs): 899 with self._name_scope(name, values=[value]): 900 value = _convert_to_tensor( 901 value, name="value", preferred_dtype=self.dtype) 902 try: 903 return self._log_survival_function(value, **kwargs) 904 except NotImplementedError as original_exception: 905 try: 906 return math_ops.log1p(-self.cdf(value, **kwargs)) 907 except NotImplementedError: 908 raise original_exception 909 910 def log_survival_function(self, value, name="log_survival_function"): 911 """Log survival function. 912 913 Given random variable `X`, the survival function is defined: 914 915 ```none 916 log_survival_function(x) = Log[ P[X > x] ] 917 = Log[ 1 - P[X <= x] ] 918 = Log[ 1 - cdf(x) ] 919 ``` 920 921 Typically, different numerical approximations can be used for the log 922 survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`. 923 924 Args: 925 value: `float` or `double` `Tensor`. 926 name: Python `str` prepended to names of ops created by this function. 927 928 Returns: 929 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 930 `self.dtype`. 931 """ 932 return self._call_log_survival_function(value, name) 933 934 def _survival_function(self, value): 935 raise NotImplementedError("survival_function is not implemented: {}".format( 936 type(self).__name__)) 937 938 def _call_survival_function(self, value, name, **kwargs): 939 with self._name_scope(name, values=[value]): 940 value = _convert_to_tensor( 941 value, name="value", preferred_dtype=self.dtype) 942 try: 943 return self._survival_function(value, **kwargs) 944 except NotImplementedError as original_exception: 945 try: 946 return 1. - self.cdf(value, **kwargs) 947 except NotImplementedError: 948 raise original_exception 949 950 def survival_function(self, value, name="survival_function"): 951 """Survival function. 952 953 Given random variable `X`, the survival function is defined: 954 955 ```none 956 survival_function(x) = P[X > x] 957 = 1 - P[X <= x] 958 = 1 - cdf(x). 959 ``` 960 961 Args: 962 value: `float` or `double` `Tensor`. 963 name: Python `str` prepended to names of ops created by this function. 964 965 Returns: 966 `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type 967 `self.dtype`. 968 """ 969 return self._call_survival_function(value, name) 970 971 def _entropy(self): 972 raise NotImplementedError("entropy is not implemented: {}".format( 973 type(self).__name__)) 974 975 def entropy(self, name="entropy"): 976 """Shannon entropy in nats.""" 977 with self._name_scope(name): 978 return self._entropy() 979 980 def _mean(self): 981 raise NotImplementedError("mean is not implemented: {}".format( 982 type(self).__name__)) 983 984 def mean(self, name="mean"): 985 """Mean.""" 986 with self._name_scope(name): 987 return self._mean() 988 989 def _quantile(self, value): 990 raise NotImplementedError("quantile is not implemented: {}".format( 991 type(self).__name__)) 992 993 def _call_quantile(self, value, name, **kwargs): 994 with self._name_scope(name, values=[value]): 995 value = _convert_to_tensor( 996 value, name="value", preferred_dtype=self.dtype) 997 return self._quantile(value, **kwargs) 998 999 def quantile(self, value, name="quantile"): 1000 """Quantile function. Aka "inverse cdf" or "percent point function". 1001 1002 Given random variable `X` and `p in [0, 1]`, the `quantile` is: 1003 1004 ```none 1005 quantile(p) := x such that P[X <= x] == p 1006 ``` 1007 1008 Args: 1009 value: `float` or `double` `Tensor`. 1010 name: Python `str` prepended to names of ops created by this function. 1011 1012 Returns: 1013 quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with 1014 values of type `self.dtype`. 1015 """ 1016 return self._call_quantile(value, name) 1017 1018 def _variance(self): 1019 raise NotImplementedError("variance is not implemented: {}".format( 1020 type(self).__name__)) 1021 1022 def variance(self, name="variance"): 1023 """Variance. 1024 1025 Variance is defined as, 1026 1027 ```none 1028 Var = E[(X - E[X])**2] 1029 ``` 1030 1031 where `X` is the random variable associated with this distribution, `E` 1032 denotes expectation, and `Var.shape = batch_shape + event_shape`. 1033 1034 Args: 1035 name: Python `str` prepended to names of ops created by this function. 1036 1037 Returns: 1038 variance: Floating-point `Tensor` with shape identical to 1039 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 1040 """ 1041 with self._name_scope(name): 1042 try: 1043 return self._variance() 1044 except NotImplementedError as original_exception: 1045 try: 1046 return math_ops.square(self._stddev()) 1047 except NotImplementedError: 1048 raise original_exception 1049 1050 def _stddev(self): 1051 raise NotImplementedError("stddev is not implemented: {}".format( 1052 type(self).__name__)) 1053 1054 def stddev(self, name="stddev"): 1055 """Standard deviation. 1056 1057 Standard deviation is defined as, 1058 1059 ```none 1060 stddev = E[(X - E[X])**2]**0.5 1061 ``` 1062 1063 where `X` is the random variable associated with this distribution, `E` 1064 denotes expectation, and `stddev.shape = batch_shape + event_shape`. 1065 1066 Args: 1067 name: Python `str` prepended to names of ops created by this function. 1068 1069 Returns: 1070 stddev: Floating-point `Tensor` with shape identical to 1071 `batch_shape + event_shape`, i.e., the same shape as `self.mean()`. 1072 """ 1073 1074 with self._name_scope(name): 1075 try: 1076 return self._stddev() 1077 except NotImplementedError as original_exception: 1078 try: 1079 return math_ops.sqrt(self._variance()) 1080 except NotImplementedError: 1081 raise original_exception 1082 1083 def _covariance(self): 1084 raise NotImplementedError("covariance is not implemented: {}".format( 1085 type(self).__name__)) 1086 1087 def covariance(self, name="covariance"): 1088 """Covariance. 1089 1090 Covariance is (possibly) defined only for non-scalar-event distributions. 1091 1092 For example, for a length-`k`, vector-valued distribution, it is calculated 1093 as, 1094 1095 ```none 1096 Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])] 1097 ``` 1098 1099 where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E` 1100 denotes expectation. 1101 1102 Alternatively, for non-vector, multivariate distributions (e.g., 1103 matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices 1104 under some vectorization of the events, i.e., 1105 1106 ```none 1107 Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above] 1108 ``` 1109 1110 where `Cov` is a (batch of) `k' x k'` matrices, 1111 `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function 1112 mapping indices of this distribution's event dimensions to indices of a 1113 length-`k'` vector. 1114 1115 Args: 1116 name: Python `str` prepended to names of ops created by this function. 1117 1118 Returns: 1119 covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']` 1120 where the first `n` dimensions are batch coordinates and 1121 `k' = reduce_prod(self.event_shape)`. 1122 """ 1123 with self._name_scope(name): 1124 return self._covariance() 1125 1126 def _mode(self): 1127 raise NotImplementedError("mode is not implemented: {}".format( 1128 type(self).__name__)) 1129 1130 def mode(self, name="mode"): 1131 """Mode.""" 1132 with self._name_scope(name): 1133 return self._mode() 1134 1135 def _cross_entropy(self, other): 1136 return kullback_leibler.cross_entropy( 1137 self, other, allow_nan_stats=self.allow_nan_stats) 1138 1139 def cross_entropy(self, other, name="cross_entropy"): 1140 """Computes the (Shannon) cross entropy. 1141 1142 Denote this distribution (`self`) by `P` and the `other` distribution by 1143 `Q`. Assuming `P, Q` are absolutely continuous with respect to 1144 one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon) 1145 cross entropy is defined as: 1146 1147 ```none 1148 H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x) 1149 ``` 1150 1151 where `F` denotes the support of the random variable `X ~ P`. 1152 1153 Args: 1154 other: `tfp.distributions.Distribution` instance. 1155 name: Python `str` prepended to names of ops created by this function. 1156 1157 Returns: 1158 cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1159 representing `n` different calculations of (Shanon) cross entropy. 1160 """ 1161 with self._name_scope(name): 1162 return self._cross_entropy(other) 1163 1164 def _kl_divergence(self, other): 1165 return kullback_leibler.kl_divergence( 1166 self, other, allow_nan_stats=self.allow_nan_stats) 1167 1168 def kl_divergence(self, other, name="kl_divergence"): 1169 """Computes the Kullback--Leibler divergence. 1170 1171 Denote this distribution (`self`) by `p` and the `other` distribution by 1172 `q`. Assuming `p, q` are absolutely continuous with respect to reference 1173 measure `r`, the KL divergence is defined as: 1174 1175 ```none 1176 KL[p, q] = E_p[log(p(X)/q(X))] 1177 = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x) 1178 = H[p, q] - H[p] 1179 ``` 1180 1181 where `F` denotes the support of the random variable `X ~ p`, `H[., .]` 1182 denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy. 1183 1184 Args: 1185 other: `tfp.distributions.Distribution` instance. 1186 name: Python `str` prepended to names of ops created by this function. 1187 1188 Returns: 1189 kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]` 1190 representing `n` different calculations of the Kullback-Leibler 1191 divergence. 1192 """ 1193 with self._name_scope(name): 1194 return self._kl_divergence(other) 1195 1196 def __str__(self): 1197 return ("tfp.distributions.{type_name}(" 1198 "\"{self_name}\"" 1199 "{maybe_batch_shape}" 1200 "{maybe_event_shape}" 1201 ", dtype={dtype})".format( 1202 type_name=type(self).__name__, 1203 self_name=self.name, 1204 maybe_batch_shape=(", batch_shape={}".format(self.batch_shape) 1205 if self.batch_shape.ndims is not None 1206 else ""), 1207 maybe_event_shape=(", event_shape={}".format(self.event_shape) 1208 if self.event_shape.ndims is not None 1209 else ""), 1210 dtype=self.dtype.name)) 1211 1212 def __repr__(self): 1213 return ("<tfp.distributions.{type_name} " 1214 "'{self_name}'" 1215 " batch_shape={batch_shape}" 1216 " event_shape={event_shape}" 1217 " dtype={dtype}>".format( 1218 type_name=type(self).__name__, 1219 self_name=self.name, 1220 batch_shape=self.batch_shape, 1221 event_shape=self.event_shape, 1222 dtype=self.dtype.name)) 1223 1224 @contextlib.contextmanager 1225 def _name_scope(self, name=None, values=None): 1226 """Helper function to standardize op scope.""" 1227 with ops.name_scope(self.name): 1228 with ops.name_scope(name, values=( 1229 ([] if values is None else values) + self._graph_parents)) as scope: 1230 yield scope 1231 1232 def _expand_sample_shape_to_vector(self, x, name): 1233 """Helper to `sample` which ensures input is 1D.""" 1234 x_static_val = tensor_util.constant_value(x) 1235 if x_static_val is None: 1236 prod = math_ops.reduce_prod(x) 1237 else: 1238 prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype()) 1239 1240 ndims = x.get_shape().ndims # != sample_ndims 1241 if ndims is None: 1242 # Maybe expand_dims. 1243 ndims = array_ops.rank(x) 1244 expanded_shape = util.pick_vector( 1245 math_ops.equal(ndims, 0), 1246 np.array([1], dtype=np.int32), array_ops.shape(x)) 1247 x = array_ops.reshape(x, expanded_shape) 1248 elif ndims == 0: 1249 # Definitely expand_dims. 1250 if x_static_val is not None: 1251 x = ops.convert_to_tensor( 1252 np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()), 1253 name=name) 1254 else: 1255 x = array_ops.reshape(x, [1]) 1256 elif ndims != 1: 1257 raise ValueError("Input is neither scalar nor vector.") 1258 1259 return x, prod 1260 1261 def _set_sample_static_shape(self, x, sample_shape): 1262 """Helper to `sample`; sets static shape info.""" 1263 # Set shape hints. 1264 sample_shape = tensor_shape.TensorShape( 1265 tensor_util.constant_value(sample_shape)) 1266 1267 ndims = x.get_shape().ndims 1268 sample_ndims = sample_shape.ndims 1269 batch_ndims = self.batch_shape.ndims 1270 event_ndims = self.event_shape.ndims 1271 1272 # Infer rank(x). 1273 if (ndims is None and 1274 sample_ndims is not None and 1275 batch_ndims is not None and 1276 event_ndims is not None): 1277 ndims = sample_ndims + batch_ndims + event_ndims 1278 x.set_shape([None] * ndims) 1279 1280 # Infer sample shape. 1281 if ndims is not None and sample_ndims is not None: 1282 shape = sample_shape.concatenate([None]*(ndims - sample_ndims)) 1283 x.set_shape(x.get_shape().merge_with(shape)) 1284 1285 # Infer event shape. 1286 if ndims is not None and event_ndims is not None: 1287 shape = tensor_shape.TensorShape( 1288 [None]*(ndims - event_ndims)).concatenate(self.event_shape) 1289 x.set_shape(x.get_shape().merge_with(shape)) 1290 1291 # Infer batch shape. 1292 if batch_ndims is not None: 1293 if ndims is not None: 1294 if sample_ndims is None and event_ndims is not None: 1295 sample_ndims = ndims - batch_ndims - event_ndims 1296 elif event_ndims is None and sample_ndims is not None: 1297 event_ndims = ndims - batch_ndims - sample_ndims 1298 if sample_ndims is not None and event_ndims is not None: 1299 shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate( 1300 self.batch_shape).concatenate([None]*event_ndims) 1301 x.set_shape(x.get_shape().merge_with(shape)) 1302 1303 return x 1304 1305 def _is_scalar_helper(self, static_shape, dynamic_shape_fn): 1306 """Implementation for `is_scalar_batch` and `is_scalar_event`.""" 1307 if static_shape.ndims is not None: 1308 return static_shape.ndims == 0 1309 shape = dynamic_shape_fn() 1310 if (shape.get_shape().ndims is not None and 1311 shape.get_shape().dims[0].value is not None): 1312 # If the static_shape is correctly written then we should never execute 1313 # this branch. We keep it just in case there's some unimagined corner 1314 # case. 1315 return shape.get_shape().as_list() == [0] 1316 return math_ops.equal(array_ops.shape(shape)[0], 0) 1317