1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/examples/models/llama3_2_vision/cross_attention/cross_attention_mask.h>
10 
11 #include <algorithm>
12 #include <string>
13 
14 namespace example {
15 
16 using ::executorch::aten::ScalarType;
17 using ::executorch::aten::Tensor;
18 using ::executorch::aten::TensorImpl;
19 
20 // Fowrward declaration needed for ARM compilers.
21 int32_t safe_size_t_to_sizes_type(size_t value);
22 std::vector<std::vector<int>> _get_image_attention_intervals(
23     const std::vector<int>& tokens,
24     int image_token_id);
25 
safe_size_t_to_sizes_type(size_t value)26 int32_t safe_size_t_to_sizes_type(size_t value) {
27   if (value >
28       static_cast<size_t>(std::numeric_limits<TensorImpl::SizesType>::max())) {
29     throw std::overflow_error(
30         "size_t value too large for TensorImpl::SizesType");
31   }
32   return static_cast<TensorImpl::SizesType>(value);
33 }
34 
35 /**
36  * Returns a list of lists of the form [start, end) where start is the index
37  * of the current image token and end is the index of the next image token,
38  * exclusive.
39  *
40  * Example:
41  *     >>> text = "<img1><img2>These are two dogs. <img3>This is a cat."
42  *     >>> size_t image_token_id = 1;
43  *     >>> std::vector<int> tokens = {1, 1, 9673, 527, 1403, 12875, 13, 1, 1115,
44  * 374, 264, 8415]};
45  *     >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40,
46  * image_token_id=1)
47  *     >>> intervals = _get_image_attention_intervals(tokens, image_token_id)
48  *     [[0, 7], [1, 7], [7, 12]]
49  *
50  * @param tokens List of token IDs in the text sequence.
51  * @param image_token_id The value of the image token.
52  *
53  * @returns Vector of vectors of the form [start, end) indicating the range of
54  * positions in the text sequence that should attend to the image.
55  */
_get_image_attention_intervals(const std::vector<int> & tokens,int image_token_id)56 std::vector<std::vector<int>> _get_image_attention_intervals(
57     const std::vector<int>& tokens,
58     int image_token_id) {
59   std::vector<std::vector<int>> vision_masks;
60   int end = tokens.size();
61   std::vector<int> vision_token_locations;
62 
63   // Find all vision token locations.
64   for (int i = 0; i < tokens.size(); ++i) {
65     if (tokens[i] == image_token_id) {
66       vision_token_locations.push_back(i);
67     }
68   }
69 
70   // Return empty vector if there are no images.
71   if (vision_token_locations.empty()) {
72     return vision_masks;
73   }
74 
75   // If there is only one image, it will attend to subsequent text until end.
76   if (vision_token_locations.size() == 1) {
77     vision_masks.push_back({vision_token_locations[0], end});
78     return vision_masks;
79   }
80 
81   // Construct intervals from previous image token to next image token.
82   for (int i = 0; i < vision_token_locations.size() - 1; ++i) {
83     vision_masks.push_back(
84         {vision_token_locations[i], vision_token_locations[i + 1]});
85   }
86 
87   // Last image will attend to subsequent text until end.
88   vision_masks.push_back({vision_token_locations.back(), end});
89 
90   // If there are consecutive vision tokens, they should all attend to the
91   // same subsequent text.
92   int last_mask_end = vision_masks.back()[1];
93   for (auto it = vision_masks.rbegin(); it != vision_masks.rend(); ++it) {
94     if ((*it)[0] == (*it)[1] - 1) {
95       (*it)[1] = last_mask_end;
96     }
97     last_mask_end = (*it)[1];
98   }
99 
100   return vision_masks;
101 }
102 
cross_attention_mask(const std::vector<int> & tokens,const std::vector<Tensor> & images,size_t tile_size,size_t patch_size,int image_token_id,std::vector<std::vector<int>> & out)103 std::vector<executorch::extension::TensorPtr> cross_attention_mask(
104     const std::vector<int>& tokens,
105     const std::vector<Tensor>& images,
106     size_t tile_size,
107     size_t patch_size,
108     int image_token_id,
109     std::vector<std::vector<int>>& out) {
110   size_t patch_grid_size = tile_size / patch_size;
111   size_t patches_per_tile = patch_grid_size * patch_grid_size;
112 
113   std::vector<std::vector<int>> image_intervals =
114       _get_image_attention_intervals(tokens, image_token_id);
115 
116   if (image_intervals.size() != images.size()) {
117     throw std::runtime_error(
118         "The number of image tokens (" +
119         std::to_string(image_intervals.size()) +
120         ") does not match the number of images (" +
121         std::to_string(images.size()) + ")");
122   }
123 
124   // Create mask for each individual image based on its number of tokens,
125   // which can vary based on number of tiles since they are not yet tile padded.
126   // The masks are padded and concatenated together in the batch collator.
127   std::vector<executorch::extension::TensorPtr> cross_attention_masks;
128   size_t text_seq_len = tokens.size();
129   for (size_t image_idx = 0; image_idx < image_intervals.size(); ++image_idx) {
130     size_t n_tiles = images[image_idx].size(0);
131     size_t image_seq_len =
132         n_tiles * (patches_per_tile + 1); // +1 for the CLS token.
133 
134     // Mask will be block of 1s at the corresponding interval in the text.
135     // It is not a causal block because all the image tokens correspond
136     // to a single image, so text tokens attend to all the image's tokens.
137     std::vector<TensorImpl::SizesType> sizes = {
138         safe_size_t_to_sizes_type(text_seq_len),
139         safe_size_t_to_sizes_type(image_seq_len)};
140 
141     // Allocate the underlying data to be handled by the managed tensor.
142     size_t num_elements = text_seq_len * image_seq_len;
143     size_t stride = image_seq_len;
144     std::vector<int> mask_data(num_elements);
145 
146     auto mask = executorch::extension::from_blob(
147         mask_data.data(), sizes, ScalarType::Int);
148     cross_attention_masks.emplace_back(std::move(mask));
149 
150     // Add the allocated data to the output vector.
151     out.emplace_back(std::move(mask_data));
152 
153     // All rows of tensor in the text_seq_len dimension within the interval are
154     // set to 1 (true).
155     size_t start = image_intervals[image_idx][0];
156     size_t end = image_intervals[image_idx][1]; // End is exclusive.
157     for (size_t i = start; i < end; ++i) {
158       for (size_t j = 0; j < image_seq_len; ++j) {
159         size_t unrolled_index = i * image_seq_len + j;
160         if (unrolled_index >= out[image_idx].size()) {
161           throw std::out_of_range(
162               "Index " + std::to_string(unrolled_index) +
163               " out of range of output tensor.");
164         }
165         out[image_idx][i * stride + j] = 1;
166       }
167     }
168   }
169 
170   return cross_attention_masks;
171 }
172 
173 } // namespace example
174