xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/calculate_necessary_args.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <cstddef>
6 
7 namespace torch::jit {
8 
9 // Calculates the number of args that need to be passed in.
10 // Less args may be needed if defaults are provided.
11 // Returns: {number args needed, number of out args}
CalculateNecessaryArgs(const std::vector<Argument> & schema_args,at::ArrayRef<Value * > actual_inputs,bool allow_trailing_out_args)12 inline std::pair<int64_t, int64_t> CalculateNecessaryArgs(
13     const std::vector<Argument>& schema_args,
14     at::ArrayRef<Value*> actual_inputs,
15     bool allow_trailing_out_args) {
16   if (schema_args.empty()) {
17     return std::make_pair(0, 0);
18   }
19 
20   // count number of out arguments
21   int64_t schema_idx = static_cast<int64_t>(schema_args.size()) - 1;
22   if (allow_trailing_out_args) {
23     // skip over out arguments in the end.
24     while (schema_idx >= 0) {
25       const auto& current_arg = schema_args.at(schema_idx);
26       if (!current_arg.is_out()) {
27         break;
28       }
29       schema_idx--;
30     }
31   }
32 
33   int64_t num_out = static_cast<int64_t>(schema_args.size()) - schema_idx - 1;
34 
35   if (schema_args.size() < actual_inputs.size()) {
36     return std::make_pair(actual_inputs.size(), num_out);
37   }
38 
39   // if it is the default args, we reset the index to the last element
40   if (!allow_trailing_out_args) {
41     schema_idx = schema_args.size() - 1;
42   }
43   // keeps track of trailing unnecessary args
44   while (schema_idx >= 0) {
45     // this means it is not default argument, so it is necessary
46     if (!schema_args.at(schema_idx).default_value().has_value()) {
47       return std::make_pair(schema_idx + 1, num_out);
48     } else {
49       auto schema_value =
50           schema_args.at(schema_idx).default_value().value().toIValue();
51       // non-const value will become nullptr here, so will be marked necessary
52       // non-const would include prim::ListConstruct, prim::DictConstruct as
53       // well.
54       auto actual_value = toIValue(actual_inputs[schema_idx]);
55       if (!actual_value.has_value()) {
56         return std::make_pair(schema_idx + 1, num_out);
57       }
58       // if the IR has same value as default value of the schema,
59       // it is not necessary argument.
60       if (schema_value != actual_value.value()) {
61         return std::make_pair(schema_idx + 1, num_out);
62       }
63     }
64     schema_idx--;
65   }
66   return std::make_pair(0, num_out);
67 }
68 
69 } // namespace torch::jit
70