1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
16
17 #include "tensorflow/c/eager/gradient_checker.h"
18 #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
19 #include "tensorflow/core/platform/test.h"
20
21 namespace tensorflow {
22 namespace gradients {
23 namespace internal {
24
CompareNumericalAndAutodiffGradients(Model model,Model grad_model,AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,bool use_function,double abs_error)25 void CompareNumericalAndAutodiffGradients(
26 Model model, Model grad_model, AbstractContext* ctx,
27 absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
28 double abs_error) {
29 auto num_inputs = inputs.size();
30 std::vector<AbstractTensorHandle*> outputs(num_inputs);
31 auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
32 /*use_function=*/use_function);
33 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
34
35 for (int i = 0; i < num_inputs; ++i) {
36 if (!outputs[i]) continue;
37
38 AbstractTensorHandlePtr numerical_grad;
39 {
40 AbstractTensorHandle* numerical_grad_raw;
41 s = CalcNumericalGrad(ctx, model, inputs,
42 /*input_index=*/i, use_function,
43 &numerical_grad_raw);
44 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
45 numerical_grad.reset(numerical_grad_raw);
46 }
47
48 TF_Tensor* numerical_tensor;
49 s = GetValue(numerical_grad.get(), &numerical_tensor);
50 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
51 auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
52
53 TF_Tensor* analytical_tensor;
54 s = GetValue(outputs[i], &analytical_tensor);
55 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
56 auto num_elem_analytical = TF_TensorElementCount(analytical_tensor);
57
58 ASSERT_EQ(num_elem_numerical, num_elem_analytical);
59
60 float* dnumerical = new float[num_elem_numerical]{0};
61 memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
62 TF_TensorByteSize(numerical_tensor));
63 float* danalytical = new float[num_elem_analytical]{0};
64 memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
65 TF_TensorByteSize(analytical_tensor));
66
67 for (int j = 0; j < num_elem_numerical; j++) {
68 ASSERT_NEAR(dnumerical[j], danalytical[j], abs_error);
69 }
70 TF_DeleteTensor(analytical_tensor);
71 TF_DeleteTensor(numerical_tensor);
72 delete[] danalytical;
73 delete[] dnumerical;
74 outputs[i]->Unref();
75 }
76 }
77
CheckTensorValue(AbstractTensorHandle * t,absl::Span<const float> manuals,absl::Span<const int64_t> dims,double abs_error)78 void CheckTensorValue(AbstractTensorHandle* t, absl::Span<const float> manuals,
79 absl::Span<const int64_t> dims, double abs_error) {
80 TF_Tensor* analytical_tensor;
81 auto s = GetValue(t, &analytical_tensor);
82 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
83
84 int64_t num_elem_analytical = 1;
85 auto num_dims_analytical = TF_NumDims(analytical_tensor);
86 ASSERT_EQ(dims.size(), num_dims_analytical);
87 for (int j = 0; j < num_dims_analytical; j++) {
88 auto dim_analytical = TF_Dim(analytical_tensor, j);
89 ASSERT_EQ(dims[j], dim_analytical);
90 num_elem_analytical *= dim_analytical;
91 }
92
93 float* danalytical = new float[num_elem_analytical]{0};
94 memcpy(&danalytical[0], TF_TensorData(analytical_tensor),
95 TF_TensorByteSize(analytical_tensor));
96
97 for (int64_t j = 0; j < num_elem_analytical; j++) {
98 if (abs_error == 0) {
99 ASSERT_EQ(manuals[j], danalytical[j]);
100 } else {
101 ASSERT_NEAR(manuals[j], danalytical[j], abs_error);
102 }
103 }
104
105 TF_DeleteTensor(analytical_tensor);
106 delete[] danalytical;
107 }
108
BuildGradModel(Model forward,GradientRegistry registry)109 Model BuildGradModel(Model forward, GradientRegistry registry) {
110 return [forward_model = std::move(forward),
111 grad_registry = std::move(registry)](
112 AbstractContext* ctx,
113 absl::Span<AbstractTensorHandle* const> inputs,
114 absl::Span<AbstractTensorHandle*> outputs) -> Status {
115 Tape tape(/*persistent=*/false);
116 for (size_t i{}; i < inputs.size(); ++i) {
117 tape.Watch(inputs[i]);
118 }
119 std::vector<AbstractTensorHandle*> temp_outputs(1);
120 AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, grad_registry));
121 TF_RETURN_IF_ERROR(
122 forward_model(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs)));
123
124 TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
125 /*sources=*/inputs,
126 /*output_gradients=*/{}, outputs));
127 for (auto temp_output : temp_outputs) {
128 temp_output->Unref();
129 }
130 return OkStatus();
131 };
132 }
133
134 } // namespace internal
135 } // namespace gradients
136 } // namespace tensorflow
137