xref: /aosp_15_r20/external/pytorch/test/cpp/api/inference_mode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/api/support.h>
3 #include <torch/script.h>
4 
5 using namespace torch::autograd;
6 using namespace torch::test;
7 
8 namespace {
functional_op(torch::Tensor & x)9 torch::Tensor functional_op(torch::Tensor& x) {
10   return x * x;
11 }
12 
inplace_op(torch::Tensor & x)13 void inplace_op(torch::Tensor& x) {
14   x.mul_(1);
15 }
16 
view_op(torch::Tensor & x)17 torch::Tensor view_op(torch::Tensor& x) {
18   return x.view({2, 3});
19 }
20 
21 /*
22   Only the following combos of Autograd & ADInplaceOrView keys on tensors are
23   valid:
24     - Autograd=true, ADInplaceOrView=true (normal tensor)
25     - Autograd=false, ADInplaceOrView=false (inference tensor)
26   Tensors created in InferenceMode are mostly inference tensors. The only
27   exception is that view of normal tensors created in InferenceMode still
28   produce normal tensor.
29 */
assert_TLS_states(bool inference_mode)30 void assert_TLS_states(bool inference_mode) {
31   ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
32   ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(
33       c10::DispatchKey::ADInplaceOrView));
34   ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(
35       c10::autograd_dispatch_keyset));
36   ASSERT_EQ(
37       c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset),
38       inference_mode);
39   ASSERT_EQ(
40       c10::impl::tls_is_dispatch_key_included(
41           c10::DispatchKey::ADInplaceOrView),
42       !inference_mode);
43   ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
44 }
45 } // namespace
46 
TEST(InferenceModeTest,TestTLSState)47 TEST(InferenceModeTest, TestTLSState) {
48   assert_TLS_states(false);
49   {
50     InferenceMode guard;
51     assert_TLS_states(true);
52     {
53       InferenceMode guard(false);
54       assert_TLS_states(false);
55     }
56     assert_TLS_states(true);
57   }
58   assert_TLS_states(false);
59 }
60 
TEST(InferenceModeTest,TestInferenceTensorCreation)61 TEST(InferenceModeTest, TestInferenceTensorCreation) {
62   {
63     InferenceMode guard;
64     // New tensor created through constructors are inference tensors.
65     torch::Tensor c = torch::ones({1, 2, 3});
66     ASSERT_FALSE(c.requires_grad());
67     ASSERT_TRUE(c.is_inference());
68 
69     // requires_grad doesn't change inference tensor behavior inside
70     // InferenceMode.
71     torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
72     ASSERT_TRUE(tmp.requires_grad());
73     ASSERT_TRUE(tmp.is_inference());
74 
75     tmp = torch::ones({1, 2, 3}).set_requires_grad(false);
76     ASSERT_FALSE(tmp.requires_grad());
77     ASSERT_TRUE(tmp.is_inference());
78   }
79 }
80 
TEST(InferenceModeTest,TestExistingAutogradSession)81 TEST(InferenceModeTest, TestExistingAutogradSession) {
82   torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
83   torch::Tensor a = s.clone();
84 
85   // Save `a` in an existing autograd session
86   torch::Tensor out = a * a;
87   {
88     InferenceMode guard;
89     inplace_op(a);
90   }
91   // Performing backward should trigger error since `a`'s version has been
92   // bumped.
93   ASSERT_THROWS_WITH(
94       out.backward(torch::ones_like(out)),
95       "one of the variables needed for gradient computation has been modified by an inplace operation")
96 }
97 
TEST(InferenceModeTest,TestInferenceTensorInInferenceModeFunctionalOp)98 TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
99   c10::InferenceMode guard;
100   for (bool requires_grad : {true, false}) {
101     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
102 
103     torch::Tensor func_out = functional_op(c); // go through kernels: CPU
104     ASSERT_TRUE(func_out.is_inference());
105     ASSERT_FALSE(func_out.requires_grad());
106   }
107 }
108 
TEST(InferenceModeTest,TestInferenceTensorInInferenceModeInplaceOp)109 TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
110   c10::InferenceMode guard;
111   for (bool requires_grad : {true, false}) {
112     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
113 
114     inplace_op(c); // go through kernels: CPU
115     ASSERT_TRUE(c.is_inference());
116     ASSERT_EQ(c.requires_grad(), requires_grad);
117   }
118 }
119 
TEST(InferenceModeTest,TestInferenceTensorInInferenceModeViewOp)120 TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
121   c10::InferenceMode guard;
122   for (bool requires_grad : {true, false}) {
123     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
124 
125     torch::Tensor view_out = view_op(c); // go through kernels: CPU
126     ASSERT_TRUE(view_out.is_inference());
127     // Note this is different from NoGradMode but makes sense.
128     ASSERT_FALSE(view_out.requires_grad());
129     ASSERT_FALSE(view_out.is_view());
130   }
131 }
132 
TEST(InferenceModeTest,TestInferenceTensorInNormalModeFunctionalOp)133 TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
134   torch::Tensor inference_tensor;
135   for (bool requires_grad : {true, false}) {
136     {
137       InferenceMode guard;
138       inference_tensor =
139           torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
140     }
141 
142     // Due to issue #54614, this might run slower compared to InferenceMode
143     // since intermediate tensors are normal tensors, and they might dispatch to
144     // VariableType kernels. This is fine since users can easily fix it by
145     // moving it inside InferenceMode block.
146     torch::Tensor tmp =
147         functional_op(inference_tensor); // go through kernels:
148                                          // ADInplaceOrView(fallthrough), CPU
149     ASSERT_FALSE(tmp.is_inference());
150     ASSERT_FALSE(tmp.requires_grad());
151   }
152 }
153 
TEST(InferenceModeTest,TestInferenceTensorInNormalModeInplaceOp)154 TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
155   torch::Tensor inference_tensor;
156   for (bool requires_grad : {true, false}) {
157     {
158       InferenceMode guard;
159       inference_tensor =
160           torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
161     }
162     ASSERT_THROWS_WITH(
163         inplace_op(
164             inference_tensor), // go through kernels: ADInplaceOrView, CPU
165         "Inplace update to inference tensor outside InferenceMode is not allowed");
166   }
167 }
168 
TEST(InferenceModeTest,TestInferenceTensorInNormalModeViewOp)169 TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
170   torch::Tensor inference_tensor;
171   for (bool requires_grad : {true, false}) {
172     {
173       InferenceMode guard;
174       inference_tensor =
175           torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
176     }
177     torch::Tensor out =
178         view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
179     ASSERT_TRUE(out.is_inference());
180     ASSERT_FALSE(out.requires_grad());
181     ASSERT_FALSE(out.is_view());
182     ASSERT_TRUE(out.is_leaf());
183   }
184 }
185 
TEST(InferenceModeTest,TestNormalTensorInplaceOutputInInferenceMode)186 TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
187   for (bool requires_grad : {true, false}) {
188     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
189     torch::Tensor a = s.clone();
190 
191     {
192       c10::InferenceMode guard;
193 
194       inplace_op(a); // go through kernels: ADInplaceOrView, CPU
195       ASSERT_FALSE(a.is_inference());
196       ASSERT_EQ(a.requires_grad(), requires_grad);
197 
198       // inplace -> inplace
199       inplace_op(a); // go through kernels: ADInplaceOrView, CPU
200       ASSERT_FALSE(a.is_inference());
201       ASSERT_EQ(a.requires_grad(), requires_grad);
202 
203       // inplace -> inplace -> view
204       torch::Tensor view_out =
205           view_op(a); // go through kernels: ADInplaceOrView, CPU
206       ASSERT_FALSE(view_out.is_inference());
207       ASSERT_EQ(view_out.requires_grad(), requires_grad);
208     }
209   }
210 }
211 
TEST(InferenceModeTest,TestNormalTensorInplaceOutputInNormalMode)212 TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
213   for (bool requires_grad : {true, false}) {
214     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
215     torch::Tensor a = s.clone();
216 
217     {
218       c10::InferenceMode guard;
219 
220       inplace_op(a); // go through kernels: ADInplaceOrView, CPU
221       ASSERT_FALSE(a.is_inference());
222       ASSERT_EQ(a.requires_grad(), requires_grad);
223     }
224 
225     torch::Tensor tmp = functional_op(a); // go through kernels: VariableType,
226                                           // ADInplaceOrView(fallthrough), CPU
227     ASSERT_FALSE(tmp.is_inference());
228     ASSERT_EQ(tmp.requires_grad(), requires_grad);
229 
230     inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
231     ASSERT_FALSE(a.is_inference());
232     ASSERT_EQ(a.requires_grad(), requires_grad);
233 
234     tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
235     ASSERT_FALSE(tmp.is_inference());
236     ASSERT_EQ(tmp.requires_grad(), requires_grad);
237   }
238 }
239 
TEST(InferenceModeTest,TestNormalTensorViewOutputInInferenceMode)240 TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
241   for (bool requires_grad : {true, false}) {
242     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
243     torch::Tensor a = s.clone();
244     torch::Tensor view_out, tmp;
245 
246     {
247       c10::InferenceMode guard;
248       // View ops on normal tensor produce normal tensors as output.
249       // - For view ops it has both dispatch keys since due to the way we create
250       //   view Tensors in alias_with_sizes_and_strides:
251       //   ```
252       //     auto impl = c10::make_intrusive<TensorImpl>(
253       //     Storage(self.storage()), self.key_set(), self.dtype());
254       //   ```
255       //   In addition, these view output tensors are normal in the sense they
256       //   have both Autograd and ADInplaceOrView keys. But they're still
257       //   special since they'll have CreationMeta::INFERENCE_MODE. In other
258       //   words they behave exactly the same as a view tensor created in
259       //   no_grad mode.
260 
261       view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
262       ASSERT_FALSE(view_out.is_inference());
263       assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
264       ASSERT_EQ(view_out.requires_grad(), requires_grad);
265       ASSERT_TRUE(view_out.is_leaf());
266 
267       // view -> view
268       tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
269       ASSERT_FALSE(tmp.is_inference());
270       assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
271       ASSERT_EQ(tmp.requires_grad(), requires_grad);
272       ASSERT_TRUE(tmp.is_leaf());
273 
274       // view -> view -> inplace
275       inplace_op(tmp); // kernels: ADInplaceOrView, CPU
276       assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
277       ASSERT_FALSE(tmp.is_inference());
278       ASSERT_EQ(tmp.requires_grad(), requires_grad);
279       ASSERT_TRUE(tmp.is_leaf());
280       ASSERT_EQ(a._version(), tmp._version());
281     }
282   }
283 }
284 
TEST(InferenceModeTest,TestNormalTensorViewOutputInNormalMode)285 TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
286   for (bool requires_grad : {true, false}) {
287     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
288     torch::Tensor a = s.clone();
289     torch::Tensor view_out, tmp;
290 
291     {
292       c10::InferenceMode guard;
293       view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
294       ASSERT_FALSE(view_out.is_inference());
295       assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
296       ASSERT_EQ(view_out.requires_grad(), requires_grad);
297       ASSERT_TRUE(view_out.is_leaf());
298     }
299 
300     tmp = functional_op(view_out);
301     ASSERT_FALSE(view_out.is_inference());
302     ASSERT_EQ(tmp.requires_grad(), requires_grad);
303 
304     if (requires_grad) {
305       ASSERT_THROWS_WITH(
306           inplace_op(view_out), // go through kernels: VariableType,
307                                 // ADInplaceOrView, CPU
308           "A view was created in inference mode and is being modified inplace")
309     } else {
310       inplace_op(view_out);
311     }
312 
313     tmp = view_op(view_out);
314     ASSERT_FALSE(view_out.is_inference());
315     ASSERT_EQ(tmp.requires_grad(), requires_grad);
316   }
317 }
318 
TEST(InferenceModeTest,TestMixInferenceAndNormalTensorFunctionalOp)319 TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
320   for (bool requires_grad : {true, false}) {
321     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
322     torch::Tensor c;
323     {
324       InferenceMode guard;
325       c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
326     }
327 
328     // add(Tensor, Tensor) is safe with inference tensor since it doesn't save
329     // any variable for backward.
330     torch::Tensor out = c.add(s); // go through kernels: VariableType,
331                                   // ADInplaceOrView(fallthrough), CPU
332     ASSERT_FALSE(out.is_inference());
333     ASSERT_EQ(out.requires_grad(), requires_grad);
334     if (requires_grad) {
335       // leaf inference tensor with requires_grad=true can still have gradient.
336       // Note this behavior is different from NoGradMode which has empty grad.
337       out.backward(torch::ones_like(out));
338       assert_tensor_equal(c.grad(), torch::ones_like(c));
339     }
340 
341     if (requires_grad) {
342       // mul(self, other) saves variable when requires_grad=true
343       ASSERT_THROWS_WITH(
344           c.mul(s), "Inference tensors cannot be saved for backward.");
345 
346       // Inference tensor in TensorList input
347       // stack does not capture anymore, so disabled
348       // TODO: find alternative Function that captures a list (maybe custom fn)
349       /*
350       std::vector<torch::Tensor> inputs = {s, c};
351       ASSERT_THROWS_WITH(
352           torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
353                                 // ADInplaceOrView(fallthrough), CPU
354           "Inference tensors cannot be saved for backward.")
355       */
356     }
357   }
358 }
359 
TEST(InferenceModeTest,TestMixInferenceAndNormalTensorInplaceOp)360 TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
361   for (bool requires_grad : {true, false}) {
362     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
363     torch::Tensor a = s.clone();
364     torch::Tensor c;
365     {
366       InferenceMode guard;
367       c = torch::ones({1, 2, 3});
368     }
369 
370     if (requires_grad) {
371       ASSERT_THROWS_WITH(
372           a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode,
373                      // CPU
374           "Inference tensors cannot be saved for backward.");
375 
376       ASSERT_THROWS_WITH(
377           torch::mul_out(
378               /*out=*/c, s, s), // go through kernels: VariableType(ERROR!),
379                                 // ADInplaceOrView, CPU
380           "out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
381     } else {
382       a.mul_(c);
383 
384       ASSERT_THROWS_WITH(
385           torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType,
386                                            // ADInplaceOrView(ERROR!), CPU
387           "Inplace update to inference tensor outside InferenceMode is not allowed");
388     }
389   }
390 }
391 
TEST(InferenceModeTest,TestMixInferenceAndNormalTensorViewOp)392 TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
393   for (bool requires_grad : {true, false}) {
394     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
395     torch::Tensor c;
396     {
397       InferenceMode guard;
398       c = torch::ones({1, 2, 3});
399     }
400 
401     // view_as is a composite op which calls view() with only one tensor
402     // argument. So there isn't a mixed inference tensor and normal tensor
403     // inputs for view ops.
404     torch::Tensor tmp1 =
405         c.view_as(s); // go through kernels: ADInplaceOrView, CPU
406     ASSERT_TRUE(tmp1.is_inference());
407     ASSERT_FALSE(tmp1.requires_grad());
408 
409     // This is fine since it's equivalent as s.view(c.sizes()) which
410     // isn't a mixed input scenario.
411     torch::Tensor tmp2 =
412         s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
413     ASSERT_FALSE(tmp2.is_inference());
414     ASSERT_EQ(tmp2.requires_grad(), requires_grad);
415   }
416 }
417 
TEST(InferenceModeTest,TestHandleDirectViewOnRebase)418 TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
419   for (bool requires_grad : {true, false}) {
420     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
421     torch::Tensor a = s.clone();
422     torch::Tensor view_out;
423     {
424       InferenceMode guard;
425       view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
426     }
427     if (requires_grad) {
428       ASSERT_THROWS_WITH(
429           inplace_op(view_out),
430           "A view was created in inference mode and is being modified inplace")
431     } else {
432       inplace_op(view_out);
433     }
434   }
435 }
436 
TEST(InferenceModeTest,TestHandleInDirectViewOnRebase)437 TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
438   for (bool requires_grad : {true, false}) {
439     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
440     torch::Tensor a = s.clone();
441     torch::Tensor view_out;
442     {
443       InferenceMode guard;
444       view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
445     }
446     inplace_op(a);
447     if (requires_grad) {
448       ASSERT_THROWS_WITH(
449           view_out.grad_fn(),
450           "A view was created in inference mode and its base or another view of its base has been modified inplace");
451     } else {
452       view_out.grad_fn();
453     }
454   }
455 }
456 
TEST(InferenceModeTest,TestCreationMetaPropagation)457 TEST(InferenceModeTest, TestCreationMetaPropagation) {
458   torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
459   torch::Tensor b, c;
460   {
461     InferenceMode guard;
462     b = s.view_as(s);
463   }
464   ASSERT_THROWS_WITH(
465       b.add_(1),
466       "A view was created in inference mode and is being modified inplace");
467   {
468     AutoGradMode mode(false);
469     c = b.view_as(b);
470   }
471   ASSERT_THROWS_WITH(
472       c.add_(1),
473       "A view was created in inference mode and is being modified inplace");
474 }
475 
TEST(InferenceModeTest,TestCreationMetaPropagationInput)476 TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
477   torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true);
478   auto s_view = s.view_as(s);
479   std::vector<at::Tensor> b, c;
480   {
481     InferenceMode guard;
482     b = s_view.split_with_sizes({1, 1});
483 
484     s = s.view_as(s);
485     c = s.split_with_sizes({1, 1});
486   }
487   for (auto& b_el : b) {
488     assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
489     ASSERT_THROWS_WITH(
490         b_el.add_(1),
491         "A view was created in inference mode and is being modified inplace");
492   }
493   for (auto& c_el : c) {
494     assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
495     ASSERT_THROWS_WITH(
496         c_el.add_(1),
497         "A view was created in inference mode and is being modified inplace");
498   }
499 }
500 
TEST(InferenceModeTest,TestInplaceCopyOnInferenceTensor)501 TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
502   for (bool requires_grad : {true, false}) {
503     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
504     torch::Tensor t;
505     {
506       InferenceMode guard;
507       t = torch::ones({1, 2, 3});
508       t.copy_(s);
509       ASSERT_TRUE(t.is_inference());
510       ASSERT_FALSE(t.requires_grad());
511     }
512 
513     ASSERT_THROWS_WITH(
514         t.copy_(s),
515         "Inplace update to inference tensor outside InferenceMode is not allowed");
516   }
517 }
518 
TEST(InferenceModeTest,TestSetRequiresGradInNormalMode)519 TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
520   torch::Tensor t;
521   {
522     InferenceMode guard;
523     t = torch::ones({1, 2, 3});
524   }
525   t.set_requires_grad(false);
526   ASSERT_THROWS_WITH(
527       t.set_requires_grad(true),
528       "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
529 }
530 
TEST(InferenceModeTest,TestAccessVersionCounter)531 TEST(InferenceModeTest, TestAccessVersionCounter) {
532   torch::Tensor t;
533   {
534     InferenceMode guard;
535     t = torch::ones({1, 2, 3});
536     ASSERT_THROWS_WITH(
537         t.unsafeGetTensorImpl()->version_counter().current_version(),
538         "Inference tensors do not track version counter.");
539     t.unsafeGetTensorImpl()->bump_version();
540   }
541   ASSERT_THROWS_WITH(
542       t.unsafeGetTensorImpl()->version_counter().current_version(),
543       "Inference tensors do not track version counter.");
544   ASSERT_THROWS_WITH(
545       t.unsafeGetTensorImpl()->bump_version(),
546       "Inplace update to inference tensor outside InferenceMode is not allowed.");
547   // Suggested workaround
548   torch::Tensor c = t.clone();
549   uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
550   c.unsafeGetTensorImpl()->bump_version();
551   ASSERT_EQ(
552       c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
553 }
554 
TEST(InferenceModeTest,TestInplaceUpdateInferenceTensorWithNormalTensor)555 TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
556   torch::Tensor s = torch::ones({1, 2, 3});
557   torch::Tensor t;
558   {
559     InferenceMode guard;
560     t = torch::ones({1, 2, 3});
561     // Testing both copy_ from VariableTypeManual and add_ from generated code.
562     s.copy_(t);
563     s.add_(t);
564     t.add_(s);
565     t.copy_(s);
566   }
567   s.copy_(t);
568   s.add_(t);
569   ASSERT_THROWS_WITH(
570       t.copy_(s),
571       "Inplace update to inference tensor outside InferenceMode is not allowed");
572 
573   ASSERT_THROWS_WITH(
574       t.add_(s),
575       "Inplace update to inference tensor outside InferenceMode is not allowed");
576 }
577 
TEST(InferenceModeTest,TestComplexViewInInferenceMode)578 TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
579   torch::Tensor s = torch::ones({3, 3, 2});
580   torch::Tensor t = torch::view_as_complex(s);
581   {
582     InferenceMode guard;
583     torch::Tensor tmp;
584 
585     tmp = torch::view_as_real(t);
586     ASSERT_FALSE(tmp.is_inference());
587     tmp = torch::view_as_complex(s);
588     ASSERT_FALSE(tmp.is_inference());
589 
590     torch::Tensor e = torch::ones({3, 3, 2});
591     tmp = torch::view_as_complex(e);
592     ASSERT_TRUE(tmp.is_inference());
593     tmp = torch::view_as_real(tmp);
594     ASSERT_TRUE(tmp.is_inference());
595   }
596 }
597 
TEST(InferenceModeTest,TestComplexViewInNormalMode)598 TEST(InferenceModeTest, TestComplexViewInNormalMode) {
599   torch::Tensor s;
600   {
601     InferenceMode guard;
602     s = torch::ones({3, 3, 2});
603   }
604   torch::Tensor tmp = torch::view_as_complex(s);
605   ASSERT_TRUE(tmp.is_inference());
606   tmp = torch::view_as_real(tmp);
607   ASSERT_TRUE(tmp.is_inference());
608 }
609 
TEST(InferenceModeTest,TestCustomFunction)610 TEST(InferenceModeTest, TestCustomFunction) {
611   struct MyFunction : public Function<MyFunction> {
612     static Variable forward(
613         AutogradContext* ctx,
614         Variable var1,
615         int mul,
616         Variable var2) {
617       ctx->saved_data["mul"] = mul;
618       ctx->save_for_backward({var1, var2});
619       return var1 + mul * var2 + var1 * var2;
620     }
621 
622     static variable_list backward(
623         AutogradContext* ctx,
624         variable_list grad_output) {
625       int mul = ctx->saved_data["mul"].toInt();
626       auto saved = ctx->get_saved_variables();
627       auto var1 = saved[0];
628       auto var2 = saved[1];
629       variable_list output = {
630           grad_output[0] + grad_output[0] * var2,
631           Variable(),
632           grad_output[0] * mul + grad_output[0] * var1};
633       return output;
634     }
635   };
636 
637   {
638     InferenceMode guard;
639     torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
640     auto var2 = var1.clone();
641     int mul = 2;
642     // If InferenceMode didn't set NoGradGuard automatically, this line
643     // would error out when trying to save `var1` and `var2` for backward.
644     auto y = MyFunction::apply(var1, mul, var2);
645     torch::Tensor expected = var1 + mul * var2 + var1 * var2;
646     assert_tensor_equal(y, expected);
647   }
648 }
649 
TEST(InferenceModeTest,TestLegacyAutoNonVariableTypeModeWarning)650 TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
651   c10::WarningUtils::WarnAlways warn_always(true);
652   WarningCapture warnings;
653   at::AutoNonVariableTypeMode guard;
654   ASSERT_TRUE(
655       warnings.str().find("AutoNonVariableTypeMode is deprecated") !=
656       std::string::npos);
657 }
658