xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/infer_schema.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/op_registration/infer_schema.h>
2 #include <c10/util/irange.h>
3 #include <fmt/format.h>
4 
5 namespace c10 {
6 
7 namespace detail {
8 namespace infer_schema {
9 namespace {
10 
createArgumentVector(c10::ArrayRef<ArgumentDef> args)11 std::vector<Argument> createArgumentVector(c10::ArrayRef<ArgumentDef> args) {
12   std::vector<Argument> result;
13   result.reserve(args.size());
14   for (const auto i : c10::irange(args.size())) {
15     // Arguments are named "_<index>"
16     result.emplace_back(
17         fmt::format("_{}", i),
18         (*args[i].getFakeTypeFn)(),
19         (*args[i].getTypeFn)());
20   }
21   return result;
22 }
23 } // namespace
24 // This is intentionally a separate function and in a .cpp file
25 // because then the template is smaller and that benefits binary size
make_function_schema(std::string && name,std::string && overload_name,c10::ArrayRef<ArgumentDef> arguments,c10::ArrayRef<ArgumentDef> returns)26 FunctionSchema make_function_schema(
27     std::string&& name,
28     std::string&& overload_name,
29     c10::ArrayRef<ArgumentDef> arguments,
30     c10::ArrayRef<ArgumentDef> returns) {
31   return FunctionSchema(
32       std::move(name),
33       std::move(overload_name),
34       createArgumentVector(arguments),
35       createArgumentVector(returns));
36 }
37 
make_function_schema(c10::ArrayRef<ArgumentDef> arguments,c10::ArrayRef<ArgumentDef> returns)38 FunctionSchema make_function_schema(
39     c10::ArrayRef<ArgumentDef> arguments,
40     c10::ArrayRef<ArgumentDef> returns) {
41   return make_function_schema("", "", arguments, returns);
42 }
43 } // namespace infer_schema
44 } // namespace detail
45 
findSchemaDifferences(const FunctionSchema & lhs,const FunctionSchema & rhs)46 std::optional<std::string> findSchemaDifferences(
47     const FunctionSchema& lhs,
48     const FunctionSchema& rhs) {
49   if (lhs.arguments().size() != rhs.arguments().size()) {
50     return fmt::format(
51         "The number of arguments is different. {} vs {}.",
52         lhs.arguments().size(),
53         rhs.arguments().size());
54   }
55   if (lhs.returns().size() != rhs.returns().size()) {
56     return fmt::format(
57         "The number of returns is different. {} vs {}.",
58         lhs.returns().size(),
59         rhs.returns().size());
60   }
61 
62   for (const auto i : c10::irange(lhs.arguments().size())) {
63     const TypePtr& leftType = lhs.arguments()[i].type();
64     const TypePtr& rightType = rhs.arguments()[i].type();
65     // Type::operator== is virtual. Comparing pointers first is
66     // cheaper, particularly when one of the types is a singleton like
67     // NumberType or AnyType.
68     if (leftType.get() != rightType.get() && *leftType != *rightType) {
69       return fmt::format(
70           "Type mismatch in argument {}: {} vs {}.",
71           i + 1,
72           lhs.arguments()[i].type()->str(),
73           rhs.arguments()[i].type()->str());
74     }
75   }
76 
77   for (const auto i : c10::irange(lhs.returns().size())) {
78     const TypePtr& leftType = lhs.returns()[i].type();
79     const TypePtr& rightType = rhs.returns()[i].type();
80     // See above about comparing pointers first.
81     if (leftType.get() != rightType.get() && *leftType != *rightType) {
82       return fmt::format(
83           "Type mismatch in return {}: {} vs {}.",
84           i + 1,
85           lhs.returns()[i].type()->str(),
86           rhs.returns()[i].type()->str());
87     }
88   }
89 
90   // no differences found
91   return std::nullopt;
92 }
93 
94 } // namespace c10
95