xref: /aosp_15_r20/external/pytorch/docs/source/notes/broadcasting.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _broadcasting-semantics:
2
3Broadcasting semantics
4======================
5
6Many PyTorch operations support NumPy's broadcasting semantics.
7See https://numpy.org/doc/stable/user/basics.broadcasting.html for details.
8
9In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be
10automatically expanded to be of equal sizes (without making copies of the data).
11
12General semantics
13-----------------
14Two tensors are "broadcastable" if the following rules hold:
15
16- Each tensor has at least one dimension.
17- When iterating over the dimension sizes, starting at the trailing dimension,
18  the dimension sizes must either be equal, one of them is 1, or one of them
19  does not exist.
20
21For Example::
22
23    >>> x=torch.empty(5,7,3)
24    >>> y=torch.empty(5,7,3)
25    # same shapes are always broadcastable (i.e. the above rules always hold)
26
27    >>> x=torch.empty((0,))
28    >>> y=torch.empty(2,2)
29    # x and y are not broadcastable, because x does not have at least 1 dimension
30
31    # can line up trailing dimensions
32    >>> x=torch.empty(5,3,4,1)
33    >>> y=torch.empty(  3,1,1)
34    # x and y are broadcastable.
35    # 1st trailing dimension: both have size 1
36    # 2nd trailing dimension: y has size 1
37    # 3rd trailing dimension: x size == y size
38    # 4th trailing dimension: y dimension doesn't exist
39
40    # but:
41    >>> x=torch.empty(5,2,4,1)
42    >>> y=torch.empty(  3,1,1)
43    # x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
44
45If two tensors :attr:`x`, :attr:`y` are "broadcastable", the resulting tensor size
46is calculated as follows:
47
48- If the number of dimensions of :attr:`x` and :attr:`y` are not equal, prepend 1
49  to the dimensions of the tensor with fewer dimensions to make them equal length.
50- Then, for each dimension size, the resulting dimension size is the max of the sizes of
51  :attr:`x` and :attr:`y` along that dimension.
52
53For Example::
54
55    # can line up trailing dimensions to make reading easier
56    >>> x=torch.empty(5,1,4,1)
57    >>> y=torch.empty(  3,1,1)
58    >>> (x+y).size()
59    torch.Size([5, 3, 4, 1])
60
61    # but not necessary:
62    >>> x=torch.empty(1)
63    >>> y=torch.empty(3,1,7)
64    >>> (x+y).size()
65    torch.Size([3, 1, 7])
66
67    >>> x=torch.empty(5,2,4,1)
68    >>> y=torch.empty(3,1,1)
69    >>> (x+y).size()
70    RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
71
72In-place semantics
73------------------
74One complication is that in-place operations do not allow the in-place tensor to change shape
75as a result of the broadcast.
76
77For Example::
78
79    >>> x=torch.empty(5,3,4,1)
80    >>> y=torch.empty(3,1,1)
81    >>> (x.add_(y)).size()
82    torch.Size([5, 3, 4, 1])
83
84    # but:
85    >>> x=torch.empty(1,3,1)
86    >>> y=torch.empty(3,1,7)
87    >>> (x.add_(y)).size()
88    RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
89
90Backwards compatibility
91-----------------------
92Prior versions of PyTorch allowed certain pointwise functions to execute on tensors with different shapes,
93as long as the number of elements in each tensor was equal.  The pointwise operation would then be carried
94out by viewing each tensor as 1-dimensional.  PyTorch now supports broadcasting and the "1-dimensional"
95pointwise behavior is considered deprecated and will generate a Python warning in cases where tensors are
96not broadcastable, but have the same number of elements.
97
98Note that the introduction of broadcasting can cause backwards incompatible changes in the case where
99two tensors do not have the same shape, but are broadcastable and have the same number of elements.
100For Example::
101
102    >>> torch.add(torch.ones(4,1), torch.randn(4))
103
104would previously produce a Tensor with size: torch.Size([4,1]), but now produces a Tensor with size: torch.Size([4,4]).
105In order to help identify cases in your code where backwards incompatibilities introduced by broadcasting may exist,
106you may set `torch.utils.backcompat.broadcast_warning.enabled` to `True`, which will generate a python warning
107in such cases.
108
109For Example::
110
111    >>> torch.utils.backcompat.broadcast_warning.enabled=True
112    >>> torch.add(torch.ones(4,1), torch.ones(4))
113    __main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
114    Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
115