1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# An torch.export() friendly version of torchtune's positional embeddings. 8# Added torch._check() to make sure guards on symints are enforced. 9# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py 10 11import logging 12import math 13from typing import Any, Dict, Tuple 14 15import torch 16import torch.nn.functional as F 17from torch import nn 18from torch.distributed._tensor import distribute_tensor, DTensor 19 20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 21logging.basicConfig(level=logging.INFO, format=FORMAT) 22 23 24class TilePositionalEmbedding(nn.Module): 25 """ 26 Positional embedding for tiles, different for every tile, same for every token within a tile. 27 28 Notice that tile is different from patch (token). For details, please check the documentation of 29 :class:`torchtune.modules.vision_transformer.VisionTransformer`. 30 31 Args: 32 max_num_tiles (int): The maximum number of tiles an image can be divided into. 33 embed_dim (int): The dimensionality of each tile embedding. 34 """ 35 36 def __init__( 37 self, 38 max_num_tiles: int, 39 embed_dim: int, 40 ): 41 super().__init__() 42 self.max_num_tiles = max_num_tiles 43 self.embed_dim = embed_dim 44 45 scale = embed_dim**-0.5 46 self.embedding = nn.Parameter( 47 scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) 48 ) 49 self.gate = nn.Parameter(torch.zeros(1)) 50 51 # Register load hook to interpolate positional embeddings 52 self._register_load_state_dict_pre_hook(self._load_state_dict_hook) 53 54 # TODO: Switch to public method after 2.5 is stable 55 @torch.no_grad() 56 def _load_state_dict_hook( 57 self, 58 state_dict: Dict[str, Any], 59 prefix: str, 60 *args: Tuple[Any], 61 **kwargs: Dict[str, Any], 62 ): 63 """ 64 Interpolates positional embeddings to accomodate different number of tiles, 65 in case the model was instantiated with different 66 settings than the one you are loading the state dict from. 67 68 For more info, check self._dynamic_resize function. 69 70 Args: 71 state_dict (Dict[str, Any]): The state dict to load. 72 prefix (str): The prefix of the state dict. 73 *args (Tuple[Any]): Additional positional arguments. 74 **kwargs (Dict[str, Any]): Additional keyword arguments. 75 76 Raises: 77 ValueError: if the shape of the loaded embedding is not compatible with the current embedding. 78 ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. 79 ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. 80 """ 81 82 embedding = state_dict.get(prefix + "embedding") 83 84 if embedding is not None: 85 86 # ckpt pos emb 87 ( 88 tgt_max_num_tiles_x, 89 tgt_max_num_tiles_y, 90 tgt_num_tokens, 91 tgt_emb, 92 ) = self.embedding.shape 93 94 # instantiated pos emb 95 ( 96 inpt_max_num_tiles_x, 97 inpt_max_num_tiles_y, 98 inpt_num_tokens, 99 inpt_emb, 100 ) = state_dict[prefix + "embedding"].shape 101 102 # sanity check 103 if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: 104 raise ValueError( 105 "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" 106 f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" 107 ) 108 109 if inpt_max_num_tiles_x != inpt_max_num_tiles_y: 110 raise ValueError( 111 "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found" 112 f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}" 113 ) 114 115 # resize ckpt to match instantiated shape 116 embedding_new = self._resize_position_embedding( 117 embedding, tgt_max_num_tiles=tgt_max_num_tiles_x 118 ) 119 120 # update state dict 121 state_dict[prefix + "embedding"] = embedding_new 122 if embedding_new.shape != self.embedding.shape: 123 raise ValueError( 124 "Expected embedding shape and embedding_new.shape to match" 125 f" but found shapes {self.embedding.shape} and {embedding_new.shape}" 126 ) 127 128 @staticmethod 129 def _resize_position_embedding( 130 embedding: torch.Tensor, tgt_max_num_tiles: int 131 ) -> torch.Tensor: 132 """ 133 Interpolates positional embeddings to accomodate a different max_num_tiles. These 134 are the only dimensions that changes during interpolation. 135 136 Args: 137 embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim 138 tgt_max_num_tiles (int): The number of tiles to resize to. 139 140 Returns: 141 torch.Tensor: The resized embedding. 142 143 Example: 144 >>> import torch 145 >>> # create dummy embedding 146 >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float() 147 >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1) 148 >>> print(resized_embed.shape) 149 >>> torch.Size([1, 1, 2, 2]) 150 """ 151 # set max_num_tiles to the last dimension 152 embedding = embedding.permute(2, 3, 0, 1) 153 154 embedding = F.interpolate( 155 embedding, 156 size=(tgt_max_num_tiles, tgt_max_num_tiles), 157 mode="bilinear", 158 align_corners=True, 159 ) 160 # permute to the original shape 161 embedding = embedding.permute(2, 3, 0, 1) 162 return embedding 163 164 def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: 165 """ 166 args: 167 x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). 168 aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), 169 representing the aspect ratio of the image before tile-cropping, e.g. (2,1). 170 returns: 171 torch.Tensor: The input tensor with added positional embeddings. 172 """ 173 bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape 174 torch._check(n_tiles <= self.max_num_tiles) 175 176 for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): 177 # When we batch images, all are padded to the same amount of tiles. 178 # The aspect_ratio lets us know the non padded tiles for each image. 179 # We only add positional encoding to those. 180 n_tiles_h = n_tiles_h.item() 181 n_tiles_w = n_tiles_w.item() 182 183 n_non_padded_tiles = int(n_tiles_h * n_tiles_w) 184 185 # We get only the positional encoding for non padded tiles, 186 # i.e. n_tiles_h, n_tiles_w. 187 torch._check_is_size(n_tiles_h) 188 torch._check_is_size(n_tiles_w) 189 torch._check(n_tiles_h >= 1) 190 torch._check(n_tiles_w >= 1) 191 torch._check(n_tiles_h <= self.max_num_tiles) 192 torch._check(n_tiles_w <= self.max_num_tiles) 193 # TODO: Remove this once pytorch/pytorch#120288 is fixed 194 padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) 195 pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] 196 197 # We need to do a clone here in order to make this model export 198 # friendly as the reshape is collapsing dim 0 and dim 1 into a 199 # single dim. 200 pos_embed = pos_embed.clone() 201 pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) 202 203 x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) 204 torch._check_is_size(n_non_padded_tiles) 205 torch._check(n_non_padded_tiles < x.size(1)) 206 x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() 207 x = x[:, :n_tiles, :, :] 208 209 return x 210 211 212class TiledTokenPositionalEmbedding(nn.Module): 213 """ 214 215 Token positional embedding for tiled images, different for every tile, different for every token. 216 217 There are two positional embeddings in this module: 218 219 * local_token_positional_embedding: same for every tile, different for every token. Equivalent \ 220 to :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding`, but gated. 221 * global_token_positional_embedding: different for every tile, different for every token. 222 223 Notice that tile is different from patch (token). For details, please check the documentation of 224 :class:`torchtune.modules.vision_transformer.VisionTransformer`. 225 226 Args: 227 max_num_tiles (int): The maximum number of tiles an image can be divided into. 228 embed_dim (int): The dimensionality of each token embedding. 229 tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, 230 the size of the input image. In this case, the function will consider your image as a single tile. 231 patch_size (int): The size of each patch. Used to divide the tiles into patches. 232 E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches 233 with shape (40, 40) each. 234 """ 235 236 def __init__( 237 self, max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int 238 ) -> None: 239 super().__init__() 240 241 patch_grid_size = tile_size // patch_size 242 self.n_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token 243 scale = embed_dim**-0.5 244 245 # different for every token, same for every tile 246 self.local_token_positional_embedding = nn.Parameter( 247 scale * torch.randn((self.n_tokens_per_tile, embed_dim)) 248 ) 249 250 # different for every token, different for every tile 251 self.global_token_positional_embedding = nn.Parameter( 252 scale 253 * torch.randn( 254 max_num_tiles, 255 max_num_tiles, 256 self.n_tokens_per_tile, 257 embed_dim, 258 ) 259 ) 260 self.max_num_tiles = max_num_tiles 261 self.gate = nn.Parameter(torch.zeros(1)) 262 263 self._register_load_state_dict_pre_hook(self._load_state_dict_hook) 264 265 @torch.no_grad() 266 def _load_state_dict_hook( 267 self, 268 state_dict: Dict[str, Any], 269 prefix: str, 270 *args: Tuple[Any], 271 **kwargs: Dict[str, Any], 272 ) -> None: 273 """ 274 Interpolates positional embeddings to accomodate different number of tiles 275 and tokens per tile, in case the model was instantiated with different 276 settings than the one you are loading the state dict from. 277 278 For more info, please check self._resize_local_position_embedding and 279 self._resize_global_position_embedding functions. 280 281 Args: 282 state_dict (Dict[str, Any]): The state dict to load. 283 prefix (str): The prefix of the state dict. 284 *args (Tuple[Any]): Additional positional arguments. 285 **kwargs (Dict[str, Any]): Additional keyword arguments. 286 287 Raises: 288 ValueError: if loaded local or global embedding n_tokens_per_tile is not derived 289 from a squared grid. 290 ValueError: if after interpolation, the shape of the loaded local embedding 291 is not compatible with the current embedding. 292 ValueError: if after interpolation, the shape of the loaded global embedding 293 is not compatible with the current embedding. 294 """ 295 296 # process local_token_positional_embedding 297 inpt_local_pos_embed = state_dict.get( 298 prefix + "local_token_positional_embedding" 299 ) 300 301 if inpt_local_pos_embed is not None: 302 303 # We can only apply F.interpolate to vanilla tensors, not DTensors 304 # If pos embeds are a DTensor, we gather the full tensor, apply 305 # interpolate, and then reshard after 306 if isinstance(inpt_local_pos_embed, DTensor): 307 local_embed_is_sharded = True 308 local_embed_device_mesh = inpt_local_pos_embed.device_mesh 309 local_embed_placements = inpt_local_pos_embed.placements 310 inpt_local_pos_embed = inpt_local_pos_embed.full_tensor() 311 else: 312 local_embed_is_sharded = False 313 314 # sanity check 315 inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape 316 if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0: 317 raise ValueError( 318 f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, " 319 f"which indicates a grid_size that is not squared. This is currently not supported." 320 ) 321 322 # instantiated pos emb 323 ( 324 tgt_n_tokens_per_tile, 325 tgt_embed_dim, 326 ) = self.local_token_positional_embedding.shape 327 328 # resize ckpt to match instantiated shape 329 inpt_local_pos_embed = self._resize_local_position_embedding( 330 local_pos_embed=inpt_local_pos_embed, 331 tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)), 332 ) 333 334 if local_embed_is_sharded: 335 inpt_local_pos_embed = distribute_tensor( 336 inpt_local_pos_embed, 337 device_mesh=local_embed_device_mesh, 338 placements=local_embed_placements, 339 ) 340 341 # update state dict 342 state_dict[prefix + "local_token_positional_embedding"] = ( 343 inpt_local_pos_embed 344 ) 345 if ( 346 inpt_local_pos_embed.shape 347 != self.local_token_positional_embedding.shape 348 ): 349 raise ValueError( 350 f"Loaded local positional embedding has shape {inpt_local_pos_embed.shape}, " 351 f"after interpolation. Expected shape {self.local_token_positional_embedding.shape}." 352 ) 353 354 # process global_token_positional_embedding 355 inpt_global_pos_embed = state_dict.get( 356 prefix + "global_token_positional_embedding" 357 ) 358 359 if inpt_global_pos_embed is not None: 360 361 # We can only apply F.interpolate to vanilla tensors, not DTensors 362 # If pos embeds are a DTensor, we gather the full tensor, apply 363 # interpolate, and then reshard after 364 if isinstance(inpt_global_pos_embed, DTensor): 365 global_embed_is_sharded = True 366 global_embed_device_mesh = inpt_global_pos_embed.device_mesh 367 global_embed_placements = inpt_global_pos_embed.placements 368 inpt_global_pos_embed = inpt_global_pos_embed.full_tensor() 369 else: 370 global_embed_is_sharded = False 371 372 _, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape 373 374 # sanity check 375 if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0: 376 raise ValueError( 377 f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, " 378 f"which indicates a grid_size that is not squared. This is currently not supported." 379 ) 380 381 # instantiated pos emb 382 ( 383 tgt_max_num_tiles_x, 384 tgt_max_num_tiles_y, # not used, same as tgt_max_num_tiles_x 385 tgt_n_tokens_per_tile, 386 tgt_embed_dim, 387 ) = self.global_token_positional_embedding.shape 388 389 # resize ckpt to match instantiated shape 390 inpt_global_pos_embed = self._resize_global_position_embedding( 391 global_pos_embed=inpt_global_pos_embed, 392 tgt_max_num_tiles=tgt_max_num_tiles_x, 393 tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)), 394 ) 395 396 if global_embed_is_sharded: 397 inpt_global_pos_embed = distribute_tensor( 398 inpt_global_pos_embed, 399 device_mesh=global_embed_device_mesh, 400 placements=global_embed_placements, 401 ) 402 403 # update state dict 404 state_dict[prefix + "global_token_positional_embedding"] = ( 405 inpt_global_pos_embed 406 ) 407 if ( 408 inpt_global_pos_embed.shape 409 != self.global_token_positional_embedding.shape 410 ): 411 raise ValueError( 412 f"Loaded global positional embedding has shape {inpt_global_pos_embed.shape}, " 413 f"after interpolation. Expected shape {self.global_token_positional_embedding.shape}." 414 ) 415 416 @staticmethod 417 def _resize_local_position_embedding( 418 local_pos_embed: torch.Tensor, tgt_patch_grid_size: int 419 ) -> torch.Tensor: 420 """ 421 Interpolates the local position embedding for a vision encoder to accommodate 422 a different number of tokens per tile. This is the only dimension that 423 changes during interpolation. 424 425 Args: 426 local_pos_embed (torch.Tensor): The position embeddings tensor to be resized. It 427 has shape [n_tokens_per_tile, emb_dim], where the first token is the CLS token 428 and n_tokens_per_tile = patch_grid_size**2 + 1. 429 tgt_patch_grid_size (int): The target size of each patch grid, i.e., 430 the square root of the number of tokens per tile, excluding the class token. 431 432 Returns: 433 torch.Tensor: The resized position embeddings tensor of shape 434 [tgt_n_tokens_per_tile, dim], where tgt_n_tokens_per_tile = tgt_patch_grid_size**2 + 1. 435 436 Example: 437 >>> import torch 438 >>> import math 439 >>> local_pos_embed = torch.randn((10*10+1, 64)) # Example input tensor 440 >>> tgt_patch_grid_size = 20 # Target number of tokens per tile 441 >>> resized_pos_embed = _resize_local_position_embedding(local_pos_embed, tgt_patch_grid_size) 442 >>> print(resized_pos_embed.shape) 443 torch.Size([20*20+1, 64]) 444 """ 445 # inverse n_tokens_per_tile = patch_grid_size**2 + 1, where +1 is the cls token 446 inpt_n_tokens_per_tile, inpt_embed_dim = local_pos_embed.shape 447 inpt_patch_grid_size = int(math.sqrt(inpt_n_tokens_per_tile - 1)) 448 449 # split tokens between cls and img tokens. 450 # we don't want to interpolate cls token. 451 cls_token, local_pos_embed = ( 452 local_pos_embed[[0]], # cls token 453 local_pos_embed[1:], # image tokens 454 ) 455 456 # we reshape n_tokens_per_tile - 1 --> (inpt_patch_grid_size, inpt_patch_grid_size) 457 # and permute to have inpt_patch_grid_size as the last two dimensions 458 # we also add a batch dim to the tensor, since F.interpolate expects it 459 local_pos_embed = local_pos_embed.reshape( 460 1, inpt_patch_grid_size, inpt_patch_grid_size, -1 461 ).permute(0, 3, 1, 2) 462 463 local_pos_embed = F.interpolate( 464 local_pos_embed, 465 size=[tgt_patch_grid_size, tgt_patch_grid_size], 466 mode="bilinear", 467 align_corners=True, # defaults from internal-llama-models 468 ) 469 470 # reshape back to [1, tokens_per_tile, embed_dim] 471 local_pos_embed = local_pos_embed.permute(0, 2, 3, 1).reshape( 472 1, -1, inpt_embed_dim 473 ) 474 475 # remove batch dim added previously 476 local_pos_embed = local_pos_embed.squeeze(0) 477 478 # add cls token back in 479 local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0) 480 481 return local_pos_embed 482 483 # TODO: Switch to public method after 2.5 is stable 484 @staticmethod 485 def _resize_global_position_embedding( 486 global_pos_embed: torch.Tensor, 487 tgt_max_num_tiles: int, 488 tgt_patch_grid_size: int, 489 ) -> torch.Tensor: 490 """ 491 Interpolates the global position embedding for a vision encoder to accommodate new grid dimensions. 492 The embedding dimension is not changed during interpolation, only max_num_tiles and num_tokens_per_tile. 493 494 Args: 495 global_pos_embed (torch.Tensor): The input global position embeddings tensor of shape 496 [max_num_tiles_x, max_num_tiles_y, num_tokens_per_tile, embed_dim], 497 where num_tokens_per_tile = inpt_patch_grid_size * inpt_patch_grid_size + 1 (CLS token), and 498 max_num_tiles_x == max_num_tiles_y. 499 tgt_max_num_tiles (int): The target maximum number of tiles along one dimension (assumed square grid). 500 tgt_patch_grid_size (int): The target size of each patch grid, i.e., the square root of the number of tokens 501 per tile, excluding the class token. 502 503 504 Returns: 505 torch.Tensor: The resized global position embeddings tensor of shape 506 [tgt_max_num_tiles, tgt_max_num_tiles, tgt_patch_grid_size * tgt_patch_grid_size + 1, embed_dim]. 507 508 Example: 509 >>> import torch 510 >>> global_pos_embed = torch.arange(3*3*(2*2+1)*4).reshape((3, 3, 2*2+1, 4)) # Example input tensor 511 >>> tgt_max_num_tiles = 2 # Target maximum number of tiles 512 >>> tgt_patch_grid_size = 3 # Target patch grid size 513 >>> resized_global_pos_embed = ( 514 >>> _resize_global_position_embedding(global_pos_embed, tgt_max_num_tiles, tgt_patch_grid_size)) 515 >>> print(resized_global_pos_embed.shape) 516 torch.Size([2, 2, 3*3+1, 4]) 517 """ 518 519 # remove cls token to interpolate it separately 520 pos_embed = global_pos_embed[:, :, 1:, :] 521 cls_embed = global_pos_embed[:, :, [0], :] 522 523 ( 524 max_num_tiles_x, 525 max_num_tiles_y, 526 n_tokens_per_tile, 527 embed_dim, 528 ) = pos_embed.shape 529 530 # tokens_per_tile == inpt_patch_grid_size**2 531 # we reshape n_tokens_per_tile --> (inpt_patch_grid_size, inpt_patch_grid_size) 532 inpt_patch_grid_size = int(math.sqrt(n_tokens_per_tile)) 533 pos_embed = pos_embed.reshape( 534 max_num_tiles_x, 535 max_num_tiles_y, 536 inpt_patch_grid_size, 537 inpt_patch_grid_size, 538 embed_dim, 539 ) 540 541 # combine max_num_tiles and patch_grid_size into one dimension 542 pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() 543 pos_embed = pos_embed.reshape( 544 max_num_tiles_x * inpt_patch_grid_size, 545 max_num_tiles_y * inpt_patch_grid_size, 546 embed_dim, 547 ) 548 549 # add batch dim for interpolation 550 pos_embed = pos_embed.unsqueeze(0) 551 552 tgt_size = ( 553 int(tgt_max_num_tiles * tgt_patch_grid_size), 554 int(tgt_max_num_tiles * tgt_patch_grid_size), 555 ) 556 557 # move to the last two dim for interpolation 558 pos_embed = pos_embed.permute(0, 3, 1, 2) 559 pos_embed = F.interpolate( 560 pos_embed, 561 size=tgt_size, 562 mode="bilinear", 563 align_corners=True, # defaults from internal-llama-models 564 ) 565 566 # return to original shape and remove batch dim 567 pos_embed = pos_embed.permute(0, 2, 3, 1).squeeze(0) 568 569 # move it back in place 570 pos_embed = pos_embed.view( 571 tgt_max_num_tiles, 572 tgt_patch_grid_size, 573 tgt_max_num_tiles, 574 tgt_patch_grid_size, 575 embed_dim, 576 ) 577 pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() 578 pos_embed = pos_embed.view( 579 tgt_max_num_tiles, 580 tgt_max_num_tiles, 581 int(tgt_patch_grid_size**2), 582 embed_dim, 583 ) 584 585 # interpolate cls token 586 cls_embed = cls_embed.permute(2, 3, 0, 1) 587 cls_embed_resized = F.interpolate( 588 cls_embed, 589 size=(tgt_max_num_tiles, tgt_max_num_tiles), 590 mode="bilinear", 591 align_corners=True, # defaults from internal-llama-models 592 ) 593 cls_embed = cls_embed_resized.permute(2, 3, 0, 1) 594 595 # add cls token back in 596 global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2) 597 598 return global_pos_embed 599 600 def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: 601 """ 602 Args: 603 x (torch.Tensor): torch.Tensor with shape 604 (bsz * n_imgs, n_tiles, n_tokens_per_tile, embed_dim). 605 aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), 606 where aspect_ratio[k] represents the aspect ratio of the k^th image 607 of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1). 608 Returns: 609 torch.Tensor: The input tensor with added positional embeddings. 610 """ 611 bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim = x.shape 612 613 # apply local position embedding (same for every tile) 614 x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh())) 615 616 # apply global positional embedding (different for every tile) 617 x = x.view(bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim) 618 for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): 619 # When we batch images, all are padded to the same amount of tiles. 620 # The aspect_ratio lets us know the non padded tiles for each image. 621 # We only add positional encoding to those. 622 n_tiles_h = n_tiles_h.item() 623 n_tiles_w = n_tiles_w.item() 624 625 n_non_padded_tiles = int(n_tiles_h * n_tiles_w) 626 627 # We get only the positional encoding for non padded tiles, 628 # i.e. n_tiles_h, n_tiles_w. 629 torch._check(n_tiles_h > 0) 630 torch._check(n_tiles_w > 0) 631 torch._check(n_tiles_h <= self.max_num_tiles) 632 torch._check(n_tiles_w <= self.max_num_tiles) 633 padded_embedding = F.pad( 634 self.global_token_positional_embedding, (0, 0, 0, 0, 0, 1, 0, 1) 635 ) 636 637 pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] 638 639 # Add pos encoding to the non padded tiles. 640 pos_embed = pos_embed.clone() 641 pos_embed = pos_embed.reshape( 642 n_non_padded_tiles, self.n_tokens_per_tile, embed_dim 643 ) 644 pos_embed = pos_embed * self.gate.tanh() 645 x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) 646 torch._check(n_non_padded_tiles < self.max_num_tiles + 1) 647 torch._check(n_non_padded_tiles < x.size(1)) 648 x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed 649 x = x[:, :n_tiles, :, :] 650 651 return x 652 653 654def replace_tile_positional_embedding(model: nn.Module) -> nn.Module: 655 """ 656 Replace the tile positional embedding from torchtune with an export-friendly one. 657 Recursively searches the submodules of the model and replaces the tile positional embedding if found. 658 Args: 659 model (nn.Module): The model to replace the tile positional embedding in. 660 661 Returns: 662 nn.Module: The model after replacing the tile positional embedding. 663 664 """ 665 from torchtune.models.clip._position_embeddings import ( 666 TilePositionalEmbedding as TuneTilePositionalEmbedding, 667 ) 668 669 for name, module in model.named_children(): 670 if isinstance(module, TuneTilePositionalEmbedding): 671 logging.info( 672 f"Replacing tile positional embedding in {name} with export-friendly one." 673 ) 674 max_num_tiles, _, _, embed_dim = module.embedding.shape 675 mod = TilePositionalEmbedding( 676 max_num_tiles=max_num_tiles, 677 embed_dim=embed_dim, 678 ) 679 mod.load_state_dict(module.state_dict()) 680 setattr( 681 model, 682 name, 683 mod, 684 ) 685 else: 686 replace_tile_positional_embedding(module) 687 return model 688 689 690def replace_tiled_token_positional_embedding(model: nn.Module) -> nn.Module: 691 """ 692 Replace the tiled token positional embedding from torchtune with an export-friendly one. 693 Recursively searches the submodules of the model and replaces the tiled token positional embedding if found. 694 Args: 695 model (nn.Module): The model to replace the tiled token positional embedding in. 696 697 Returns: 698 nn.Module: The model after replacing the tiled token positional embedding. 699 700 """ 701 from torchtune.models.clip._position_embeddings import ( 702 TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding, 703 ) 704 705 for name, module in model.named_children(): 706 if isinstance(module, TuneTiledTokenPositionalEmbedding): 707 logging.info( 708 f"Replacing tiled token positional embedding in {name} with export-friendly one." 709 ) 710 max_num_tiles, _, n_tokens_per_tile, embed_dim = ( 711 module.global_token_positional_embedding.shape 712 ) 713 mod = TiledTokenPositionalEmbedding( 714 max_num_tiles=max_num_tiles, 715 embed_dim=embed_dim, 716 tile_size=int(math.sqrt((n_tokens_per_tile - 1))), 717 patch_size=1, 718 ) 719 mod.load_state_dict(module.state_dict()) 720 setattr( 721 model, 722 name, 723 mod, 724 ) 725 else: 726 replace_tiled_token_positional_embedding(module) 727 return model 728