xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_runtime/model_container.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <condition_variable>
5 #include <deque>
6 #include <mutex>
7 #include <shared_mutex>
8 
9 // WARNING: Be careful when adding new includes here. This header will be used
10 // in model.so, and should not refer to any aten/c10 headers except the stable
11 // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule
12 // applies to other files under torch/csrc/inductor/aoti_runtime/.
13 #include <torch/csrc/inductor/aoti_runtime/model.h>
14 
15 namespace torch::aot_inductor {
16 
17 class AOTInductorModelContainer {
18  public:
19   AOTInductorModelContainer(
20       size_t num_models,
21       const std::string& device_str,
22       const std::optional<std::string>& cubin_dir = std::nullopt)
use_secondary_(false)23       : use_secondary_(false), constant_folded_(false) {
24     constants_map_ = std::make_shared<ConstantMap>();
25     constants_array_ = std::make_shared<std::vector<ConstantHandle>>();
26 
27     models_.reserve(num_models);
28     available_models_.reserve(num_models);
29     for (size_t i = 0; i < num_models; ++i) {
30       models_.push_back(AOTInductorModel::Create(
31           constants_map_, constants_array_, device_str, cubin_dir));
32       available_models_.push_back(models_.back().get());
33     }
34 
35     // Note that the all following fields (input_names_, output_names,
36     // etc) can be filled in by the AOT
37     // codegen. However, we choose to query such information from
38     // the owned AOTInductorModel for a couple of reasons:
39     //   * simplify the codegen templates
40     //   * reduce information fragmentation and duplication
41     //   * the initialization process below is done only once when the container
42     //     is constructed, so it would have little performance impact
43     auto* model = available_models_[0];
44     size_t num_inputs = model->num_inputs();
45     input_names_.reserve(num_inputs);
46     for (size_t i = 0; i < num_inputs; i++) {
47       input_names_.emplace_back(model->input_name(static_cast<int64_t>(i)));
48     }
49 
50     size_t num_outputs = model->num_outputs();
51     output_names_.reserve(num_outputs);
52     for (size_t i = 0; i < num_outputs; i++) {
53       output_names_.emplace_back(model->output_name(static_cast<int64_t>(i)));
54     }
55 
56     model->load_constants();
57 #ifdef USE_CUDA
58     constant_blob_ = model->release_constant_blob();
59     constants_internal_offset_.resize(model->num_constants());
60     model->compute_cuda_constant_blob(blob_size_, constants_internal_offset_);
61 #endif
62 
63     for (auto& model : models_) {
64       model->update_constants_map(constants_map_);
65     }
66 
67     in_spec_ = model->get_in_spec();
68     out_spec_ = model->get_out_spec();
69   }
70 
run(AtenTensorHandle * input_handles,AtenTensorHandle * output_handles,DeviceStreamType stream,AOTIProxyExecutorHandle proxy_executor)71   void run(
72       AtenTensorHandle*
73           input_handles, // array of input AtenTensorHandle; handles
74                          // are stolen; the array itself is borrowed
75       AtenTensorHandle*
76           output_handles, // array for writing output AtenTensorHandle; handles
77                           // will be stolen by the caller; the array itself is
78                           // borrowed
79       DeviceStreamType stream,
80       AOTIProxyExecutorHandle proxy_executor) {
81     std::shared_lock model_lk(model_exec_mutex_);
82     auto* model = get_available_model();
83 
84     if (!constant_folded_) {
85       // At this point, constant is not ready yet. We need to call constant
86       // folding before we execute the model. We obtain a unique lock at this
87       // point to make sure constant is ready for all.
88       model_lk.unlock();
89       std::unique_lock constants_folding_lk(model_exec_mutex_);
90       // Double locking to make sure constant folding is only ran once.
91       if (!constant_folded_) {
92         auto folded_const_map = model->run_const_fold(
93             stream, proxy_executor, /* initialization = */ true);
94         update_constant_buffer(
95             folded_const_map,
96             /* use_inactive = */ false,
97             /* validate_full_update = */ false);
98         constant_folded_ = true;
99       }
100       constants_folding_lk.unlock();
101       model_lk.lock();
102     }
103 
104     try {
105       model->run(input_handles, output_handles, stream, proxy_executor);
106     } catch (...) {
107       std::lock_guard lk(models_mutex_);
108       available_models_.push_back(model);
109       throw;
110     }
111 
112     {
113       std::lock_guard lk(models_mutex_);
114       pending_models_.push_back(model);
115     }
116     pending_models_available_.notify_one();
117   }
118 
num_constants()119   size_t num_constants() const {
120     if (this->num_models() == 0) {
121       throw std::runtime_error("No available models in container!");
122     }
123     return models_[0]->num_constants();
124   }
125 
126   // retrieve the constant name of constants_info_[idx]
constant_name(size_t idx)127   const char* constant_name(size_t idx) const {
128     if (this->num_models() == 0) {
129       throw std::runtime_error("No available models in container!");
130     }
131     return models_[0]->constant_name(static_cast<int64_t>(idx));
132   }
133 
134   // retrieve original FQN of constants_info_[idx]
constant_original_fqn(size_t idx)135   const char* constant_original_fqn(size_t idx) const {
136     if (this->num_models() == 0) {
137       throw std::runtime_error("No available models in container!");
138     }
139     return models_[0]->constant_original_fqn(static_cast<int64_t>(idx));
140   }
141 
142   // retrieve whether constant is from folded of constants_info_[idx]
constant_from_folded(size_t idx)143   bool constant_from_folded(size_t idx) const {
144     if (this->num_models() == 0) {
145       throw std::runtime_error("No available models in container!");
146     }
147     return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
148   }
149 
150   // retrieve dtype of constants_info_[idx]
constant_dtype(size_t idx)151   int32_t constant_dtype(size_t idx) const {
152     if (this->num_models() == 0) {
153       throw std::runtime_error("No available models in container!");
154     }
155     return models_[0]->constant_dtype(static_cast<int64_t>(idx));
156   }
157 
run_const_fold(bool inactive_buffer,DeviceStreamType stream,AOTIProxyExecutorHandle proxy_executor)158   void run_const_fold(
159       bool inactive_buffer,
160       DeviceStreamType stream,
161       AOTIProxyExecutorHandle proxy_executor) {
162     std::shared_lock model_lk(model_exec_mutex_);
163     auto* model = get_available_model();
164 
165     if (!inactive_buffer) {
166       // We would need to acquire a unique lock if we want to run constant
167       // folding on the active buffer.
168       model_lk.unlock();
169       std::unique_lock constants_folding_lk(model_exec_mutex_);
170       try {
171         auto folded_const_map = model->run_const_fold(stream, proxy_executor);
172         update_constant_buffer(
173             folded_const_map,
174             /* use_inactive = */ false,
175             /* validate_full_update = */ false);
176       } catch (...) {
177         std::lock_guard lk(models_mutex_);
178         available_models_.push_back(model);
179         throw;
180       }
181       constants_folding_lk.unlock();
182       model_lk.lock();
183     } else {
184       // We swap the constant mapping to the inactive buffer in the model to run
185       // const run.
186       auto constants_map = get_constants_map(/* get_inactive= */ true);
187       auto constants_array = get_constants_array(/* get_inactive= */ true);
188 
189       try {
190         model->update_constants_map(
191             constants_map, /* remap_constants_array= */ false);
192         model->update_constants_array(constants_array);
193 
194         auto folded_const_map = model->run_const_fold(stream, proxy_executor);
195         update_constant_buffer(
196             folded_const_map,
197             /* use_inactive = */ true,
198             /* validate_full_update = */ false);
199 
200         // Swap back the model's constants mapping
201         constants_map = get_constants_map(/* get_inactive= */ false);
202         constants_array = get_constants_array(/* get_inactive= */ false);
203         model->update_constants_map(
204             constants_map, /* remap_constants_array= */ false);
205         model->update_constants_array(constants_array);
206       } catch (...) {
207         std::lock_guard lk(models_mutex_);
208         available_models_.push_back(model);
209         throw;
210       }
211     }
212 
213     {
214       std::lock_guard lk(models_mutex_);
215       pending_models_.push_back(model);
216     }
217     pending_models_available_.notify_one();
218   }
219 
_is_tensor_constant(const std::string & constant_name)220   bool _is_tensor_constant(const std::string& constant_name) const {
221     return constant_name.rfind("_tensor_constant", 0) == 0;
222   }
223   // This function updates the buffer for storing constants.
224   // It will update the buffer, the mapping and the array mapping.
update_constant_buffer(const std::unordered_map<std::string,AtenTensorHandle> & constants_map,bool use_inactive,bool validate_full_update)225   void update_constant_buffer(
226       const std::unordered_map<std::string, AtenTensorHandle>& constants_map,
227       bool use_inactive,
228       bool validate_full_update) {
229     if (this->num_models() == 0) {
230       throw std::runtime_error("No model available in container!");
231     }
232     auto num_constants = models_[0]->num_constants();
233 
234     if (validate_full_update) {
235       for (size_t idx = 0; idx < num_constants; idx++) {
236         if (models_[0]->constant_from_folded(static_cast<int64_t>(idx))) {
237           continue;
238         }
239 
240         auto constant_name =
241             std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
242         auto it = constants_map.find(constant_name);
243         if (it == constants_map.end()) {
244           if (_is_tensor_constant(constant_name)) {
245             // tracing sometimes creates tensors that are non-existent in
246             // original graph. We could skip those and do a direct copy.
247             std::cerr << "[WARNING] Found constant " << constant_name
248                       << " in model, but not provided by user!\n";
249             continue;
250           }
251           throw std::runtime_error(
252               std::string("Cannot find constants ") + constant_name +
253               std::string(" in constants_map!"));
254         }
255       }
256     }
257 
258     auto original_constants_map = get_constants_map(!use_inactive);
259     auto constants_map_to_update = get_constants_map(use_inactive);
260 
261     for (size_t idx = 0; idx < num_constants; idx++) {
262       auto constant_name =
263           std::string(models_[0]->constant_name(static_cast<int64_t>(idx)));
264       auto it = constants_map.find(constant_name);
265       if (it == constants_map.end() &&
266           !(_is_tensor_constant(constant_name) && use_inactive)) {
267         continue;
268       }
269 
270 #ifdef USE_CUDA
271       AtenTensorHandle tensor;
272       if (_is_tensor_constant(constant_name) && use_inactive) {
273         tensor = original_constants_map->find(constant_name)->second.get();
274       } else {
275         tensor = it->second;
276       }
277       auto* constants_blob_ptr =
278           static_cast<uint8_t*>(get_constant_blob_ptr(use_inactive));
279 
280       // Move the data to container handled blob.
281       uint8_t* internal_constants_ptr =
282           constants_blob_ptr + constants_internal_offset_[idx];
283       void* user_constant_ptr;
284       int64_t constant_size;
285       aoti_torch_get_data_ptr(tensor, &user_constant_ptr);
286       aoti_torch_get_storage_size(tensor, &constant_size);
287 
288       AOTI_RUNTIME_DEVICE_CHECK(cudaMemcpy(
289           internal_constants_ptr,
290           user_constant_ptr,
291           constant_size,
292           cudaMemcpyDefault));
293 
294       // Generate Tensor from container handled blob.
295       // We extract stride and offset from provided Tensor since we do not
296       // guarantee that the tensor is contiguous.
297       AtenTensorHandle tensor_handle;
298       int64_t* stride;
299       int64_t offset;
300       int device_idx = models_[0]->get_device_idx();
301       AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride));
302       AOTI_TORCH_ERROR_CODE_CHECK(
303           aoti_torch_get_storage_offset(tensor, &offset));
304       AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
305           internal_constants_ptr,
306           models_[0]->constant_ndim(idx),
307           models_[0]->constant_shape(idx),
308           stride,
309           offset,
310           models_[0]->constant_dtype(idx),
311           aoti_torch_device_type_cuda(),
312           device_idx,
313           &tensor_handle));
314 #else // USE_CUDA
315       AtenTensorHandle tensor_handle = it->second;
316 #endif // USE_CUDA
317 
318       // Now place the tensor to constants_map. Note at this point the ownership
319       // of the tensor_handle will be taken over.
320       constants_map_to_update->emplace(constant_name, tensor_handle);
321     }
322     // Update the inactive constant array.
323     update_array_from_map(
324         get_constants_array(use_inactive), constants_map_to_update);
325   }
326 
update_array_from_map(const std::shared_ptr<std::vector<ConstantHandle>> & constants_array,const std::shared_ptr<ConstantMap> & constants_map)327   void update_array_from_map(
328       const std::shared_ptr<std::vector<ConstantHandle>>& constants_array,
329       const std::shared_ptr<ConstantMap>& constants_map) {
330     auto num_constants = models_[0]->num_constants();
331     for (size_t idx = 0; idx < num_constants; idx++) {
332       if (constants_map->find(models_[0]->constant_name(
333               static_cast<int64_t>(idx))) != constants_map->end()) {
334         constants_array->at(idx) = ConstantHandle(
335             constants_map
336                 ->find(models_[0]->constant_name(static_cast<int64_t>(idx)))
337                 ->second);
338       }
339     }
340   }
341 
swap_constant_buffer()342   void swap_constant_buffer() {
343     std::lock_guard unique_lk(model_exec_mutex_);
344 
345     auto constants_map = get_constants_map(/* get_inactive= */ true);
346     auto constants_array = get_constants_array(/* get_inactive= */ true);
347 
348     for (auto& model : models_) {
349       model->update_constants_map(
350           constants_map, /* remap_constants_array = */ false);
351       model->update_constants_array(constants_array);
352     }
353 
354     use_secondary_ = !use_secondary_;
355   }
356 
num_inputs()357   size_t num_inputs() const {
358     return input_names_.size();
359   }
360 
num_outputs()361   size_t num_outputs() const {
362     return output_names_.size();
363   }
364 
input_name(size_t idx)365   const char* input_name(size_t idx) const {
366     return input_names_.at(idx).c_str();
367   }
368 
output_name(size_t idx)369   const char* output_name(size_t idx) const {
370     return output_names_.at(idx).c_str();
371   }
372 
num_models()373   size_t num_models() const {
374     return models_.size();
375   }
376 
get_in_spec()377   const char* get_in_spec() const {
378     return in_spec_;
379   }
380 
get_out_spec()381   const char* get_out_spec() const {
382     return out_spec_;
383   }
384 
385  private:
386   std::vector<std::string> input_names_;
387   std::vector<std::string> output_names_;
388   const char* in_spec_;
389   const char* out_spec_;
390 
391 #ifdef USE_CUDA
392   // Holds the blob storage for constants' at::Tensor for CUDA.
393   CUDAPtr constant_blob_;
394   CUDAPtr constant_blob_secondary_;
395 
396   // Let's place this within USE_CUDA at the moment before we fully support
397   // update for CPU cases.
398   size_t blob_size_;
399   std::vector<size_t> constants_internal_offset_;
400 #endif // USE_CUDA
401 
402   // Determine which constants is being used for the model.
403   // If true,
404   // constants_map_secondary/constant_blob_secondary/constants_array_secondary
405   // is being used.
406   bool use_secondary_;
407 
408   // Determine whether we have ran constant folding
409   bool constant_folded_;
410 
411   // Holds the mapping of constants to at::Tensor.
412   // The underlying data of at::Tensor is in either constant_blob_ (for CUDA).
413   // or _binary_constants_bin_start (for CPU).
414   std::shared_ptr<ConstantMap> constants_map_;
415   std::shared_ptr<ConstantMap> constants_map_secondary_;
416 
417   // Holds the indexed array of constant for faster lookup during runtime.
418   std::shared_ptr<std::vector<ConstantHandle>> constants_array_;
419   std::shared_ptr<std::vector<ConstantHandle>> constants_array_secondary_;
420 
421   // Holds all the AOTInductorModel instances owned by this container.
422   std::vector<std::unique_ptr<AOTInductorModel>> models_;
423 
424   // Holds the AOTInductorModel instances available for inference.
425   std::vector<AOTInductorModel*> available_models_;
426 
427   // Holds the AOTInductorModel instances that have started running
428   // inference and can be placed onto available_models_ upon their
429   // completion.
430   std::deque<AOTInductorModel*> pending_models_;
431 
432   // Protects available_models_ and pending_models_.
433   std::mutex models_mutex_;
434 
435   // Notified whenever a model is placed onto pending_models_.
436   std::condition_variable pending_models_available_;
437 
get_available_model()438   AOTInductorModel* get_available_model() {
439     std::unique_lock lk(models_mutex_);
440     if (available_models_.empty()) {
441       reclaim_finished_models(lk);
442     }
443     auto* result = available_models_.back();
444     available_models_.pop_back();
445     return result;
446   }
447 
448   // This mutex is used to protect execution of model.
449   // We acquire the mutex in shared mode if we allow concurrent execution.
450   // We acquire the mutex in unique mode when we want exclusive access of the
451   // model. One such case is when we want to do a weight swapping. We want to
452   // make sure no one is executing the model.
453   std::shared_mutex model_exec_mutex_;
454 
455 #ifdef USE_CUDA
get_constant_blob_ptr(bool get_inactive)456   void* get_constant_blob_ptr(bool get_inactive) {
457     if ((get_inactive && use_secondary_) ||
458         (!get_inactive && !use_secondary_)) {
459       return constant_blob_.get();
460     } else {
461       if (!constant_blob_secondary_) {
462         constant_blob_secondary_ = RAII_cudaMalloc(blob_size_);
463       }
464       return constant_blob_secondary_.get();
465     }
466   }
467 #endif // USE_CUDA
468 
get_constants_map(bool get_inactive)469   std::shared_ptr<ConstantMap> get_constants_map(bool get_inactive) {
470     if ((get_inactive && use_secondary_) ||
471         (!get_inactive && !use_secondary_)) {
472       return constants_map_;
473     } else {
474       if (!constants_map_secondary_) {
475         constants_map_secondary_ = std::make_shared<ConstantMap>();
476       }
477       return constants_map_secondary_;
478     }
479   }
480 
get_constants_array(bool get_inactive)481   std::shared_ptr<std::vector<ConstantHandle>> get_constants_array(
482       bool get_inactive) {
483     if ((get_inactive && use_secondary_) ||
484         (!get_inactive && !use_secondary_)) {
485       return constants_array_;
486     } else {
487       if (!constants_array_secondary_) {
488         constants_array_secondary_ =
489             std::make_shared<std::vector<ConstantHandle>>(
490                 models_[0]->num_constants());
491       }
492       return constants_array_secondary_;
493     }
494   }
495 
reclaim_finished_models(std::unique_lock<std::mutex> & lk)496   void reclaim_finished_models(std::unique_lock<std::mutex>& lk) {
497     // push finished model instances to the end of pending_models_
498     auto it = std::stable_partition(
499         pending_models_.begin(),
500         pending_models_.end(),
501         [](AOTInductorModel* m) { return !m->is_finished(); });
502 
503     if (it != pending_models_.end()) {
504       // We have finished model instances that can be pushed into
505       // available_models_ so that we don't have to be blocked on waiting
506       // the pending_models_available_ condition.
507       available_models_.insert(
508           available_models_.end(), it, pending_models_.end());
509       pending_models_.erase(it, pending_models_.end());
510       return;
511     }
512 
513     pending_models_available_.wait(
514         lk, [this]() { return !pending_models_.empty(); });
515     // Let's make the schedule simple first. We always wait on the first
516     // pending_models_ to be complete.
517     auto* model = pending_models_.front();
518     pending_models_.pop_front();
519     lk.unlock();
520     try {
521       model->wait_for_completion();
522     } catch (...) {
523       lk.lock();
524       available_models_.push_back(model);
525       throw;
526     }
527     lk.lock();
528     available_models_.push_back(model);
529   }
530 };
531 
532 } // namespace torch::aot_inductor
533