xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/gradients/grad_test_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
16 
17 #include "tensorflow/c/eager/gradient_checker.h"
18 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
19 #include "tensorflow/core/platform/test.h"
20 
21 namespace tensorflow {
22 namespace gradients {
23 namespace internal {
24 
CompareNumericalAndAutodiffGradients(Model model,Model grad_model,AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,bool use_function,double abs_error)25 void CompareNumericalAndAutodiffGradients(
26     Model model, Model grad_model, AbstractContext* ctx,
27     absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
28     double abs_error) {
29   auto num_inputs = inputs.size();
30   std::vector<AbstractTensorHandle*> outputs(num_inputs);
31   auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
32                     /*use_function=*/use_function);
33   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
34 
35   for (int i = 0; i < num_inputs; ++i) {
36     if (!outputs[i]) continue;
37 
38     AbstractTensorHandlePtr numerical_grad;
39     {
40       AbstractTensorHandle* numerical_grad_raw;
41       s = CalcNumericalGrad(ctx, model, inputs,
42                             /*input_index=*/i, use_function,
43                             &numerical_grad_raw);
44       ASSERT_EQ(errors::OK, s.code()) << s.error_message();
45       numerical_grad.reset(numerical_grad_raw);
46     }
47 
48     TF_Tensor* numerical_tensor;
49     s = GetValue(numerical_grad.get(), &numerical_tensor);
50     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
51     auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
52 
53     TF_Tensor* analytical_tensor;
54     s = GetValue(outputs[i], &analytical_tensor);
55     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
56     auto num_elem_analytical = TF_TensorElementCount(analytical_tensor);
57 
58     ASSERT_EQ(num_elem_numerical, num_elem_analytical);
59 
60     float* dnumerical = new float[num_elem_numerical]{0};
61     memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
62            TF_TensorByteSize(numerical_tensor));
63     float* danalytical = new float[num_elem_analytical]{0};
64     memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
65            TF_TensorByteSize(analytical_tensor));
66 
67     for (int j = 0; j < num_elem_numerical; j++) {
68       ASSERT_NEAR(dnumerical[j], danalytical[j], abs_error);
69     }
70     TF_DeleteTensor(analytical_tensor);
71     TF_DeleteTensor(numerical_tensor);
72     delete[] danalytical;
73     delete[] dnumerical;
74     outputs[i]->Unref();
75   }
76 }
77 
CheckTensorValue(AbstractTensorHandle * t,absl::Span<const float> manuals,absl::Span<const int64_t> dims,double abs_error)78 void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
79                       absl::Span<const int64_t> dims, double abs_error) {
80   TF_Tensor* analytical_tensor;
81   auto s = GetValue(t, &analytical_tensor);
82   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
83 
84   int64_t num_elem_analytical = 1;
85   auto num_dims_analytical = TF_NumDims(analytical_tensor);
86   ASSERT_EQ(dims.size(), num_dims_analytical);
87   for (int j = 0; j < num_dims_analytical; j++) {
88     auto dim_analytical = TF_Dim(analytical_tensor, j);
89     ASSERT_EQ(dims[j], dim_analytical);
90     num_elem_analytical *= dim_analytical;
91   }
92 
93   float* danalytical = new float[num_elem_analytical]{0};
94   memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
95          TF_TensorByteSize(analytical_tensor));
96 
97   for (int64_t j = 0; j < num_elem_analytical; j++) {
98     if (abs_error == 0) {
99       ASSERT_EQ(manuals[j], danalytical[j]);
100     } else {
101       ASSERT_NEAR(manuals[j], danalytical[j], abs_error);
102     }
103   }
104 
105   TF_DeleteTensor(analytical_tensor);
106   delete[] danalytical;
107 }
108 
BuildGradModel(Model forward,GradientRegistry registry)109 Model BuildGradModel(Model forward, GradientRegistry registry) {
110   return [forward_model = std::move(forward),
111           grad_registry = std::move(registry)](
112              AbstractContext* ctx,
113              absl::Span<AbstractTensorHandle* const> inputs,
114              absl::Span<AbstractTensorHandle*> outputs) -> Status {
115     Tape tape(/*persistent=*/false);
116     for (size_t i{}; i < inputs.size(); ++i) {
117       tape.Watch(inputs[i]);
118     }
119     std::vector<AbstractTensorHandle*> temp_outputs(1);
120     AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, grad_registry));
121     TF_RETURN_IF_ERROR(
122         forward_model(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
123 
124     TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
125                                             /*sources=*/inputs,
126                                             /*output_gradients=*/{}, outputs));
127     for (auto temp_output : temp_outputs) {
128       temp_output->Unref();
129     }
130     return OkStatus();
131   };
132 }
133 
134 }  // namespace internal
135 }  // namespace gradients
136 }  // namespace tensorflow
137