1 #include <ATen/core/op_registration/infer_schema.h>
2 #include <c10/util/irange.h>
3 #include <fmt/format.h>
4
5 namespace c10 {
6
7 namespace detail {
8 namespace infer_schema {
9 namespace {
10
createArgumentVector(c10::ArrayRef<ArgumentDef> args)11 std::vector<Argument> createArgumentVector(c10::ArrayRef<ArgumentDef> args) {
12 std::vector<Argument> result;
13 result.reserve(args.size());
14 for (const auto i : c10::irange(args.size())) {
15 // Arguments are named "_<index>"
16 result.emplace_back(
17 fmt::format("_{}", i),
18 (*args[i].getFakeTypeFn)(),
19 (*args[i].getTypeFn)());
20 }
21 return result;
22 }
23 } // namespace
24 // This is intentionally a separate function and in a .cpp file
25 // because then the template is smaller and that benefits binary size
make_function_schema(std::string && name,std::string && overload_name,c10::ArrayRef<ArgumentDef> arguments,c10::ArrayRef<ArgumentDef> returns)26 FunctionSchema make_function_schema(
27 std::string&& name,
28 std::string&& overload_name,
29 c10::ArrayRef<ArgumentDef> arguments,
30 c10::ArrayRef<ArgumentDef> returns) {
31 return FunctionSchema(
32 std::move(name),
33 std::move(overload_name),
34 createArgumentVector(arguments),
35 createArgumentVector(returns));
36 }
37
make_function_schema(c10::ArrayRef<ArgumentDef> arguments,c10::ArrayRef<ArgumentDef> returns)38 FunctionSchema make_function_schema(
39 c10::ArrayRef<ArgumentDef> arguments,
40 c10::ArrayRef<ArgumentDef> returns) {
41 return make_function_schema("", "", arguments, returns);
42 }
43 } // namespace infer_schema
44 } // namespace detail
45
findSchemaDifferences(const FunctionSchema & lhs,const FunctionSchema & rhs)46 std::optional<std::string> findSchemaDifferences(
47 const FunctionSchema& lhs,
48 const FunctionSchema& rhs) {
49 if (lhs.arguments().size() != rhs.arguments().size()) {
50 return fmt::format(
51 "The number of arguments is different. {} vs {}.",
52 lhs.arguments().size(),
53 rhs.arguments().size());
54 }
55 if (lhs.returns().size() != rhs.returns().size()) {
56 return fmt::format(
57 "The number of returns is different. {} vs {}.",
58 lhs.returns().size(),
59 rhs.returns().size());
60 }
61
62 for (const auto i : c10::irange(lhs.arguments().size())) {
63 const TypePtr& leftType = lhs.arguments()[i].type();
64 const TypePtr& rightType = rhs.arguments()[i].type();
65 // Type::operator== is virtual. Comparing pointers first is
66 // cheaper, particularly when one of the types is a singleton like
67 // NumberType or AnyType.
68 if (leftType.get() != rightType.get() && *leftType != *rightType) {
69 return fmt::format(
70 "Type mismatch in argument {}: {} vs {}.",
71 i + 1,
72 lhs.arguments()[i].type()->str(),
73 rhs.arguments()[i].type()->str());
74 }
75 }
76
77 for (const auto i : c10::irange(lhs.returns().size())) {
78 const TypePtr& leftType = lhs.returns()[i].type();
79 const TypePtr& rightType = rhs.returns()[i].type();
80 // See above about comparing pointers first.
81 if (leftType.get() != rightType.get() && *leftType != *rightType) {
82 return fmt::format(
83 "Type mismatch in return {}: {} vs {}.",
84 i + 1,
85 lhs.returns()[i].type()->str(),
86 rhs.returns()[i].type()->str());
87 }
88 }
89
90 // no differences found
91 return std::nullopt;
92 }
93
94 } // namespace c10
95