1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2017-2018, 2021-2022 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust
25*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER
26*c217d954SCole Faust #define ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER
27*c217d954SCole Faust
28*c217d954SCole Faust #include "arm_compute/core/Error.h"
29*c217d954SCole Faust
30*c217d954SCole Faust #include "tests/RawTensor.h"
31*c217d954SCole Faust #include "tests/SimpleTensor.h"
32*c217d954SCole Faust
33*c217d954SCole Faust #include <iostream>
34*c217d954SCole Faust #include <sstream>
35*c217d954SCole Faust
36*c217d954SCole Faust namespace arm_compute
37*c217d954SCole Faust {
38*c217d954SCole Faust namespace test
39*c217d954SCole Faust {
40*c217d954SCole Faust template <typename T>
41*c217d954SCole Faust inline std::string prettify_tensor(const SimpleTensor<T> &input, const IOFormatInfo &io_fmt = IOFormatInfo{ IOFormatInfo::PrintRegion::NoPadding })
42*c217d954SCole Faust {
43*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(input.data() == nullptr);
44*c217d954SCole Faust
45*c217d954SCole Faust RawTensor tensor(std::move(SimpleTensor<T>(input)));
46*c217d954SCole Faust
47*c217d954SCole Faust TensorInfo info(tensor.shape(), tensor.num_channels(), tensor.data_type());
48*c217d954SCole Faust
49*c217d954SCole Faust const DataType dt = info.data_type();
50*c217d954SCole Faust const size_t slices2D = info.tensor_shape().total_size_upper(2);
51*c217d954SCole Faust const Strides strides = info.strides_in_bytes();
52*c217d954SCole Faust const PaddingSize padding = info.padding();
53*c217d954SCole Faust const size_t num_channels = info.num_channels();
54*c217d954SCole Faust
55*c217d954SCole Faust std::ostringstream os;
56*c217d954SCole Faust
57*c217d954SCole Faust // Set precision
58*c217d954SCole Faust if(is_data_type_float(dt) && (io_fmt.precision_type != IOFormatInfo::PrecisionType::Default))
59*c217d954SCole Faust {
60*c217d954SCole Faust int precision = io_fmt.precision;
61*c217d954SCole Faust if(io_fmt.precision_type == IOFormatInfo::PrecisionType::Full)
62*c217d954SCole Faust {
63*c217d954SCole Faust precision = std::numeric_limits<float>().max_digits10;
64*c217d954SCole Faust }
65*c217d954SCole Faust os.precision(precision);
66*c217d954SCole Faust }
67*c217d954SCole Faust
68*c217d954SCole Faust // Define region to print
69*c217d954SCole Faust size_t print_width = 0;
70*c217d954SCole Faust size_t print_height = 0;
71*c217d954SCole Faust int start_offset = 0;
72*c217d954SCole Faust switch(io_fmt.print_region)
73*c217d954SCole Faust {
74*c217d954SCole Faust case IOFormatInfo::PrintRegion::NoPadding:
75*c217d954SCole Faust print_width = info.dimension(0);
76*c217d954SCole Faust print_height = info.dimension(1);
77*c217d954SCole Faust start_offset = info.offset_first_element_in_bytes();
78*c217d954SCole Faust break;
79*c217d954SCole Faust case IOFormatInfo::PrintRegion::ValidRegion:
80*c217d954SCole Faust print_width = info.valid_region().shape.x();
81*c217d954SCole Faust print_height = info.valid_region().shape.y();
82*c217d954SCole Faust start_offset = info.offset_element_in_bytes(Coordinates(info.valid_region().anchor.x(),
83*c217d954SCole Faust info.valid_region().anchor.y()));
84*c217d954SCole Faust break;
85*c217d954SCole Faust case IOFormatInfo::PrintRegion::Full:
86*c217d954SCole Faust print_width = padding.left + info.dimension(0) + padding.right;
87*c217d954SCole Faust print_height = padding.top + info.dimension(1) + padding.bottom;
88*c217d954SCole Faust start_offset = static_cast<int>(info.offset_first_element_in_bytes()) - padding.top * strides[1] - padding.left * strides[0];
89*c217d954SCole Faust break;
90*c217d954SCole Faust default:
91*c217d954SCole Faust break;
92*c217d954SCole Faust }
93*c217d954SCole Faust
94*c217d954SCole Faust print_width = print_width * num_channels;
95*c217d954SCole Faust
96*c217d954SCole Faust // Set pointer to start
97*c217d954SCole Faust const uint8_t *ptr = tensor.data() + start_offset;
98*c217d954SCole Faust
99*c217d954SCole Faust // Start printing
100*c217d954SCole Faust for(size_t i = 0; i < slices2D; ++i)
101*c217d954SCole Faust {
102*c217d954SCole Faust // Find max_width of elements in slice to align columns
103*c217d954SCole Faust int max_element_width = 0;
104*c217d954SCole Faust if(io_fmt.align_columns)
105*c217d954SCole Faust {
106*c217d954SCole Faust size_t offset = i * strides[2];
107*c217d954SCole Faust for(size_t h = 0; h < print_height; ++h)
108*c217d954SCole Faust {
109*c217d954SCole Faust max_element_width = std::max<int>(max_element_width, max_consecutive_elements_display_width(os, dt, ptr + offset, print_width));
110*c217d954SCole Faust offset += strides[1];
111*c217d954SCole Faust }
112*c217d954SCole Faust }
113*c217d954SCole Faust
114*c217d954SCole Faust // Print slice
115*c217d954SCole Faust {
116*c217d954SCole Faust size_t offset = i * strides[2];
117*c217d954SCole Faust for(size_t h = 0; h < print_height; ++h)
118*c217d954SCole Faust {
119*c217d954SCole Faust print_consecutive_elements(os, dt, ptr + offset, print_width, max_element_width, io_fmt.element_delim);
120*c217d954SCole Faust offset += strides[1];
121*c217d954SCole Faust os << io_fmt.row_delim;
122*c217d954SCole Faust }
123*c217d954SCole Faust os << io_fmt.row_delim;
124*c217d954SCole Faust }
125*c217d954SCole Faust }
126*c217d954SCole Faust
127*c217d954SCole Faust return os.str();
128*c217d954SCole Faust }
129*c217d954SCole Faust
130*c217d954SCole Faust template <typename T>
131*c217d954SCole Faust inline std::ostream &operator<<(std::ostream &os, const SimpleTensor<T> &tensor)
132*c217d954SCole Faust {
133*c217d954SCole Faust os << prettify_tensor(tensor, IOFormatInfo{ IOFormatInfo::PrintRegion::NoPadding });
134*c217d954SCole Faust return os;
135*c217d954SCole Faust }
136*c217d954SCole Faust
137*c217d954SCole Faust template <typename T>
to_string(const SimpleTensor<T> & tensor)138*c217d954SCole Faust inline std::string to_string(const SimpleTensor<T> &tensor)
139*c217d954SCole Faust {
140*c217d954SCole Faust std::stringstream ss;
141*c217d954SCole Faust ss << tensor;
142*c217d954SCole Faust return ss.str();
143*c217d954SCole Faust }
144*c217d954SCole Faust
145*c217d954SCole Faust #if PRINT_TENSOR_LIMIT
146*c217d954SCole Faust template <typename T>
147*c217d954SCole Faust void print_simpletensor(const SimpleTensor<T> &tensor, const std::string &title, const IOFormatInfo::PrintRegion ®ion = IOFormatInfo::PrintRegion::NoPadding)
148*c217d954SCole Faust {
149*c217d954SCole Faust if(tensor.num_elements() < PRINT_TENSOR_LIMIT)
150*c217d954SCole Faust {
151*c217d954SCole Faust std::cout << title << ":" << std::endl;
152*c217d954SCole Faust std::cout << prettify_tensor(tensor, IOFormatInfo{ region });
153*c217d954SCole Faust }
154*c217d954SCole Faust }
155*c217d954SCole Faust #endif // PRINT_TENSOR_LIMIT
156*c217d954SCole Faust } // namespace test
157*c217d954SCole Faust } // namespace arm_compute
158*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_SIMPLE_TENSOR_PRINTER */
159