1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <vector> 12 13 #include <executorch/extension/tensor/tensor.h> 14 #include <executorch/runtime/core/exec_aten/exec_aten.h> 15 16 namespace example { 17 18 /** 19 * Computes the cross-attention mask for text + image inputs. Text tokens that 20 * participate in cross-attention with an image token will show True in the mask 21 * and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper 22 * (https://arxiv.org/pdf/2204.14198): 23 * 24 * (1) Text tokens immediately following the image token up until the next 25 * image token (2) Consecutive image tokens attend to subsequent text tokens 26 * 27 * :: 28 * 29 * ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ 30 * img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ 31 * └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ 32 * ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ 33 * img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ 34 * └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ 35 * ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ 36 * img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ 37 * └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ 38 * <img1> <img2>These are two dogs. <img3> This is a cat. 39 * 40 * 41 * 42 * Resultant mask is constructed per image and is of shape (text_seq_len, 43 * image_seq_len), where True indicates that the token outputted from the image 44 * encoder attends to the token in the text sequence in cross-attention. A list 45 * of these masks are returned with length equal to number of images in the 46 * sample. 47 * 48 * @param tokens Vector of tokens participating in the cross attention. 49 * @param images Vector of images participating in the cross attention. 50 * @param tile_size The size of the image tiles from the image transform. 51 * @param patch_size The size of each patch. Used to divide the tiles into 52 * patches. E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 53 * grid of patches with shape (40, 40) each. image_token_id (int): Token ID of 54 * the image special token. 55 * @param image_token_id The value of the image token. 56 * @param out Out vector holding the raw data wrapped by the returned cross 57 * attention masks. 58 * 59 * @returns A vector of cross attention masks, as Tensors, one for each image. 60 */ 61 std::vector<::executorch::extension::TensorPtr> cross_attention_mask( 62 const std::vector<int>& tokens, 63 const std::vector<::executorch::aten::Tensor>& images, 64 size_t tile_size, 65 size_t patch_size, 66 int image_token_id, 67 std::vector<std::vector<int>>& out); 68 69 } // namespace example 70