1 #pragma once
2
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/frontend/schema_matching.h>
5 #include <cstddef>
6
7 namespace torch::jit {
8
9 // Calculates the number of args that need to be passed in.
10 // Less args may be needed if defaults are provided.
11 // Returns: {number args needed, number of out args}
CalculateNecessaryArgs(const std::vector<Argument> & schema_args,at::ArrayRef<Value * > actual_inputs,bool allow_trailing_out_args)12 inline std::pair<int64_t, int64_t> CalculateNecessaryArgs(
13 const std::vector<Argument>& schema_args,
14 at::ArrayRef<Value*> actual_inputs,
15 bool allow_trailing_out_args) {
16 if (schema_args.empty()) {
17 return std::make_pair(0, 0);
18 }
19
20 // count number of out arguments
21 int64_t schema_idx = static_cast<int64_t>(schema_args.size()) - 1;
22 if (allow_trailing_out_args) {
23 // skip over out arguments in the end.
24 while (schema_idx >= 0) {
25 const auto& current_arg = schema_args.at(schema_idx);
26 if (!current_arg.is_out()) {
27 break;
28 }
29 schema_idx--;
30 }
31 }
32
33 int64_t num_out = static_cast<int64_t>(schema_args.size()) - schema_idx - 1;
34
35 if (schema_args.size() < actual_inputs.size()) {
36 return std::make_pair(actual_inputs.size(), num_out);
37 }
38
39 // if it is the default args, we reset the index to the last element
40 if (!allow_trailing_out_args) {
41 schema_idx = schema_args.size() - 1;
42 }
43 // keeps track of trailing unnecessary args
44 while (schema_idx >= 0) {
45 // this means it is not default argument, so it is necessary
46 if (!schema_args.at(schema_idx).default_value().has_value()) {
47 return std::make_pair(schema_idx + 1, num_out);
48 } else {
49 auto schema_value =
50 schema_args.at(schema_idx).default_value().value().toIValue();
51 // non-const value will become nullptr here, so will be marked necessary
52 // non-const would include prim::ListConstruct, prim::DictConstruct as
53 // well.
54 auto actual_value = toIValue(actual_inputs[schema_idx]);
55 if (!actual_value.has_value()) {
56 return std::make_pair(schema_idx + 1, num_out);
57 }
58 // if the IR has same value as default value of the schema,
59 // it is not necessary argument.
60 if (schema_value != actual_value.value()) {
61 return std::make_pair(schema_idx + 1, num_out);
62 }
63 }
64 schema_idx--;
65 }
66 return std::make_pair(0, num_out);
67 }
68
69 } // namespace torch::jit
70