xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/distributions/distribution.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base classes for probability distributions."""
16
17import abc
18import contextlib
19import types
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops.distributions import kullback_leibler
31from tensorflow.python.ops.distributions import util
32from tensorflow.python.util import deprecation
33from tensorflow.python.util import tf_inspect
34from tensorflow.python.util.tf_export import tf_export
35
36
37__all__ = [
38    "ReparameterizationType",
39    "FULLY_REPARAMETERIZED",
40    "NOT_REPARAMETERIZED",
41    "Distribution",
42]
43
44_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
45    "batch_shape",
46    "batch_shape_tensor",
47    "cdf",
48    "covariance",
49    "cross_entropy",
50    "entropy",
51    "event_shape",
52    "event_shape_tensor",
53    "kl_divergence",
54    "log_cdf",
55    "log_prob",
56    "log_survival_function",
57    "mean",
58    "mode",
59    "prob",
60    "sample",
61    "stddev",
62    "survival_function",
63    "variance",
64]
65
66
67class _BaseDistribution(metaclass=abc.ABCMeta):
68  """Abstract base class needed for resolving subclass hierarchy."""
69  pass
70
71
72def _copy_fn(fn):
73  """Create a deep copy of fn.
74
75  Args:
76    fn: a callable
77
78  Returns:
79    A `FunctionType`: a deep copy of fn.
80
81  Raises:
82    TypeError: if `fn` is not a callable.
83  """
84  if not callable(fn):
85    raise TypeError("fn is not callable: %s" % fn)
86  # The blessed way to copy a function. copy.deepcopy fails to create a
87  # non-reference copy. Since:
88  #   types.FunctionType == type(lambda: None),
89  # and the docstring for the function type states:
90  #
91  #   function(code, globals[, name[, argdefs[, closure]]])
92  #
93  #   Create a function object from a code object and a dictionary.
94  #   ...
95  #
96  # Here we can use this to create a new function with the old function's
97  # code, globals, closure, etc.
98  return types.FunctionType(
99      code=fn.__code__, globals=fn.__globals__,
100      name=fn.__name__, argdefs=fn.__defaults__,
101      closure=fn.__closure__)
102
103
104def _update_docstring(old_str, append_str):
105  """Update old_str by inserting append_str just before the "Args:" section."""
106  old_str = old_str or ""
107  old_str_lines = old_str.split("\n")
108
109  # Step 0: Prepend spaces to all lines of append_str. This is
110  # necessary for correct markdown generation.
111  append_str = "\n".join("    %s" % line for line in append_str.split("\n"))
112
113  # Step 1: Find mention of "Args":
114  has_args_ix = [
115      ix for ix, line in enumerate(old_str_lines)
116      if line.strip().lower() == "args:"]
117  if has_args_ix:
118    final_args_ix = has_args_ix[-1]
119    return ("\n".join(old_str_lines[:final_args_ix])
120            + "\n\n" + append_str + "\n\n"
121            + "\n".join(old_str_lines[final_args_ix:]))
122  else:
123    return old_str + "\n\n" + append_str
124
125
126def _convert_to_tensor(value, name=None, preferred_dtype=None):
127  """Converts to tensor avoiding an eager bug that loses float precision."""
128  # TODO(b/116672045): Remove this function.
129  if (context.executing_eagerly() and preferred_dtype is not None and
130      (preferred_dtype.is_integer or preferred_dtype.is_bool)):
131    v = ops.convert_to_tensor(value, name=name)
132    if v.dtype.is_floating:
133      return v
134  return ops.convert_to_tensor(
135      value, name=name, preferred_dtype=preferred_dtype)
136
137
138class _DistributionMeta(abc.ABCMeta):
139
140  def __new__(mcs, classname, baseclasses, attrs):
141    """Control the creation of subclasses of the Distribution class.
142
143    The main purpose of this method is to properly propagate docstrings
144    from private Distribution methods, like `_log_prob`, into their
145    public wrappers as inherited by the Distribution base class
146    (e.g. `log_prob`).
147
148    Args:
149      classname: The name of the subclass being created.
150      baseclasses: A tuple of parent classes.
151      attrs: A dict mapping new attributes to their values.
152
153    Returns:
154      The class object.
155
156    Raises:
157      TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
158        the new class is derived via multiple inheritance and the first
159        parent class is not a subclass of `BaseDistribution`.
160      AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
161      ValueError:  If a `Distribution` public method lacks a docstring.
162    """
163    if not baseclasses:  # Nothing to be done for Distribution
164      raise TypeError("Expected non-empty baseclass. Does Distribution "
165                      "not subclass _BaseDistribution?")
166    which_base = [
167        base for base in baseclasses
168        if base == _BaseDistribution or issubclass(base, Distribution)]
169    base = which_base[0]
170    if base == _BaseDistribution:  # Nothing to be done for Distribution
171      return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
172    if not issubclass(base, Distribution):
173      raise TypeError("First parent class declared for %s must be "
174                      "Distribution, but saw '%s'" % (classname, base.__name__))
175    for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
176      special_attr = "_%s" % attr
177      class_attr_value = attrs.get(attr, None)
178      if attr in attrs:
179        # The method is being overridden, do not update its docstring
180        continue
181      base_attr_value = getattr(base, attr, None)
182      if not base_attr_value:
183        raise AttributeError(
184            "Internal error: expected base class '%s' to implement method '%s'"
185            % (base.__name__, attr))
186      class_special_attr_value = attrs.get(special_attr, None)
187      if class_special_attr_value is None:
188        # No _special method available, no need to update the docstring.
189        continue
190      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
191      if not class_special_attr_docstring:
192        # No docstring to append.
193        continue
194      class_attr_value = _copy_fn(base_attr_value)
195      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
196      if class_attr_docstring is None:
197        raise ValueError(
198            "Expected base class fn to contain a docstring: %s.%s"
199            % (base.__name__, attr))
200      class_attr_value.__doc__ = _update_docstring(
201          class_attr_value.__doc__,
202          ("Additional documentation from `%s`:\n\n%s"
203           % (classname, class_special_attr_docstring)))
204      attrs[attr] = class_attr_value
205
206    return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
207
208
209@tf_export(v1=["distributions.ReparameterizationType"])
210class ReparameterizationType:
211  """Instances of this class represent how sampling is reparameterized.
212
213  Two static instances exist in the distributions library, signifying
214  one of two possible properties for samples from a distribution:
215
216  `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
217    reparameterized, and straight-through gradients are supported.
218
219  `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
220    reparameterized, and straight-through gradients are either partially
221    unsupported or are not supported at all. In this case, for purposes of
222    e.g. RL or variational inference, it is generally safest to wrap the
223    sample results in a `stop_gradients` call and use policy
224    gradients / surrogate loss instead.
225  """
226
227  @deprecation.deprecated(
228      "2019-01-01",
229      "The TensorFlow Distributions library has moved to "
230      "TensorFlow Probability "
231      "(https://github.com/tensorflow/probability). You "
232      "should update all references to use `tfp.distributions` "
233      "instead of `tf.distributions`.",
234      warn_once=True)
235  def __init__(self, rep_type):
236    self._rep_type = rep_type
237
238  def __repr__(self):
239    return "<Reparameterization Type: %s>" % self._rep_type
240
241  def __eq__(self, other):
242    """Determine if this `ReparameterizationType` is equal to another.
243
244    Since ReparameterizationType instances are constant static global
245    instances, equality checks if two instances' id() values are equal.
246
247    Args:
248      other: Object to compare against.
249
250    Returns:
251      `self is other`.
252    """
253    return self is other
254
255
256# Fully reparameterized distribution: samples from a fully
257# reparameterized distribution support straight-through gradients with
258# respect to all parameters.
259FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
260tf_export(v1=["distributions.FULLY_REPARAMETERIZED"]).export_constant(
261    __name__, "FULLY_REPARAMETERIZED")
262
263
264# Not reparameterized distribution: samples from a non-
265# reparameterized distribution do not support straight-through gradients for
266# at least some of the parameters.
267NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
268tf_export(v1=["distributions.NOT_REPARAMETERIZED"]).export_constant(
269    __name__, "NOT_REPARAMETERIZED")
270
271
272@tf_export(v1=["distributions.Distribution"])
273class Distribution(_BaseDistribution, metaclass=_DistributionMeta):
274  """A generic probability distribution base class.
275
276  `Distribution` is a base class for constructing and organizing properties
277  (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
278
279  #### Subclassing
280
281  Subclasses are expected to implement a leading-underscore version of the
282  same-named function. The argument signature should be identical except for
283  the omission of `name="..."`. For example, to enable `log_prob(value,
284  name="log_prob")` a subclass should implement `_log_prob(value)`.
285
286  Subclasses can append to public-level docstrings by providing
287  docstrings for their method specializations. For example:
288
289  ```python
290  @util.AppendDocstring("Some other details.")
291  def _log_prob(self, value):
292    ...
293  ```
294
295  would add the string "Some other details." to the `log_prob` function
296  docstring. This is implemented as a simple decorator to avoid python
297  linter complaining about missing Args/Returns/Raises sections in the
298  partial docstrings.
299
300  #### Broadcasting, batching, and shapes
301
302  All distributions support batches of independent distributions of that type.
303  The batch shape is determined by broadcasting together the parameters.
304
305  The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
306  `log_prob` reflect this broadcasting, as does the return value of `sample` and
307  `sample_n`.
308
309  `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
310  the shape of the `Tensor` returned from `sample_n`, `n` is the number of
311  samples, `batch_shape` defines how many independent distributions there are,
312  and `event_shape` defines the shape of samples from each of those independent
313  distributions. Samples are independent along the `batch_shape` dimensions, but
314  not necessarily so along the `event_shape` dimensions (depending on the
315  particulars of the underlying distribution).
316
317  Using the `Uniform` distribution as an example:
318
319  ```python
320  minval = 3.0
321  maxval = [[4.0, 6.0],
322            [10.0, 12.0]]
323
324  # Broadcasting:
325  # This instance represents 4 Uniform distributions. Each has a lower bound at
326  # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
327  u = Uniform(minval, maxval)
328
329  # `event_shape` is `TensorShape([])`.
330  event_shape = u.event_shape
331  # `event_shape_t` is a `Tensor` which will evaluate to [].
332  event_shape_t = u.event_shape_tensor()
333
334  # Sampling returns a sample per distribution. `samples` has shape
335  # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
336  # batch_shape=[2, 2], and event_shape=[].
337  samples = u.sample_n(5)
338
339  # The broadcasting holds across methods. Here we use `cdf` as an example. The
340  # same holds for `log_cdf` and the likelihood functions.
341
342  # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
343  # shape of the `Uniform` instance.
344  cum_prob_broadcast = u.cdf(4.0)
345
346  # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
347  # occurred.
348  cum_prob_per_dist = u.cdf([[4.0, 5.0],
349                             [6.0, 7.0]])
350
351  # INVALID as the `value` argument is not broadcastable to the distribution's
352  # shape.
353  cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
354  ```
355
356  #### Shapes
357
358  There are three important concepts associated with TensorFlow Distributions
359  shapes:
360  - Event shape describes the shape of a single draw from the distribution;
361    it may be dependent across dimensions. For scalar distributions, the event
362    shape is `[]`. For a 5-dimensional MultivariateNormal, the event shape is
363    `[5]`.
364  - Batch shape describes independent, not identically distributed draws, aka a
365    "collection" or "bunch" of distributions.
366  - Sample shape describes independent, identically distributed draws of batches
367    from the distribution family.
368
369  The event shape and the batch shape are properties of a Distribution object,
370  whereas the sample shape is associated with a specific call to `sample` or
371  `log_prob`.
372
373  For detailed usage examples of TensorFlow Distributions shapes, see
374  [this tutorial](
375  https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Understanding_TensorFlow_Distributions_Shapes.ipynb)
376
377  #### Parameter values leading to undefined statistics or distributions.
378
379  Some distributions do not have well-defined statistics for all initialization
380  parameter values. For example, the beta distribution is parameterized by
381  positive real numbers `concentration1` and `concentration0`, and does not have
382  well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
383
384  The user is given the option of raising an exception or returning `NaN`.
385
386  ```python
387  a = tf.exp(tf.matmul(logits, weights_a))
388  b = tf.exp(tf.matmul(logits, weights_b))
389
390  # Will raise exception if ANY batch member has a < 1 or b < 1.
391  dist = distributions.beta(a, b, allow_nan_stats=False)
392  mode = dist.mode().eval()
393
394  # Will return NaN for batch members with either a < 1 or b < 1.
395  dist = distributions.beta(a, b, allow_nan_stats=True)  # Default behavior
396  mode = dist.mode().eval()
397  ```
398
399  In all cases, an exception is raised if *invalid* parameters are passed, e.g.
400
401  ```python
402  # Will raise an exception if any Op is run.
403  negative_a = -1.0 * a  # beta distribution by definition has a > 0.
404  dist = distributions.beta(negative_a, b, allow_nan_stats=True)
405  dist.mean().eval()
406  ```
407
408  """
409
410  @deprecation.deprecated(
411      "2019-01-01",
412      "The TensorFlow Distributions library has moved to "
413      "TensorFlow Probability "
414      "(https://github.com/tensorflow/probability). You "
415      "should update all references to use `tfp.distributions` "
416      "instead of `tf.distributions`.",
417      warn_once=True)
418  def __init__(self,
419               dtype,
420               reparameterization_type,
421               validate_args,
422               allow_nan_stats,
423               parameters=None,
424               graph_parents=None,
425               name=None):
426    """Constructs the `Distribution`.
427
428    **This is a private method for subclass use.**
429
430    Args:
431      dtype: The type of the event samples. `None` implies no type-enforcement.
432      reparameterization_type: Instance of `ReparameterizationType`.
433        If `distributions.FULLY_REPARAMETERIZED`, this
434        `Distribution` can be reparameterized in terms of some standard
435        distribution with a function whose Jacobian is constant for the support
436        of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
437        then no such reparameterization is available.
438      validate_args: Python `bool`, default `False`. When `True` distribution
439        parameters are checked for validity despite possibly degrading runtime
440        performance. When `False` invalid inputs may silently render incorrect
441        outputs.
442      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
443        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
444        result is undefined. When `False`, an exception is raised if one or
445        more of the statistic's batch members are undefined.
446      parameters: Python `dict` of parameters used to instantiate this
447        `Distribution`.
448      graph_parents: Python `list` of graph prerequisites of this
449        `Distribution`.
450      name: Python `str` name prefixed to Ops created by this class. Default:
451        subclass name.
452
453    Raises:
454      ValueError: if any member of graph_parents is `None` or not a `Tensor`.
455    """
456    graph_parents = [] if graph_parents is None else graph_parents
457    for i, t in enumerate(graph_parents):
458      if t is None or not tensor_util.is_tf_type(t):
459        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
460    if not name or name[-1] != "/":  # `name` is not a name scope
461      non_unique_name = name or type(self).__name__
462      with ops.name_scope(non_unique_name) as name:
463        pass
464    self._dtype = dtype
465    self._reparameterization_type = reparameterization_type
466    self._allow_nan_stats = allow_nan_stats
467    self._validate_args = validate_args
468    self._parameters = parameters or {}
469    self._graph_parents = graph_parents
470    self._name = name
471
472  @property
473  def _parameters(self):
474    return self._parameter_dict
475
476  @_parameters.setter
477  def _parameters(self, value):
478    """Intercept assignments to self._parameters to avoid reference cycles.
479
480    Parameters are often created using locals(), so we need to clean out any
481    references to `self` before assigning it to an attribute.
482
483    Args:
484      value: A dictionary of parameters to assign to the `_parameters` property.
485    """
486    if "self" in value:
487      del value["self"]
488    self._parameter_dict = value
489
490  @classmethod
491  def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
492    """Shapes of parameters given the desired shape of a call to `sample()`.
493
494    This is a class method that describes what key/value arguments are required
495    to instantiate the given `Distribution` so that a particular shape is
496    returned for that instance's call to `sample()`.
497
498    Subclasses should override class method `_param_shapes`.
499
500    Args:
501      sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
502        `sample()`.
503      name: name to prepend ops with.
504
505    Returns:
506      `dict` of parameter name to `Tensor` shapes.
507    """
508    with ops.name_scope(name, values=[sample_shape]):
509      return cls._param_shapes(sample_shape)
510
511  @classmethod
512  def param_static_shapes(cls, sample_shape):
513    """param_shapes with static (i.e. `TensorShape`) shapes.
514
515    This is a class method that describes what key/value arguments are required
516    to instantiate the given `Distribution` so that a particular shape is
517    returned for that instance's call to `sample()`. Assumes that the sample's
518    shape is known statically.
519
520    Subclasses should override class method `_param_shapes` to return
521    constant-valued tensors when constant values are fed.
522
523    Args:
524      sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
525        to `sample()`.
526
527    Returns:
528      `dict` of parameter name to `TensorShape`.
529
530    Raises:
531      ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
532    """
533    if isinstance(sample_shape, tensor_shape.TensorShape):
534      if not sample_shape.is_fully_defined():
535        raise ValueError("TensorShape sample_shape must be fully defined")
536      sample_shape = sample_shape.as_list()
537
538    params = cls.param_shapes(sample_shape)
539
540    static_params = {}
541    for name, shape in params.items():
542      static_shape = tensor_util.constant_value(shape)
543      if static_shape is None:
544        raise ValueError(
545            "sample_shape must be a fully-defined TensorShape or list/tuple")
546      static_params[name] = tensor_shape.TensorShape(static_shape)
547
548    return static_params
549
550  @staticmethod
551  def _param_shapes(sample_shape):
552    raise NotImplementedError("_param_shapes not implemented")
553
554  @property
555  def name(self):
556    """Name prepended to all ops created by this `Distribution`."""
557    return self._name
558
559  @property
560  def dtype(self):
561    """The `DType` of `Tensor`s handled by this `Distribution`."""
562    return self._dtype
563
564  @property
565  def parameters(self):
566    """Dictionary of parameters used to instantiate this `Distribution`."""
567    # Remove "self", "__class__", or other special variables. These can appear
568    # if the subclass used:
569    # `parameters = dict(locals())`.
570    return {k: v for k, v in self._parameters.items()
571            if not k.startswith("__") and k != "self"}
572
573  @property
574  def reparameterization_type(self):
575    """Describes how samples from the distribution are reparameterized.
576
577    Currently this is one of the static instances
578    `distributions.FULLY_REPARAMETERIZED`
579    or `distributions.NOT_REPARAMETERIZED`.
580
581    Returns:
582      An instance of `ReparameterizationType`.
583    """
584    return self._reparameterization_type
585
586  @property
587  def allow_nan_stats(self):
588    """Python `bool` describing behavior when a stat is undefined.
589
590    Stats return +/- infinity when it makes sense. E.g., the variance of a
591    Cauchy distribution is infinity. However, sometimes the statistic is
592    undefined, e.g., if a distribution's pdf does not achieve a maximum within
593    the support of the distribution, the mode is undefined. If the mean is
594    undefined, then by definition the variance is undefined. E.g. the mean for
595    Student's T for df = 1 is undefined (no clear way to say it is either + or -
596    infinity), so the variance = E[(X - mean)**2] is also undefined.
597
598    Returns:
599      allow_nan_stats: Python `bool`.
600    """
601    return self._allow_nan_stats
602
603  @property
604  def validate_args(self):
605    """Python `bool` indicating possibly expensive checks are enabled."""
606    return self._validate_args
607
608  def copy(self, **override_parameters_kwargs):
609    """Creates a deep copy of the distribution.
610
611    Note: the copy distribution may continue to depend on the original
612    initialization arguments.
613
614    Args:
615      **override_parameters_kwargs: String/value dictionary of initialization
616        arguments to override with new values.
617
618    Returns:
619      distribution: A new instance of `type(self)` initialized from the union
620        of self.parameters and override_parameters_kwargs, i.e.,
621        `dict(self.parameters, **override_parameters_kwargs)`.
622    """
623    parameters = dict(self.parameters, **override_parameters_kwargs)
624    return type(self)(**parameters)
625
626  def _batch_shape_tensor(self):
627    raise NotImplementedError(
628        "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
629
630  def batch_shape_tensor(self, name="batch_shape_tensor"):
631    """Shape of a single sample from a single event index as a 1-D `Tensor`.
632
633    The batch dimensions are indexes into independent, non-identical
634    parameterizations of this distribution.
635
636    Args:
637      name: name to give to the op
638
639    Returns:
640      batch_shape: `Tensor`.
641    """
642    with self._name_scope(name):
643      if self.batch_shape.is_fully_defined():
644        return ops.convert_to_tensor(self.batch_shape.as_list(),
645                                     dtype=dtypes.int32,
646                                     name="batch_shape")
647      return self._batch_shape_tensor()
648
649  def _batch_shape(self):
650    return tensor_shape.TensorShape(None)
651
652  @property
653  def batch_shape(self):
654    """Shape of a single sample from a single event index as a `TensorShape`.
655
656    May be partially defined or unknown.
657
658    The batch dimensions are indexes into independent, non-identical
659    parameterizations of this distribution.
660
661    Returns:
662      batch_shape: `TensorShape`, possibly unknown.
663    """
664    return tensor_shape.as_shape(self._batch_shape())
665
666  def _event_shape_tensor(self):
667    raise NotImplementedError(
668        "event_shape_tensor is not implemented: {}".format(type(self).__name__))
669
670  def event_shape_tensor(self, name="event_shape_tensor"):
671    """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
672
673    Args:
674      name: name to give to the op
675
676    Returns:
677      event_shape: `Tensor`.
678    """
679    with self._name_scope(name):
680      if self.event_shape.is_fully_defined():
681        return ops.convert_to_tensor(self.event_shape.as_list(),
682                                     dtype=dtypes.int32,
683                                     name="event_shape")
684      return self._event_shape_tensor()
685
686  def _event_shape(self):
687    return tensor_shape.TensorShape(None)
688
689  @property
690  def event_shape(self):
691    """Shape of a single sample from a single batch as a `TensorShape`.
692
693    May be partially defined or unknown.
694
695    Returns:
696      event_shape: `TensorShape`, possibly unknown.
697    """
698    return tensor_shape.as_shape(self._event_shape())
699
700  def is_scalar_event(self, name="is_scalar_event"):
701    """Indicates that `event_shape == []`.
702
703    Args:
704      name: Python `str` prepended to names of ops created by this function.
705
706    Returns:
707      is_scalar_event: `bool` scalar `Tensor`.
708    """
709    with self._name_scope(name):
710      return ops.convert_to_tensor(
711          self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
712          name="is_scalar_event")
713
714  def is_scalar_batch(self, name="is_scalar_batch"):
715    """Indicates that `batch_shape == []`.
716
717    Args:
718      name: Python `str` prepended to names of ops created by this function.
719
720    Returns:
721      is_scalar_batch: `bool` scalar `Tensor`.
722    """
723    with self._name_scope(name):
724      return ops.convert_to_tensor(
725          self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
726          name="is_scalar_batch")
727
728  def _sample_n(self, n, seed=None):
729    raise NotImplementedError("sample_n is not implemented: {}".format(
730        type(self).__name__))
731
732  def _call_sample_n(self, sample_shape, seed, name, **kwargs):
733    with self._name_scope(name, values=[sample_shape]):
734      sample_shape = ops.convert_to_tensor(
735          sample_shape, dtype=dtypes.int32, name="sample_shape")
736      sample_shape, n = self._expand_sample_shape_to_vector(
737          sample_shape, "sample_shape")
738      samples = self._sample_n(n, seed, **kwargs)
739      batch_event_shape = array_ops.shape(samples)[1:]
740      final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
741      samples = array_ops.reshape(samples, final_shape)
742      samples = self._set_sample_static_shape(samples, sample_shape)
743      return samples
744
745  def sample(self, sample_shape=(), seed=None, name="sample"):
746    """Generate samples of the specified shape.
747
748    Note that a call to `sample()` without arguments will generate a single
749    sample.
750
751    Args:
752      sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
753      seed: Python integer seed for RNG
754      name: name to give to the op.
755
756    Returns:
757      samples: a `Tensor` with prepended dimensions `sample_shape`.
758    """
759    return self._call_sample_n(sample_shape, seed, name)
760
761  def _log_prob(self, value):
762    raise NotImplementedError("log_prob is not implemented: {}".format(
763        type(self).__name__))
764
765  def _call_log_prob(self, value, name, **kwargs):
766    with self._name_scope(name, values=[value]):
767      value = _convert_to_tensor(
768          value, name="value", preferred_dtype=self.dtype)
769      try:
770        return self._log_prob(value, **kwargs)
771      except NotImplementedError as original_exception:
772        try:
773          return math_ops.log(self._prob(value, **kwargs))
774        except NotImplementedError:
775          raise original_exception
776
777  def log_prob(self, value, name="log_prob"):
778    """Log probability density/mass function.
779
780    Args:
781      value: `float` or `double` `Tensor`.
782      name: Python `str` prepended to names of ops created by this function.
783
784    Returns:
785      log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
786        values of type `self.dtype`.
787    """
788    return self._call_log_prob(value, name)
789
790  def _prob(self, value):
791    raise NotImplementedError("prob is not implemented: {}".format(
792        type(self).__name__))
793
794  def _call_prob(self, value, name, **kwargs):
795    with self._name_scope(name, values=[value]):
796      value = _convert_to_tensor(
797          value, name="value", preferred_dtype=self.dtype)
798      try:
799        return self._prob(value, **kwargs)
800      except NotImplementedError as original_exception:
801        try:
802          return math_ops.exp(self._log_prob(value, **kwargs))
803        except NotImplementedError:
804          raise original_exception
805
806  def prob(self, value, name="prob"):
807    """Probability density/mass function.
808
809    Args:
810      value: `float` or `double` `Tensor`.
811      name: Python `str` prepended to names of ops created by this function.
812
813    Returns:
814      prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
815        values of type `self.dtype`.
816    """
817    return self._call_prob(value, name)
818
819  def _log_cdf(self, value):
820    raise NotImplementedError("log_cdf is not implemented: {}".format(
821        type(self).__name__))
822
823  def _call_log_cdf(self, value, name, **kwargs):
824    with self._name_scope(name, values=[value]):
825      value = _convert_to_tensor(
826          value, name="value", preferred_dtype=self.dtype)
827      try:
828        return self._log_cdf(value, **kwargs)
829      except NotImplementedError as original_exception:
830        try:
831          return math_ops.log(self._cdf(value, **kwargs))
832        except NotImplementedError:
833          raise original_exception
834
835  def log_cdf(self, value, name="log_cdf"):
836    """Log cumulative distribution function.
837
838    Given random variable `X`, the cumulative distribution function `cdf` is:
839
840    ```none
841    log_cdf(x) := Log[ P[X <= x] ]
842    ```
843
844    Often, a numerical approximation can be used for `log_cdf(x)` that yields
845    a more accurate answer than simply taking the logarithm of the `cdf` when
846    `x << -1`.
847
848    Args:
849      value: `float` or `double` `Tensor`.
850      name: Python `str` prepended to names of ops created by this function.
851
852    Returns:
853      logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
854        values of type `self.dtype`.
855    """
856    return self._call_log_cdf(value, name)
857
858  def _cdf(self, value):
859    raise NotImplementedError("cdf is not implemented: {}".format(
860        type(self).__name__))
861
862  def _call_cdf(self, value, name, **kwargs):
863    with self._name_scope(name, values=[value]):
864      value = _convert_to_tensor(
865          value, name="value", preferred_dtype=self.dtype)
866      try:
867        return self._cdf(value, **kwargs)
868      except NotImplementedError as original_exception:
869        try:
870          return math_ops.exp(self._log_cdf(value, **kwargs))
871        except NotImplementedError:
872          raise original_exception
873
874  def cdf(self, value, name="cdf"):
875    """Cumulative distribution function.
876
877    Given random variable `X`, the cumulative distribution function `cdf` is:
878
879    ```none
880    cdf(x) := P[X <= x]
881    ```
882
883    Args:
884      value: `float` or `double` `Tensor`.
885      name: Python `str` prepended to names of ops created by this function.
886
887    Returns:
888      cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
889        values of type `self.dtype`.
890    """
891    return self._call_cdf(value, name)
892
893  def _log_survival_function(self, value):
894    raise NotImplementedError(
895        "log_survival_function is not implemented: {}".format(
896            type(self).__name__))
897
898  def _call_log_survival_function(self, value, name, **kwargs):
899    with self._name_scope(name, values=[value]):
900      value = _convert_to_tensor(
901          value, name="value", preferred_dtype=self.dtype)
902      try:
903        return self._log_survival_function(value, **kwargs)
904      except NotImplementedError as original_exception:
905        try:
906          return math_ops.log1p(-self.cdf(value, **kwargs))
907        except NotImplementedError:
908          raise original_exception
909
910  def log_survival_function(self, value, name="log_survival_function"):
911    """Log survival function.
912
913    Given random variable `X`, the survival function is defined:
914
915    ```none
916    log_survival_function(x) = Log[ P[X > x] ]
917                             = Log[ 1 - P[X <= x] ]
918                             = Log[ 1 - cdf(x) ]
919    ```
920
921    Typically, different numerical approximations can be used for the log
922    survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
923
924    Args:
925      value: `float` or `double` `Tensor`.
926      name: Python `str` prepended to names of ops created by this function.
927
928    Returns:
929      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
930        `self.dtype`.
931    """
932    return self._call_log_survival_function(value, name)
933
934  def _survival_function(self, value):
935    raise NotImplementedError("survival_function is not implemented: {}".format(
936        type(self).__name__))
937
938  def _call_survival_function(self, value, name, **kwargs):
939    with self._name_scope(name, values=[value]):
940      value = _convert_to_tensor(
941          value, name="value", preferred_dtype=self.dtype)
942      try:
943        return self._survival_function(value, **kwargs)
944      except NotImplementedError as original_exception:
945        try:
946          return 1. - self.cdf(value, **kwargs)
947        except NotImplementedError:
948          raise original_exception
949
950  def survival_function(self, value, name="survival_function"):
951    """Survival function.
952
953    Given random variable `X`, the survival function is defined:
954
955    ```none
956    survival_function(x) = P[X > x]
957                         = 1 - P[X <= x]
958                         = 1 - cdf(x).
959    ```
960
961    Args:
962      value: `float` or `double` `Tensor`.
963      name: Python `str` prepended to names of ops created by this function.
964
965    Returns:
966      `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
967        `self.dtype`.
968    """
969    return self._call_survival_function(value, name)
970
971  def _entropy(self):
972    raise NotImplementedError("entropy is not implemented: {}".format(
973        type(self).__name__))
974
975  def entropy(self, name="entropy"):
976    """Shannon entropy in nats."""
977    with self._name_scope(name):
978      return self._entropy()
979
980  def _mean(self):
981    raise NotImplementedError("mean is not implemented: {}".format(
982        type(self).__name__))
983
984  def mean(self, name="mean"):
985    """Mean."""
986    with self._name_scope(name):
987      return self._mean()
988
989  def _quantile(self, value):
990    raise NotImplementedError("quantile is not implemented: {}".format(
991        type(self).__name__))
992
993  def _call_quantile(self, value, name, **kwargs):
994    with self._name_scope(name, values=[value]):
995      value = _convert_to_tensor(
996          value, name="value", preferred_dtype=self.dtype)
997      return self._quantile(value, **kwargs)
998
999  def quantile(self, value, name="quantile"):
1000    """Quantile function. Aka "inverse cdf" or "percent point function".
1001
1002    Given random variable `X` and `p in [0, 1]`, the `quantile` is:
1003
1004    ```none
1005    quantile(p) := x such that P[X <= x] == p
1006    ```
1007
1008    Args:
1009      value: `float` or `double` `Tensor`.
1010      name: Python `str` prepended to names of ops created by this function.
1011
1012    Returns:
1013      quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
1014        values of type `self.dtype`.
1015    """
1016    return self._call_quantile(value, name)
1017
1018  def _variance(self):
1019    raise NotImplementedError("variance is not implemented: {}".format(
1020        type(self).__name__))
1021
1022  def variance(self, name="variance"):
1023    """Variance.
1024
1025    Variance is defined as,
1026
1027    ```none
1028    Var = E[(X - E[X])**2]
1029    ```
1030
1031    where `X` is the random variable associated with this distribution, `E`
1032    denotes expectation, and `Var.shape = batch_shape + event_shape`.
1033
1034    Args:
1035      name: Python `str` prepended to names of ops created by this function.
1036
1037    Returns:
1038      variance: Floating-point `Tensor` with shape identical to
1039        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1040    """
1041    with self._name_scope(name):
1042      try:
1043        return self._variance()
1044      except NotImplementedError as original_exception:
1045        try:
1046          return math_ops.square(self._stddev())
1047        except NotImplementedError:
1048          raise original_exception
1049
1050  def _stddev(self):
1051    raise NotImplementedError("stddev is not implemented: {}".format(
1052        type(self).__name__))
1053
1054  def stddev(self, name="stddev"):
1055    """Standard deviation.
1056
1057    Standard deviation is defined as,
1058
1059    ```none
1060    stddev = E[(X - E[X])**2]**0.5
1061    ```
1062
1063    where `X` is the random variable associated with this distribution, `E`
1064    denotes expectation, and `stddev.shape = batch_shape + event_shape`.
1065
1066    Args:
1067      name: Python `str` prepended to names of ops created by this function.
1068
1069    Returns:
1070      stddev: Floating-point `Tensor` with shape identical to
1071        `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
1072    """
1073
1074    with self._name_scope(name):
1075      try:
1076        return self._stddev()
1077      except NotImplementedError as original_exception:
1078        try:
1079          return math_ops.sqrt(self._variance())
1080        except NotImplementedError:
1081          raise original_exception
1082
1083  def _covariance(self):
1084    raise NotImplementedError("covariance is not implemented: {}".format(
1085        type(self).__name__))
1086
1087  def covariance(self, name="covariance"):
1088    """Covariance.
1089
1090    Covariance is (possibly) defined only for non-scalar-event distributions.
1091
1092    For example, for a length-`k`, vector-valued distribution, it is calculated
1093    as,
1094
1095    ```none
1096    Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
1097    ```
1098
1099    where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
1100    denotes expectation.
1101
1102    Alternatively, for non-vector, multivariate distributions (e.g.,
1103    matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
1104    under some vectorization of the events, i.e.,
1105
1106    ```none
1107    Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
1108    ```
1109
1110    where `Cov` is a (batch of) `k' x k'` matrices,
1111    `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
1112    mapping indices of this distribution's event dimensions to indices of a
1113    length-`k'` vector.
1114
1115    Args:
1116      name: Python `str` prepended to names of ops created by this function.
1117
1118    Returns:
1119      covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
1120        where the first `n` dimensions are batch coordinates and
1121        `k' = reduce_prod(self.event_shape)`.
1122    """
1123    with self._name_scope(name):
1124      return self._covariance()
1125
1126  def _mode(self):
1127    raise NotImplementedError("mode is not implemented: {}".format(
1128        type(self).__name__))
1129
1130  def mode(self, name="mode"):
1131    """Mode."""
1132    with self._name_scope(name):
1133      return self._mode()
1134
1135  def _cross_entropy(self, other):
1136    return kullback_leibler.cross_entropy(
1137        self, other, allow_nan_stats=self.allow_nan_stats)
1138
1139  def cross_entropy(self, other, name="cross_entropy"):
1140    """Computes the (Shannon) cross entropy.
1141
1142    Denote this distribution (`self`) by `P` and the `other` distribution by
1143    `Q`. Assuming `P, Q` are absolutely continuous with respect to
1144    one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
1145    cross entropy is defined as:
1146
1147    ```none
1148    H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
1149    ```
1150
1151    where `F` denotes the support of the random variable `X ~ P`.
1152
1153    Args:
1154      other: `tfp.distributions.Distribution` instance.
1155      name: Python `str` prepended to names of ops created by this function.
1156
1157    Returns:
1158      cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1159        representing `n` different calculations of (Shanon) cross entropy.
1160    """
1161    with self._name_scope(name):
1162      return self._cross_entropy(other)
1163
1164  def _kl_divergence(self, other):
1165    return kullback_leibler.kl_divergence(
1166        self, other, allow_nan_stats=self.allow_nan_stats)
1167
1168  def kl_divergence(self, other, name="kl_divergence"):
1169    """Computes the Kullback--Leibler divergence.
1170
1171    Denote this distribution (`self`) by `p` and the `other` distribution by
1172    `q`. Assuming `p, q` are absolutely continuous with respect to reference
1173    measure `r`, the KL divergence is defined as:
1174
1175    ```none
1176    KL[p, q] = E_p[log(p(X)/q(X))]
1177             = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
1178             = H[p, q] - H[p]
1179    ```
1180
1181    where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
1182    denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
1183
1184    Args:
1185      other: `tfp.distributions.Distribution` instance.
1186      name: Python `str` prepended to names of ops created by this function.
1187
1188    Returns:
1189      kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
1190        representing `n` different calculations of the Kullback-Leibler
1191        divergence.
1192    """
1193    with self._name_scope(name):
1194      return self._kl_divergence(other)
1195
1196  def __str__(self):
1197    return ("tfp.distributions.{type_name}("
1198            "\"{self_name}\""
1199            "{maybe_batch_shape}"
1200            "{maybe_event_shape}"
1201            ", dtype={dtype})".format(
1202                type_name=type(self).__name__,
1203                self_name=self.name,
1204                maybe_batch_shape=(", batch_shape={}".format(self.batch_shape)
1205                                   if self.batch_shape.ndims is not None
1206                                   else ""),
1207                maybe_event_shape=(", event_shape={}".format(self.event_shape)
1208                                   if self.event_shape.ndims is not None
1209                                   else ""),
1210                dtype=self.dtype.name))
1211
1212  def __repr__(self):
1213    return ("<tfp.distributions.{type_name} "
1214            "'{self_name}'"
1215            " batch_shape={batch_shape}"
1216            " event_shape={event_shape}"
1217            " dtype={dtype}>".format(
1218                type_name=type(self).__name__,
1219                self_name=self.name,
1220                batch_shape=self.batch_shape,
1221                event_shape=self.event_shape,
1222                dtype=self.dtype.name))
1223
1224  @contextlib.contextmanager
1225  def _name_scope(self, name=None, values=None):
1226    """Helper function to standardize op scope."""
1227    with ops.name_scope(self.name):
1228      with ops.name_scope(name, values=(
1229          ([] if values is None else values) + self._graph_parents)) as scope:
1230        yield scope
1231
1232  def _expand_sample_shape_to_vector(self, x, name):
1233    """Helper to `sample` which ensures input is 1D."""
1234    x_static_val = tensor_util.constant_value(x)
1235    if x_static_val is None:
1236      prod = math_ops.reduce_prod(x)
1237    else:
1238      prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
1239
1240    ndims = x.get_shape().ndims  # != sample_ndims
1241    if ndims is None:
1242      # Maybe expand_dims.
1243      ndims = array_ops.rank(x)
1244      expanded_shape = util.pick_vector(
1245          math_ops.equal(ndims, 0),
1246          np.array([1], dtype=np.int32), array_ops.shape(x))
1247      x = array_ops.reshape(x, expanded_shape)
1248    elif ndims == 0:
1249      # Definitely expand_dims.
1250      if x_static_val is not None:
1251        x = ops.convert_to_tensor(
1252            np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
1253            name=name)
1254      else:
1255        x = array_ops.reshape(x, [1])
1256    elif ndims != 1:
1257      raise ValueError("Input is neither scalar nor vector.")
1258
1259    return x, prod
1260
1261  def _set_sample_static_shape(self, x, sample_shape):
1262    """Helper to `sample`; sets static shape info."""
1263    # Set shape hints.
1264    sample_shape = tensor_shape.TensorShape(
1265        tensor_util.constant_value(sample_shape))
1266
1267    ndims = x.get_shape().ndims
1268    sample_ndims = sample_shape.ndims
1269    batch_ndims = self.batch_shape.ndims
1270    event_ndims = self.event_shape.ndims
1271
1272    # Infer rank(x).
1273    if (ndims is None and
1274        sample_ndims is not None and
1275        batch_ndims is not None and
1276        event_ndims is not None):
1277      ndims = sample_ndims + batch_ndims + event_ndims
1278      x.set_shape([None] * ndims)
1279
1280    # Infer sample shape.
1281    if ndims is not None and sample_ndims is not None:
1282      shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
1283      x.set_shape(x.get_shape().merge_with(shape))
1284
1285    # Infer event shape.
1286    if ndims is not None and event_ndims is not None:
1287      shape = tensor_shape.TensorShape(
1288          [None]*(ndims - event_ndims)).concatenate(self.event_shape)
1289      x.set_shape(x.get_shape().merge_with(shape))
1290
1291    # Infer batch shape.
1292    if batch_ndims is not None:
1293      if ndims is not None:
1294        if sample_ndims is None and event_ndims is not None:
1295          sample_ndims = ndims - batch_ndims - event_ndims
1296        elif event_ndims is None and sample_ndims is not None:
1297          event_ndims = ndims - batch_ndims - sample_ndims
1298      if sample_ndims is not None and event_ndims is not None:
1299        shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
1300            self.batch_shape).concatenate([None]*event_ndims)
1301        x.set_shape(x.get_shape().merge_with(shape))
1302
1303    return x
1304
1305  def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
1306    """Implementation for `is_scalar_batch` and `is_scalar_event`."""
1307    if static_shape.ndims is not None:
1308      return static_shape.ndims == 0
1309    shape = dynamic_shape_fn()
1310    if (shape.get_shape().ndims is not None and
1311        shape.get_shape().dims[0].value is not None):
1312      # If the static_shape is correctly written then we should never execute
1313      # this branch. We keep it just in case there's some unimagined corner
1314      # case.
1315      return shape.get_shape().as_list() == [0]
1316    return math_ops.equal(array_ops.shape(shape)[0], 0)
1317