1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <vector>
12 
13 #include <executorch/extension/tensor/tensor.h>
14 #include <executorch/runtime/core/exec_aten/exec_aten.h>
15 
16 namespace example {
17 
18 /**
19  * Computes the cross-attention mask for text + image inputs. Text tokens that
20  * participate in cross-attention with an image token will show True in the mask
21  * and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper
22  * (https://arxiv.org/pdf/2204.14198):
23  *
24  *     (1) Text tokens immediately following the image token up until the next
25  * image token (2) Consecutive image tokens attend to subsequent text tokens
26  *
27  * ::
28  *
29  *           ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
30  *      img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
31  *           └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
32  *           ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
33  *      img2 │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │   │ │   │ │   │ │   │ │   │
34  *           └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
35  *           ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
36  *      img3 │   │ │   │ │   │ │   │ │   │ │   │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
37  *           └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
38  *         <img1> <img2>These  are   two  dogs. <img3> This   is    a    cat.
39  *
40  *
41  *
42  * Resultant mask is constructed per image and is of shape (text_seq_len,
43  * image_seq_len), where True indicates that the token outputted from the image
44  * encoder attends to the token in the text sequence in cross-attention. A list
45  * of these masks are returned with length equal to number of images in the
46  * sample.
47  *
48  * @param tokens Vector of tokens participating in the cross attention.
49  * @param images Vector of images participating in the cross attention.
50  * @param tile_size The size of the image tiles from the image transform.
51  * @param patch_size The size of each patch. Used to divide the tiles into
52  * patches. E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10
53  * grid of patches with shape (40, 40) each. image_token_id (int): Token ID of
54  * the image special token.
55  * @param image_token_id The value of the image token.
56  * @param out Out vector holding the raw data wrapped by the returned cross
57  * attention masks.
58  *
59  * @returns A vector of cross attention masks, as Tensors, one for each image.
60  */
61 std::vector<::executorch::extension::TensorPtr> cross_attention_mask(
62     const std::vector<int>& tokens,
63     const std::vector<::executorch::aten::Tensor>& images,
64     size_t tile_size,
65     size_t patch_size,
66     int image_token_id,
67     std::vector<std::vector<int>>& out);
68 
69 } // namespace example
70