xref: /aosp_15_r20/external/pytorch/docs/source/notes/broadcasting.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker.. _broadcasting-semantics:
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard WorkerBroadcasting semantics
4*da0073e9SAndroid Build Coastguard Worker======================
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard WorkerMany PyTorch operations support NumPy's broadcasting semantics.
7*da0073e9SAndroid Build Coastguard WorkerSee https://numpy.org/doc/stable/user/basics.broadcasting.html for details.
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerIn short, if a PyTorch operation supports broadcast, then its Tensor arguments can be
10*da0073e9SAndroid Build Coastguard Workerautomatically expanded to be of equal sizes (without making copies of the data).
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerGeneral semantics
13*da0073e9SAndroid Build Coastguard Worker-----------------
14*da0073e9SAndroid Build Coastguard WorkerTwo tensors are "broadcastable" if the following rules hold:
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker- Each tensor has at least one dimension.
17*da0073e9SAndroid Build Coastguard Worker- When iterating over the dimension sizes, starting at the trailing dimension,
18*da0073e9SAndroid Build Coastguard Worker  the dimension sizes must either be equal, one of them is 1, or one of them
19*da0073e9SAndroid Build Coastguard Worker  does not exist.
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard WorkerFor Example::
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,7,3)
24*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(5,7,3)
25*da0073e9SAndroid Build Coastguard Worker    # same shapes are always broadcastable (i.e. the above rules always hold)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty((0,))
28*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(2,2)
29*da0073e9SAndroid Build Coastguard Worker    # x and y are not broadcastable, because x does not have at least 1 dimension
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    # can line up trailing dimensions
32*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,3,4,1)
33*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(  3,1,1)
34*da0073e9SAndroid Build Coastguard Worker    # x and y are broadcastable.
35*da0073e9SAndroid Build Coastguard Worker    # 1st trailing dimension: both have size 1
36*da0073e9SAndroid Build Coastguard Worker    # 2nd trailing dimension: y has size 1
37*da0073e9SAndroid Build Coastguard Worker    # 3rd trailing dimension: x size == y size
38*da0073e9SAndroid Build Coastguard Worker    # 4th trailing dimension: y dimension doesn't exist
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    # but:
41*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,2,4,1)
42*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(  3,1,1)
43*da0073e9SAndroid Build Coastguard Worker    # x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard WorkerIf two tensors :attr:`x`, :attr:`y` are "broadcastable", the resulting tensor size
46*da0073e9SAndroid Build Coastguard Workeris calculated as follows:
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker- If the number of dimensions of :attr:`x` and :attr:`y` are not equal, prepend 1
49*da0073e9SAndroid Build Coastguard Worker  to the dimensions of the tensor with fewer dimensions to make them equal length.
50*da0073e9SAndroid Build Coastguard Worker- Then, for each dimension size, the resulting dimension size is the max of the sizes of
51*da0073e9SAndroid Build Coastguard Worker  :attr:`x` and :attr:`y` along that dimension.
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard WorkerFor Example::
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    # can line up trailing dimensions to make reading easier
56*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,1,4,1)
57*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(  3,1,1)
58*da0073e9SAndroid Build Coastguard Worker    >>> (x+y).size()
59*da0073e9SAndroid Build Coastguard Worker    torch.Size([5, 3, 4, 1])
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    # but not necessary:
62*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(1)
63*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(3,1,7)
64*da0073e9SAndroid Build Coastguard Worker    >>> (x+y).size()
65*da0073e9SAndroid Build Coastguard Worker    torch.Size([3, 1, 7])
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,2,4,1)
68*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(3,1,1)
69*da0073e9SAndroid Build Coastguard Worker    >>> (x+y).size()
70*da0073e9SAndroid Build Coastguard Worker    RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerIn-place semantics
73*da0073e9SAndroid Build Coastguard Worker------------------
74*da0073e9SAndroid Build Coastguard WorkerOne complication is that in-place operations do not allow the in-place tensor to change shape
75*da0073e9SAndroid Build Coastguard Workeras a result of the broadcast.
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard WorkerFor Example::
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(5,3,4,1)
80*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(3,1,1)
81*da0073e9SAndroid Build Coastguard Worker    >>> (x.add_(y)).size()
82*da0073e9SAndroid Build Coastguard Worker    torch.Size([5, 3, 4, 1])
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    # but:
85*da0073e9SAndroid Build Coastguard Worker    >>> x=torch.empty(1,3,1)
86*da0073e9SAndroid Build Coastguard Worker    >>> y=torch.empty(3,1,7)
87*da0073e9SAndroid Build Coastguard Worker    >>> (x.add_(y)).size()
88*da0073e9SAndroid Build Coastguard Worker    RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard WorkerBackwards compatibility
91*da0073e9SAndroid Build Coastguard Worker-----------------------
92*da0073e9SAndroid Build Coastguard WorkerPrior versions of PyTorch allowed certain pointwise functions to execute on tensors with different shapes,
93*da0073e9SAndroid Build Coastguard Workeras long as the number of elements in each tensor was equal.  The pointwise operation would then be carried
94*da0073e9SAndroid Build Coastguard Workerout by viewing each tensor as 1-dimensional.  PyTorch now supports broadcasting and the "1-dimensional"
95*da0073e9SAndroid Build Coastguard Workerpointwise behavior is considered deprecated and will generate a Python warning in cases where tensors are
96*da0073e9SAndroid Build Coastguard Workernot broadcastable, but have the same number of elements.
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard WorkerNote that the introduction of broadcasting can cause backwards incompatible changes in the case where
99*da0073e9SAndroid Build Coastguard Workertwo tensors do not have the same shape, but are broadcastable and have the same number of elements.
100*da0073e9SAndroid Build Coastguard WorkerFor Example::
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    >>> torch.add(torch.ones(4,1), torch.randn(4))
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerwould previously produce a Tensor with size: torch.Size([4,1]), but now produces a Tensor with size: torch.Size([4,4]).
105*da0073e9SAndroid Build Coastguard WorkerIn order to help identify cases in your code where backwards incompatibilities introduced by broadcasting may exist,
106*da0073e9SAndroid Build Coastguard Workeryou may set `torch.utils.backcompat.broadcast_warning.enabled` to `True`, which will generate a python warning
107*da0073e9SAndroid Build Coastguard Workerin such cases.
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard WorkerFor Example::
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    >>> torch.utils.backcompat.broadcast_warning.enabled=True
112*da0073e9SAndroid Build Coastguard Worker    >>> torch.add(torch.ones(4,1), torch.ones(4))
113*da0073e9SAndroid Build Coastguard Worker    __main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
114*da0073e9SAndroid Build Coastguard Worker    Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.
115