1 #pragma once
2
3 #include <gtest/gtest.h>
4 #include <gmock/gmock.h>
5
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/dispatch/Dispatcher.h>
8 #include <ATen/core/ivalue.h>
9 #include <c10/core/CPUAllocator.h>
10 #include <c10/util/irange.h>
11
12 template<class... Inputs>
makeStack(Inputs &&...inputs)13 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
14 return {std::forward<Inputs>(inputs)...};
15 }
16
17 inline at::Tensor dummyTensor(c10::DispatchKeySet ks, bool requires_grad=false) {
18 auto* allocator = c10::GetCPUAllocator();
19 int64_t nelements = 1;
20 auto dtype = caffe2::TypeMeta::Make<float>();
21 int64_t size_bytes = nelements * dtype.itemsize();
22 auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
23 c10::StorageImpl::use_byte_size_t(),
24 size_bytes,
25 allocator->allocate(size_bytes),
26 allocator,
27 /*resizable=*/true);
28 at::Tensor t = at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
29 // TODO: We add this to simulate the ideal case where we only have Autograd backend keys
30 // on Tensor when it requires grad. But currently Autograd keys are added in TensorImpl
31 // constructor by default.
32 if (!requires_grad) {
33 t.unsafeGetTensorImpl()->remove_autograd_key();
34 }
35 return t;
36 }
37
38 inline at::Tensor dummyTensor(c10::DispatchKey dispatch_key, bool requires_grad=false) {
39 return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
40 }
41
42 template<class... Args>
callOp(const c10::OperatorHandle & op,Args...args)43 inline std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... args) {
44 auto stack = makeStack(std::forward<Args>(args)...);
45 op.callBoxed(&stack);
46 return stack;
47 }
48
49 template<class Result, class... Args>
callOpUnboxed(const c10::OperatorHandle & op,Args...args)50 inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
51 return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
52 }
53
54 template<class Result, class... Args>
callOpUnboxedWithDispatchKey(const c10::OperatorHandle & op,c10::DispatchKey dispatchKey,Args...args)55 inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::DispatchKey dispatchKey, Args... args) {
56 return op.typed<Result(Args...)>().callWithDispatchKey(dispatchKey, std::forward<Args>(args)...);
57 }
58
59 template<class Result, class... Args>
callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle & op,c10::DispatchKeySet ks,Args...args)60 inline Result callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle& op, c10::DispatchKeySet ks, Args... args) {
61 return op.typed<Result(Args...)>().redispatch(ks, std::forward<Args>(args)...);
62 }
63
expectDoesntFindKernel(const char * op_name,c10::DispatchKey dispatch_key)64 inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) {
65 auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
66 EXPECT_ANY_THROW(
67 callOp(*op, dummyTensor(dispatch_key), 5);
68 );
69 }
70
expectDoesntFindOperator(const char * op_name)71 inline void expectDoesntFindOperator(const char* op_name) {
72 auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
73 EXPECT_FALSE(op.has_value());
74 }
75
76 template<class Exception, class Functor>
expectThrows(Functor && functor,const char * expectMessageContains)77 inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
78 try {
79 std::forward<Functor>(functor)();
80 } catch (const Exception& e) {
81 EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
82 return;
83 }
84 ADD_FAILURE() << "Expected to throw exception containing \""
85 << expectMessageContains << "\" but didn't throw";
86 }
87
88 template<class T, size_t N>
expectListEquals(c10::ArrayRef<T> expected,std::array<T,N> actual)89 void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
90 EXPECT_EQ(expected.size(), actual.size());
91 for (const auto i : c10::irange(expected.size())) {
92 EXPECT_EQ(expected[i], actual[i]);
93 }
94 }
95
96 template<class T>
expectListEquals(c10::ArrayRef<T> expected,c10::ArrayRef<T> actual)97 void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
98 EXPECT_EQ(expected.size(), actual.size());
99 for (const auto i : c10::irange(expected.size())) {
100 EXPECT_EQ(expected[i], actual[i]);
101 }
102 }
103
104 template<class T>
expectListEquals(c10::ArrayRef<T> expected,c10::List<T> actual)105 void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
106 EXPECT_EQ(expected.size(), actual.size());
107 for (const auto i : c10::irange(expected.size())) {
108 EXPECT_EQ(expected[i], actual.get(i));
109 }
110 }
111
112 template<class T>
expectListEquals(c10::ArrayRef<T> expected,std::vector<T> actual)113 void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
114 EXPECT_EQ(expected.size(), actual.size());
115 for (const auto i : c10::irange(expected.size())) {
116 EXPECT_EQ(expected[i], actual[i]);
117 }
118 }
119
120 // NB: This is not really sound, but all of the type sets constructed here
121 // are singletons so it's fine
extractDispatchKey(const at::Tensor & t)122 static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
123 return legacyExtractDispatchKey(t.key_set());
124 }
125