xref: /aosp_15_r20/external/executorch/extension/llm/modules/_position_embeddings.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# An torch.export() friendly version of torchtune's positional embeddings.
8# Added torch._check() to make sure guards on symints are enforced.
9# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py
10
11import logging
12import math
13from typing import Any, Dict, Tuple
14
15import torch
16import torch.nn.functional as F
17from torch import nn
18from torch.distributed._tensor import distribute_tensor, DTensor
19
20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21logging.basicConfig(level=logging.INFO, format=FORMAT)
22
23
24class TilePositionalEmbedding(nn.Module):
25    """
26    Positional embedding for tiles, different for every tile, same for every token within a tile.
27
28    Notice that tile is different from patch (token). For details, please check the documentation of
29    :class:`torchtune.modules.vision_transformer.VisionTransformer`.
30
31    Args:
32        max_num_tiles (int): The maximum number of tiles an image can be divided into.
33        embed_dim (int): The dimensionality of each tile embedding.
34    """
35
36    def __init__(
37        self,
38        max_num_tiles: int,
39        embed_dim: int,
40    ):
41        super().__init__()
42        self.max_num_tiles = max_num_tiles
43        self.embed_dim = embed_dim
44
45        scale = embed_dim**-0.5
46        self.embedding = nn.Parameter(
47            scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim)
48        )
49        self.gate = nn.Parameter(torch.zeros(1))
50
51        # Register load hook to interpolate positional embeddings
52        self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
53
54    # TODO: Switch to public method after 2.5 is stable
55    @torch.no_grad()
56    def _load_state_dict_hook(
57        self,
58        state_dict: Dict[str, Any],
59        prefix: str,
60        *args: Tuple[Any],
61        **kwargs: Dict[str, Any],
62    ):
63        """
64        Interpolates positional embeddings to accomodate different number of tiles,
65        in case the model was instantiated with different
66        settings than the one you are loading the state dict from.
67
68        For more info, check self._dynamic_resize function.
69
70        Args:
71            state_dict (Dict[str, Any]): The state dict to load.
72            prefix (str): The prefix of the state dict.
73            *args (Tuple[Any]): Additional positional arguments.
74            **kwargs (Dict[str, Any]): Additional keyword arguments.
75
76        Raises:
77            ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
78            ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
79            ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
80        """
81
82        embedding = state_dict.get(prefix + "embedding")
83
84        if embedding is not None:
85
86            # ckpt pos emb
87            (
88                tgt_max_num_tiles_x,
89                tgt_max_num_tiles_y,
90                tgt_num_tokens,
91                tgt_emb,
92            ) = self.embedding.shape
93
94            # instantiated pos emb
95            (
96                inpt_max_num_tiles_x,
97                inpt_max_num_tiles_y,
98                inpt_num_tokens,
99                inpt_emb,
100            ) = state_dict[prefix + "embedding"].shape
101
102            # sanity check
103            if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb:
104                raise ValueError(
105                    "Expected embedding shape to be (..., num_tokens, tgt_emb) to match"
106                    f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}"
107                )
108
109            if inpt_max_num_tiles_x != inpt_max_num_tiles_y:
110                raise ValueError(
111                    "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found"
112                    f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}"
113                )
114
115            # resize ckpt to match instantiated shape
116            embedding_new = self._resize_position_embedding(
117                embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
118            )
119
120            # update state dict
121            state_dict[prefix + "embedding"] = embedding_new
122            if embedding_new.shape != self.embedding.shape:
123                raise ValueError(
124                    "Expected embedding shape and embedding_new.shape to match"
125                    f" but found shapes {self.embedding.shape} and {embedding_new.shape}"
126                )
127
128    @staticmethod
129    def _resize_position_embedding(
130        embedding: torch.Tensor, tgt_max_num_tiles: int
131    ) -> torch.Tensor:
132        """
133        Interpolates positional embeddings to accomodate a different max_num_tiles. These
134        are the only dimensions that changes during interpolation.
135
136        Args:
137            embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim
138            tgt_max_num_tiles (int): The number of tiles to resize to.
139
140        Returns:
141            torch.Tensor: The resized embedding.
142
143        Example:
144            >>> import torch
145            >>> # create dummy embedding
146            >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float()
147            >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1)
148            >>> print(resized_embed.shape)
149            >>> torch.Size([1, 1, 2, 2])
150        """
151        # set max_num_tiles to the last dimension
152        embedding = embedding.permute(2, 3, 0, 1)
153
154        embedding = F.interpolate(
155            embedding,
156            size=(tgt_max_num_tiles, tgt_max_num_tiles),
157            mode="bilinear",
158            align_corners=True,
159        )
160        # permute to the original shape
161        embedding = embedding.permute(2, 3, 0, 1)
162        return embedding
163
164    def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
165        """
166        args:
167            x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
168            aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
169                representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
170        returns:
171            torch.Tensor: The input tensor with added positional embeddings.
172        """
173        bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape
174        torch._check(n_tiles <= self.max_num_tiles)
175
176        for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
177            # When we batch images, all are padded to the same amount of tiles.
178            # The aspect_ratio lets us know the non padded tiles for each image.
179            # We only add positional encoding to those.
180            n_tiles_h = n_tiles_h.item()
181            n_tiles_w = n_tiles_w.item()
182
183            n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
184
185            # We get only the positional encoding for non padded tiles,
186            # i.e. n_tiles_h, n_tiles_w.
187            torch._check_is_size(n_tiles_h)
188            torch._check_is_size(n_tiles_w)
189            torch._check(n_tiles_h >= 1)
190            torch._check(n_tiles_w >= 1)
191            torch._check(n_tiles_h <= self.max_num_tiles)
192            torch._check(n_tiles_w <= self.max_num_tiles)
193            # TODO: Remove this once pytorch/pytorch#120288 is fixed
194            padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
195            pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
196
197            # We need to do a clone here in order to make this model export
198            # friendly as the reshape is collapsing dim 0 and dim 1 into a
199            # single dim.
200            pos_embed = pos_embed.clone()
201            pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)
202
203            x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
204            torch._check_is_size(n_non_padded_tiles)
205            torch._check(n_non_padded_tiles < x.size(1))
206            x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh()
207            x = x[:, :n_tiles, :, :]
208
209        return x
210
211
212class TiledTokenPositionalEmbedding(nn.Module):
213    """
214
215    Token positional embedding for tiled images, different for every tile, different for every token.
216
217    There are two positional embeddings in this module:
218
219    * local_token_positional_embedding: same for every tile, different for every token. Equivalent \
220        to :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding`, but gated.
221    * global_token_positional_embedding: different for every tile, different for every token.
222
223    Notice that tile is different from patch (token). For details, please check the documentation of
224    :class:`torchtune.modules.vision_transformer.VisionTransformer`.
225
226    Args:
227        max_num_tiles (int): The maximum number of tiles an image can be divided into.
228        embed_dim (int): The dimensionality of each token embedding.
229        tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
230            the size of the input image. In this case, the function will consider your image as a single tile.
231        patch_size (int): The size of each patch. Used to divide the tiles into patches.
232            E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
233            with shape (40, 40) each.
234    """
235
236    def __init__(
237        self, max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int
238    ) -> None:
239        super().__init__()
240
241        patch_grid_size = tile_size // patch_size
242        self.n_tokens_per_tile = patch_grid_size**2 + 1  # +1 for cls token
243        scale = embed_dim**-0.5
244
245        # different for every token, same for every tile
246        self.local_token_positional_embedding = nn.Parameter(
247            scale * torch.randn((self.n_tokens_per_tile, embed_dim))
248        )
249
250        # different for every token, different for every tile
251        self.global_token_positional_embedding = nn.Parameter(
252            scale
253            * torch.randn(
254                max_num_tiles,
255                max_num_tiles,
256                self.n_tokens_per_tile,
257                embed_dim,
258            )
259        )
260        self.max_num_tiles = max_num_tiles
261        self.gate = nn.Parameter(torch.zeros(1))
262
263        self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
264
265    @torch.no_grad()
266    def _load_state_dict_hook(
267        self,
268        state_dict: Dict[str, Any],
269        prefix: str,
270        *args: Tuple[Any],
271        **kwargs: Dict[str, Any],
272    ) -> None:
273        """
274        Interpolates positional embeddings to accomodate different number of tiles
275        and tokens per tile, in case the model was instantiated with different
276        settings than the one you are loading the state dict from.
277
278        For more info, please check self._resize_local_position_embedding and
279        self._resize_global_position_embedding functions.
280
281        Args:
282            state_dict (Dict[str, Any]): The state dict to load.
283            prefix (str): The prefix of the state dict.
284            *args (Tuple[Any]): Additional positional arguments.
285            **kwargs (Dict[str, Any]): Additional keyword arguments.
286
287        Raises:
288            ValueError: if loaded local or global embedding n_tokens_per_tile is not derived
289                from a squared grid.
290            ValueError: if after interpolation, the shape of the loaded local embedding
291                is not compatible with the current embedding.
292            ValueError: if after interpolation, the shape of the loaded global embedding
293                is not compatible with the current embedding.
294        """
295
296        # process local_token_positional_embedding
297        inpt_local_pos_embed = state_dict.get(
298            prefix + "local_token_positional_embedding"
299        )
300
301        if inpt_local_pos_embed is not None:
302
303            # We can only apply F.interpolate to vanilla tensors, not DTensors
304            # If pos embeds are a DTensor, we gather the full tensor, apply
305            # interpolate, and then reshard after
306            if isinstance(inpt_local_pos_embed, DTensor):
307                local_embed_is_sharded = True
308                local_embed_device_mesh = inpt_local_pos_embed.device_mesh
309                local_embed_placements = inpt_local_pos_embed.placements
310                inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
311            else:
312                local_embed_is_sharded = False
313
314            # sanity check
315            inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape
316            if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
317                raise ValueError(
318                    f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, "
319                    f"which indicates a grid_size that is not squared. This is currently not supported."
320                )
321
322            # instantiated pos emb
323            (
324                tgt_n_tokens_per_tile,
325                tgt_embed_dim,
326            ) = self.local_token_positional_embedding.shape
327
328            # resize ckpt to match instantiated shape
329            inpt_local_pos_embed = self._resize_local_position_embedding(
330                local_pos_embed=inpt_local_pos_embed,
331                tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
332            )
333
334            if local_embed_is_sharded:
335                inpt_local_pos_embed = distribute_tensor(
336                    inpt_local_pos_embed,
337                    device_mesh=local_embed_device_mesh,
338                    placements=local_embed_placements,
339                )
340
341            # update state dict
342            state_dict[prefix + "local_token_positional_embedding"] = (
343                inpt_local_pos_embed
344            )
345            if (
346                inpt_local_pos_embed.shape
347                != self.local_token_positional_embedding.shape
348            ):
349                raise ValueError(
350                    f"Loaded local positional embedding has shape {inpt_local_pos_embed.shape}, "
351                    f"after interpolation. Expected shape {self.local_token_positional_embedding.shape}."
352                )
353
354        # process global_token_positional_embedding
355        inpt_global_pos_embed = state_dict.get(
356            prefix + "global_token_positional_embedding"
357        )
358
359        if inpt_global_pos_embed is not None:
360
361            # We can only apply F.interpolate to vanilla tensors, not DTensors
362            # If pos embeds are a DTensor, we gather the full tensor, apply
363            # interpolate, and then reshard after
364            if isinstance(inpt_global_pos_embed, DTensor):
365                global_embed_is_sharded = True
366                global_embed_device_mesh = inpt_global_pos_embed.device_mesh
367                global_embed_placements = inpt_global_pos_embed.placements
368                inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
369            else:
370                global_embed_is_sharded = False
371
372            _, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape
373
374            # sanity check
375            if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
376                raise ValueError(
377                    f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, "
378                    f"which indicates a grid_size that is not squared. This is currently not supported."
379                )
380
381            # instantiated pos emb
382            (
383                tgt_max_num_tiles_x,
384                tgt_max_num_tiles_y,  # not used, same as tgt_max_num_tiles_x
385                tgt_n_tokens_per_tile,
386                tgt_embed_dim,
387            ) = self.global_token_positional_embedding.shape
388
389            # resize ckpt to match instantiated shape
390            inpt_global_pos_embed = self._resize_global_position_embedding(
391                global_pos_embed=inpt_global_pos_embed,
392                tgt_max_num_tiles=tgt_max_num_tiles_x,
393                tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
394            )
395
396            if global_embed_is_sharded:
397                inpt_global_pos_embed = distribute_tensor(
398                    inpt_global_pos_embed,
399                    device_mesh=global_embed_device_mesh,
400                    placements=global_embed_placements,
401                )
402
403            # update state dict
404            state_dict[prefix + "global_token_positional_embedding"] = (
405                inpt_global_pos_embed
406            )
407            if (
408                inpt_global_pos_embed.shape
409                != self.global_token_positional_embedding.shape
410            ):
411                raise ValueError(
412                    f"Loaded global positional embedding has shape {inpt_global_pos_embed.shape}, "
413                    f"after interpolation. Expected shape {self.global_token_positional_embedding.shape}."
414                )
415
416    @staticmethod
417    def _resize_local_position_embedding(
418        local_pos_embed: torch.Tensor, tgt_patch_grid_size: int
419    ) -> torch.Tensor:
420        """
421        Interpolates the local position embedding for a vision encoder to accommodate
422        a different number of tokens per tile. This is the only dimension that
423        changes during interpolation.
424
425        Args:
426            local_pos_embed (torch.Tensor): The position embeddings tensor to be resized. It
427                has shape [n_tokens_per_tile, emb_dim], where the first token is the CLS token
428                and n_tokens_per_tile = patch_grid_size**2 + 1.
429            tgt_patch_grid_size (int): The target size of each patch grid, i.e.,
430                the square root of the number of tokens per tile, excluding the class token.
431
432        Returns:
433            torch.Tensor: The resized position embeddings tensor of shape
434                [tgt_n_tokens_per_tile, dim], where tgt_n_tokens_per_tile = tgt_patch_grid_size**2 + 1.
435
436        Example:
437            >>> import torch
438            >>> import math
439            >>> local_pos_embed = torch.randn((10*10+1, 64))  # Example input tensor
440            >>> tgt_patch_grid_size = 20  # Target number of tokens per tile
441            >>> resized_pos_embed = _resize_local_position_embedding(local_pos_embed, tgt_patch_grid_size)
442            >>> print(resized_pos_embed.shape)
443            torch.Size([20*20+1, 64])
444        """
445        # inverse n_tokens_per_tile = patch_grid_size**2 + 1, where +1 is the cls token
446        inpt_n_tokens_per_tile, inpt_embed_dim = local_pos_embed.shape
447        inpt_patch_grid_size = int(math.sqrt(inpt_n_tokens_per_tile - 1))
448
449        # split tokens between cls and img tokens.
450        # we don't want to interpolate cls token.
451        cls_token, local_pos_embed = (
452            local_pos_embed[[0]],  # cls token
453            local_pos_embed[1:],  # image tokens
454        )
455
456        # we reshape n_tokens_per_tile - 1 --> (inpt_patch_grid_size, inpt_patch_grid_size)
457        # and permute to have inpt_patch_grid_size as the last two dimensions
458        # we also add a batch dim to the tensor, since F.interpolate expects it
459        local_pos_embed = local_pos_embed.reshape(
460            1, inpt_patch_grid_size, inpt_patch_grid_size, -1
461        ).permute(0, 3, 1, 2)
462
463        local_pos_embed = F.interpolate(
464            local_pos_embed,
465            size=[tgt_patch_grid_size, tgt_patch_grid_size],
466            mode="bilinear",
467            align_corners=True,  # defaults from internal-llama-models
468        )
469
470        # reshape back to [1, tokens_per_tile, embed_dim]
471        local_pos_embed = local_pos_embed.permute(0, 2, 3, 1).reshape(
472            1, -1, inpt_embed_dim
473        )
474
475        # remove batch dim added previously
476        local_pos_embed = local_pos_embed.squeeze(0)
477
478        # add cls token back in
479        local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0)
480
481        return local_pos_embed
482
483    # TODO: Switch to public method after 2.5 is stable
484    @staticmethod
485    def _resize_global_position_embedding(
486        global_pos_embed: torch.Tensor,
487        tgt_max_num_tiles: int,
488        tgt_patch_grid_size: int,
489    ) -> torch.Tensor:
490        """
491        Interpolates the global position embedding for a vision encoder to accommodate new grid dimensions.
492        The embedding dimension is not changed during interpolation, only max_num_tiles and num_tokens_per_tile.
493
494        Args:
495            global_pos_embed (torch.Tensor): The input global position embeddings tensor of shape
496                [max_num_tiles_x, max_num_tiles_y, num_tokens_per_tile, embed_dim],
497                where num_tokens_per_tile = inpt_patch_grid_size * inpt_patch_grid_size + 1 (CLS token), and
498                max_num_tiles_x == max_num_tiles_y.
499            tgt_max_num_tiles (int): The target maximum number of tiles along one dimension (assumed square grid).
500            tgt_patch_grid_size (int): The target size of each patch grid, i.e., the square root of the number of tokens
501                per tile, excluding the class token.
502
503
504        Returns:
505            torch.Tensor: The resized global position embeddings tensor of shape
506                [tgt_max_num_tiles, tgt_max_num_tiles, tgt_patch_grid_size * tgt_patch_grid_size + 1, embed_dim].
507
508        Example:
509            >>> import torch
510            >>> global_pos_embed = torch.arange(3*3*(2*2+1)*4).reshape((3, 3, 2*2+1, 4))  # Example input tensor
511            >>> tgt_max_num_tiles = 2  # Target maximum number of tiles
512            >>> tgt_patch_grid_size = 3  # Target patch grid size
513            >>> resized_global_pos_embed = (
514            >>> _resize_global_position_embedding(global_pos_embed, tgt_max_num_tiles, tgt_patch_grid_size))
515            >>> print(resized_global_pos_embed.shape)
516            torch.Size([2, 2, 3*3+1, 4])
517        """
518
519        # remove cls token to interpolate it separately
520        pos_embed = global_pos_embed[:, :, 1:, :]
521        cls_embed = global_pos_embed[:, :, [0], :]
522
523        (
524            max_num_tiles_x,
525            max_num_tiles_y,
526            n_tokens_per_tile,
527            embed_dim,
528        ) = pos_embed.shape
529
530        # tokens_per_tile == inpt_patch_grid_size**2
531        # we reshape n_tokens_per_tile --> (inpt_patch_grid_size, inpt_patch_grid_size)
532        inpt_patch_grid_size = int(math.sqrt(n_tokens_per_tile))
533        pos_embed = pos_embed.reshape(
534            max_num_tiles_x,
535            max_num_tiles_y,
536            inpt_patch_grid_size,
537            inpt_patch_grid_size,
538            embed_dim,
539        )
540
541        # combine max_num_tiles and patch_grid_size into one dimension
542        pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
543        pos_embed = pos_embed.reshape(
544            max_num_tiles_x * inpt_patch_grid_size,
545            max_num_tiles_y * inpt_patch_grid_size,
546            embed_dim,
547        )
548
549        # add batch dim for interpolation
550        pos_embed = pos_embed.unsqueeze(0)
551
552        tgt_size = (
553            int(tgt_max_num_tiles * tgt_patch_grid_size),
554            int(tgt_max_num_tiles * tgt_patch_grid_size),
555        )
556
557        # move to the last two dim for interpolation
558        pos_embed = pos_embed.permute(0, 3, 1, 2)
559        pos_embed = F.interpolate(
560            pos_embed,
561            size=tgt_size,
562            mode="bilinear",
563            align_corners=True,  # defaults from internal-llama-models
564        )
565
566        # return to original shape and remove batch dim
567        pos_embed = pos_embed.permute(0, 2, 3, 1).squeeze(0)
568
569        # move it back in place
570        pos_embed = pos_embed.view(
571            tgt_max_num_tiles,
572            tgt_patch_grid_size,
573            tgt_max_num_tiles,
574            tgt_patch_grid_size,
575            embed_dim,
576        )
577        pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
578        pos_embed = pos_embed.view(
579            tgt_max_num_tiles,
580            tgt_max_num_tiles,
581            int(tgt_patch_grid_size**2),
582            embed_dim,
583        )
584
585        # interpolate cls token
586        cls_embed = cls_embed.permute(2, 3, 0, 1)
587        cls_embed_resized = F.interpolate(
588            cls_embed,
589            size=(tgt_max_num_tiles, tgt_max_num_tiles),
590            mode="bilinear",
591            align_corners=True,  # defaults from internal-llama-models
592        )
593        cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
594
595        # add cls token back in
596        global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2)
597
598        return global_pos_embed
599
600    def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
601        """
602        Args:
603            x (torch.Tensor): torch.Tensor with shape
604                (bsz * n_imgs, n_tiles, n_tokens_per_tile, embed_dim).
605            aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
606                where aspect_ratio[k] represents the aspect ratio of the k^th image
607                of the batch before tile-cropping,  e.g. aspect_ratio[k] = (2,1).
608        Returns:
609            torch.Tensor: The input tensor with added positional embeddings.
610        """
611        bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim = x.shape
612
613        # apply local position embedding (same for every tile)
614        x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh()))
615
616        # apply global positional embedding (different for every tile)
617        x = x.view(bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim)
618        for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
619            # When we batch images, all are padded to the same amount of tiles.
620            # The aspect_ratio lets us know the non padded tiles for each image.
621            # We only add positional encoding to those.
622            n_tiles_h = n_tiles_h.item()
623            n_tiles_w = n_tiles_w.item()
624
625            n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
626
627            # We get only the positional encoding for non padded tiles,
628            # i.e. n_tiles_h, n_tiles_w.
629            torch._check(n_tiles_h > 0)
630            torch._check(n_tiles_w > 0)
631            torch._check(n_tiles_h <= self.max_num_tiles)
632            torch._check(n_tiles_w <= self.max_num_tiles)
633            padded_embedding = F.pad(
634                self.global_token_positional_embedding, (0, 0, 0, 0, 0, 1, 0, 1)
635            )
636
637            pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
638
639            # Add pos encoding to the non padded tiles.
640            pos_embed = pos_embed.clone()
641            pos_embed = pos_embed.reshape(
642                n_non_padded_tiles, self.n_tokens_per_tile, embed_dim
643            )
644            pos_embed = pos_embed * self.gate.tanh()
645            x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
646            torch._check(n_non_padded_tiles < self.max_num_tiles + 1)
647            torch._check(n_non_padded_tiles < x.size(1))
648            x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed
649            x = x[:, :n_tiles, :, :]
650
651        return x
652
653
654def replace_tile_positional_embedding(model: nn.Module) -> nn.Module:
655    """
656    Replace the tile positional embedding from torchtune with an export-friendly one.
657    Recursively searches the submodules of the model and replaces the tile positional embedding if found.
658    Args:
659        model (nn.Module): The model to replace the tile positional embedding in.
660
661    Returns:
662        nn.Module: The model after replacing the tile positional embedding.
663
664    """
665    from torchtune.models.clip._position_embeddings import (
666        TilePositionalEmbedding as TuneTilePositionalEmbedding,
667    )
668
669    for name, module in model.named_children():
670        if isinstance(module, TuneTilePositionalEmbedding):
671            logging.info(
672                f"Replacing tile positional embedding in {name} with export-friendly one."
673            )
674            max_num_tiles, _, _, embed_dim = module.embedding.shape
675            mod = TilePositionalEmbedding(
676                max_num_tiles=max_num_tiles,
677                embed_dim=embed_dim,
678            )
679            mod.load_state_dict(module.state_dict())
680            setattr(
681                model,
682                name,
683                mod,
684            )
685        else:
686            replace_tile_positional_embedding(module)
687    return model
688
689
690def replace_tiled_token_positional_embedding(model: nn.Module) -> nn.Module:
691    """
692    Replace the tiled token positional embedding from torchtune with an export-friendly one.
693    Recursively searches the submodules of the model and replaces the tiled token positional embedding if found.
694    Args:
695        model (nn.Module): The model to replace the tiled token positional embedding in.
696
697    Returns:
698        nn.Module: The model after replacing the tiled token positional embedding.
699
700    """
701    from torchtune.models.clip._position_embeddings import (
702        TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding,
703    )
704
705    for name, module in model.named_children():
706        if isinstance(module, TuneTiledTokenPositionalEmbedding):
707            logging.info(
708                f"Replacing tiled token positional embedding in {name} with export-friendly one."
709            )
710            max_num_tiles, _, n_tokens_per_tile, embed_dim = (
711                module.global_token_positional_embedding.shape
712            )
713            mod = TiledTokenPositionalEmbedding(
714                max_num_tiles=max_num_tiles,
715                embed_dim=embed_dim,
716                tile_size=int(math.sqrt((n_tokens_per_tile - 1))),
717                patch_size=1,
718            )
719            mod.load_state_dict(module.state_dict())
720            setattr(
721                model,
722                name,
723                mod,
724            )
725        else:
726            replace_tiled_token_positional_embedding(module)
727    return model
728