1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/examples/models/llama3_2_vision/cross_attention/cross_attention_mask.h>
10
11 #include <algorithm>
12 #include <string>
13
14 namespace example {
15
16 using ::executorch::aten::ScalarType;
17 using ::executorch::aten::Tensor;
18 using ::executorch::aten::TensorImpl;
19
20 // Fowrward declaration needed for ARM compilers.
21 int32_t safe_size_t_to_sizes_type(size_t value);
22 std::vector<std::vector<int>> _get_image_attention_intervals(
23 const std::vector<int>& tokens,
24 int image_token_id);
25
safe_size_t_to_sizes_type(size_t value)26 int32_t safe_size_t_to_sizes_type(size_t value) {
27 if (value >
28 static_cast<size_t>(std::numeric_limits<TensorImpl::SizesType>::max())) {
29 throw std::overflow_error(
30 "size_t value too large for TensorImpl::SizesType");
31 }
32 return static_cast<TensorImpl::SizesType>(value);
33 }
34
35 /**
36 * Returns a list of lists of the form [start, end) where start is the index
37 * of the current image token and end is the index of the next image token,
38 * exclusive.
39 *
40 * Example:
41 * >>> text = "<img1><img2>These are two dogs. <img3>This is a cat."
42 * >>> size_t image_token_id = 1;
43 * >>> std::vector<int> tokens = {1, 1, 9673, 527, 1403, 12875, 13, 1, 1115,
44 * 374, 264, 8415]};
45 * >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40,
46 * image_token_id=1)
47 * >>> intervals = _get_image_attention_intervals(tokens, image_token_id)
48 * [[0, 7], [1, 7], [7, 12]]
49 *
50 * @param tokens List of token IDs in the text sequence.
51 * @param image_token_id The value of the image token.
52 *
53 * @returns Vector of vectors of the form [start, end) indicating the range of
54 * positions in the text sequence that should attend to the image.
55 */
_get_image_attention_intervals(const std::vector<int> & tokens,int image_token_id)56 std::vector<std::vector<int>> _get_image_attention_intervals(
57 const std::vector<int>& tokens,
58 int image_token_id) {
59 std::vector<std::vector<int>> vision_masks;
60 int end = tokens.size();
61 std::vector<int> vision_token_locations;
62
63 // Find all vision token locations.
64 for (int i = 0; i < tokens.size(); ++i) {
65 if (tokens[i] == image_token_id) {
66 vision_token_locations.push_back(i);
67 }
68 }
69
70 // Return empty vector if there are no images.
71 if (vision_token_locations.empty()) {
72 return vision_masks;
73 }
74
75 // If there is only one image, it will attend to subsequent text until end.
76 if (vision_token_locations.size() == 1) {
77 vision_masks.push_back({vision_token_locations[0], end});
78 return vision_masks;
79 }
80
81 // Construct intervals from previous image token to next image token.
82 for (int i = 0; i < vision_token_locations.size() - 1; ++i) {
83 vision_masks.push_back(
84 {vision_token_locations[i], vision_token_locations[i + 1]});
85 }
86
87 // Last image will attend to subsequent text until end.
88 vision_masks.push_back({vision_token_locations.back(), end});
89
90 // If there are consecutive vision tokens, they should all attend to the
91 // same subsequent text.
92 int last_mask_end = vision_masks.back()[1];
93 for (auto it = vision_masks.rbegin(); it != vision_masks.rend(); ++it) {
94 if ((*it)[0] == (*it)[1] - 1) {
95 (*it)[1] = last_mask_end;
96 }
97 last_mask_end = (*it)[1];
98 }
99
100 return vision_masks;
101 }
102
cross_attention_mask(const std::vector<int> & tokens,const std::vector<Tensor> & images,size_t tile_size,size_t patch_size,int image_token_id,std::vector<std::vector<int>> & out)103 std::vector<executorch::extension::TensorPtr> cross_attention_mask(
104 const std::vector<int>& tokens,
105 const std::vector<Tensor>& images,
106 size_t tile_size,
107 size_t patch_size,
108 int image_token_id,
109 std::vector<std::vector<int>>& out) {
110 size_t patch_grid_size = tile_size / patch_size;
111 size_t patches_per_tile = patch_grid_size * patch_grid_size;
112
113 std::vector<std::vector<int>> image_intervals =
114 _get_image_attention_intervals(tokens, image_token_id);
115
116 if (image_intervals.size() != images.size()) {
117 throw std::runtime_error(
118 "The number of image tokens (" +
119 std::to_string(image_intervals.size()) +
120 ") does not match the number of images (" +
121 std::to_string(images.size()) + ")");
122 }
123
124 // Create mask for each individual image based on its number of tokens,
125 // which can vary based on number of tiles since they are not yet tile padded.
126 // The masks are padded and concatenated together in the batch collator.
127 std::vector<executorch::extension::TensorPtr> cross_attention_masks;
128 size_t text_seq_len = tokens.size();
129 for (size_t image_idx = 0; image_idx < image_intervals.size(); ++image_idx) {
130 size_t n_tiles = images[image_idx].size(0);
131 size_t image_seq_len =
132 n_tiles * (patches_per_tile + 1); // +1 for the CLS token.
133
134 // Mask will be block of 1s at the corresponding interval in the text.
135 // It is not a causal block because all the image tokens correspond
136 // to a single image, so text tokens attend to all the image's tokens.
137 std::vector<TensorImpl::SizesType> sizes = {
138 safe_size_t_to_sizes_type(text_seq_len),
139 safe_size_t_to_sizes_type(image_seq_len)};
140
141 // Allocate the underlying data to be handled by the managed tensor.
142 size_t num_elements = text_seq_len * image_seq_len;
143 size_t stride = image_seq_len;
144 std::vector<int> mask_data(num_elements);
145
146 auto mask = executorch::extension::from_blob(
147 mask_data.data(), sizes, ScalarType::Int);
148 cross_attention_masks.emplace_back(std::move(mask));
149
150 // Add the allocated data to the output vector.
151 out.emplace_back(std::move(mask_data));
152
153 // All rows of tensor in the text_seq_len dimension within the interval are
154 // set to 1 (true).
155 size_t start = image_intervals[image_idx][0];
156 size_t end = image_intervals[image_idx][1]; // End is exclusive.
157 for (size_t i = start; i < end; ++i) {
158 for (size_t j = 0; j < image_seq_len; ++j) {
159 size_t unrolled_index = i * image_seq_len + j;
160 if (unrolled_index >= out[image_idx].size()) {
161 throw std::out_of_range(
162 "Index " + std::to_string(unrolled_index) +
163 " out of range of output tensor.");
164 }
165 out[image_idx][i * stride + j] = 1;
166 }
167 }
168 }
169
170 return cross_attention_masks;
171 }
172
173 } // namespace example
174