xref: /aosp_15_r20/external/ComputeLibrary/include/libnpy/npy.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2    Copyright 2017 Leon Merten Lohse
3 
4    Permission is hereby granted, free of charge, to any person obtaining a copy
5    of this software and associated documentation files (the "Software"), to deal
6    in the Software without restriction, including without limitation the rights
7    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8    copies of the Software, and to permit persons to whom the Software is
9    furnished to do so, subject to the following conditions:
10 
11    The above copyright notice and this permission notice shall be included in
12    all copies or substantial portions of the Software.
13 
14    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20    SOFTWARE.
21 */
22 
23 #ifndef NPY_HPP_
24 #define NPY_HPP_
25 
26 #include <complex>
27 #include <fstream>
28 #include <string>
29 #include <iostream>
30 #include <sstream>
31 #include <cstdint>
32 #include <cstring>
33 #include <array>
34 #include <vector>
35 #include <stdexcept>
36 #include <algorithm>
37 #include <unordered_map>
38 #include <type_traits>
39 #include <typeinfo>
40 #include <typeindex>
41 #include <iterator>
42 #include <utility>
43 
44 
45 namespace npy {
46 
47 /* Compile-time test for byte order.
48    If your compiler does not define these per default, you may want to define
49    one of these constants manually.
50    Defaults to little endian order. */
51 #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
52     defined(__BIG_ENDIAN__) || \
53     defined(__ARMEB__) || \
54     defined(__THUMBEB__) || \
55     defined(__AARCH64EB__) || \
56     defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
57 const bool big_endian = true;
58 #else
59 const bool big_endian = false;
60 #endif
61 
62 
63 const char magic_string[] = "\x93NUMPY";
64 const size_t magic_string_length = 6;
65 
66 const char little_endian_char = '<';
67 const char big_endian_char = '>';
68 const char no_endian_char = '|';
69 
70 constexpr std::array<char, 3>
71 endian_chars = {little_endian_char, big_endian_char, no_endian_char};
72 constexpr std::array<char, 4>
73 numtype_chars = {'f', 'i', 'u', 'c'};
74 
75 constexpr char host_endian_char = (big_endian ?
76                                    big_endian_char :
77                                    little_endian_char);
78 
79 /* npy array length */
80 typedef unsigned long int ndarray_len_t;
81 
82 typedef std::pair<char, char> version_t;
83 
84 struct dtype_t {
85   const char byteorder;
86   const char kind;
87   const unsigned int itemsize;
88 
89 // TODO(llohse): implement as constexpr
strnpy::dtype_t90   inline std::string str() const {
91     const size_t max_buflen = 16;
92     char buf[max_buflen];
93     std::snprintf(buf, max_buflen, "%c%c%u", byteorder, kind, itemsize);
94     return std::string(buf);
95   }
96 
tienpy::dtype_t97   inline std::tuple<const char, const char, const unsigned int> tie() const {
98     return std::tie(byteorder, kind, itemsize);
99   }
100 };
101 
102 
103 struct header_t {
104   const dtype_t dtype;
105   const bool fortran_order;
106   const std::vector <ndarray_len_t> shape;
107 };
108 
write_magic(std::ostream & ostream,version_t version)109 inline void write_magic(std::ostream &ostream, version_t version) {
110   ostream.write(magic_string, magic_string_length);
111   ostream.put(version.first);
112   ostream.put(version.second);
113 }
114 
read_magic(std::istream & istream)115 inline version_t read_magic(std::istream &istream) {
116   char buf[magic_string_length + 2];
117   istream.read(buf, magic_string_length + 2);
118 
119   if (!istream) {
120     throw std::runtime_error("io error: failed reading file");
121   }
122 
123   if (0 != std::memcmp(buf, magic_string, magic_string_length))
124     throw std::runtime_error("this file does not have a valid npy format.");
125 
126   version_t version;
127   version.first = buf[magic_string_length];
128   version.second = buf[magic_string_length + 1];
129 
130   return version;
131 }
132 
133 const std::unordered_map<std::type_index, dtype_t> dtype_map = {
134   {std::type_index(typeid(float)), {host_endian_char, 'f', sizeof(float)}},
135   {std::type_index(typeid(double)), {host_endian_char, 'f', sizeof(double)}},
136   {std::type_index(typeid(long double)), {host_endian_char, 'f', sizeof(long double)}},
137   {std::type_index(typeid(char)), {no_endian_char, 'i', sizeof(char)}},
138   {std::type_index(typeid(signed char)), {no_endian_char, 'i', sizeof(signed char)}},
139   {std::type_index(typeid(short)), {host_endian_char, 'i', sizeof(short)}},
140   {std::type_index(typeid(int)), {host_endian_char, 'i', sizeof(int)}},
141   {std::type_index(typeid(long)), {host_endian_char, 'i', sizeof(long)}},
142   {std::type_index(typeid(long long)), {host_endian_char, 'i', sizeof(long long)}},
143   {std::type_index(typeid(unsigned char)), {no_endian_char, 'u', sizeof(unsigned char)}},
144   {std::type_index(typeid(unsigned short)), {host_endian_char, 'u', sizeof(unsigned short)}},
145   {std::type_index(typeid(unsigned int)), {host_endian_char, 'u', sizeof(unsigned int)}},
146   {std::type_index(typeid(unsigned long)), {host_endian_char, 'u', sizeof(unsigned long)}},
147   {std::type_index(typeid(unsigned long long)), {host_endian_char, 'u', sizeof(unsigned long long)}},
148   {std::type_index(typeid(std::complex<float>)), {host_endian_char, 'c', sizeof(std::complex<float>)}},
149   {std::type_index(typeid(std::complex<double>)), {host_endian_char, 'c', sizeof(std::complex<double>)}},
150   {std::type_index(typeid(std::complex<long double>)), {host_endian_char, 'c', sizeof(std::complex<long double>)}}
151 };
152 
153 
154 // helpers
is_digits(const std::string & str)155 inline bool is_digits(const std::string &str) {
156   return std::all_of(str.begin(), str.end(), ::isdigit);
157 }
158 
159 template<typename T, size_t N>
in_array(T val,const std::array<T,N> & arr)160 inline bool in_array(T val, const std::array <T, N> &arr) {
161   return std::find(std::begin(arr), std::end(arr), val) != std::end(arr);
162 }
163 
parse_descr(std::string typestring)164 inline dtype_t parse_descr(std::string typestring) {
165   if (typestring.length() < 3) {
166     throw std::runtime_error("invalid typestring (length)");
167   }
168 
169   char byteorder_c = typestring.at(0);
170   char kind_c = typestring.at(1);
171   std::string itemsize_s = typestring.substr(2);
172 
173   if (!in_array(byteorder_c, endian_chars)) {
174     throw std::runtime_error("invalid typestring (byteorder)");
175   }
176 
177   if (!in_array(kind_c, numtype_chars)) {
178     throw std::runtime_error("invalid typestring (kind)");
179   }
180 
181   if (!is_digits(itemsize_s)) {
182     throw std::runtime_error("invalid typestring (itemsize)");
183   }
184   unsigned int itemsize = std::stoul(itemsize_s);
185 
186   return {byteorder_c, kind_c, itemsize};
187 }
188 
189 namespace pyparse {
190 
191 /**
192   Removes leading and trailing whitespaces
193   */
trim(const std::string & str)194 inline std::string trim(const std::string &str) {
195   const std::string whitespace = " \t";
196   auto begin = str.find_first_not_of(whitespace);
197 
198   if (begin == std::string::npos)
199     return "";
200 
201   auto end = str.find_last_not_of(whitespace);
202 
203   return str.substr(begin, end - begin + 1);
204 }
205 
206 
get_value_from_map(const std::string & mapstr)207 inline std::string get_value_from_map(const std::string &mapstr) {
208   size_t sep_pos = mapstr.find_first_of(":");
209   if (sep_pos == std::string::npos)
210     return "";
211 
212   std::string tmp = mapstr.substr(sep_pos + 1);
213   return trim(tmp);
214 }
215 
216 /**
217    Parses the string representation of a Python dict
218 
219    The keys need to be known and may not appear anywhere else in the data.
220  */
parse_dict(std::string in,const std::vector<std::string> & keys)221 inline std::unordered_map <std::string, std::string> parse_dict(std::string in, const std::vector <std::string> &keys) {
222   std::unordered_map <std::string, std::string> map;
223 
224   if (keys.size() == 0)
225     return map;
226 
227   in = trim(in);
228 
229   // unwrap dictionary
230   if ((in.front() == '{') && (in.back() == '}'))
231     in = in.substr(1, in.length() - 2);
232   else
233     throw std::runtime_error("Not a Python dictionary.");
234 
235   std::vector <std::pair<size_t, std::string>> positions;
236 
237   for (auto const &value : keys) {
238     size_t pos = in.find("'" + value + "'");
239 
240     if (pos == std::string::npos)
241       throw std::runtime_error("Missing '" + value + "' key.");
242 
243     std::pair <size_t, std::string> position_pair{pos, value};
244     positions.push_back(position_pair);
245   }
246 
247   // sort by position in dict
248   std::sort(positions.begin(), positions.end());
249 
250   for (size_t i = 0; i < positions.size(); ++i) {
251     std::string raw_value;
252     size_t begin{positions[i].first};
253     size_t end{std::string::npos};
254 
255     std::string key = positions[i].second;
256 
257     if (i + 1 < positions.size())
258       end = positions[i + 1].first;
259 
260     raw_value = in.substr(begin, end - begin);
261 
262     raw_value = trim(raw_value);
263 
264     if (raw_value.back() == ',')
265       raw_value.pop_back();
266 
267     map[key] = get_value_from_map(raw_value);
268   }
269 
270   return map;
271 }
272 
273 /**
274   Parses the string representation of a Python boolean
275   */
parse_bool(const std::string & in)276 inline bool parse_bool(const std::string &in) {
277   if (in == "True")
278     return true;
279   if (in == "False")
280     return false;
281 
282   throw std::runtime_error("Invalid python boolan.");
283 }
284 
285 /**
286   Parses the string representation of a Python str
287   */
parse_str(const std::string & in)288 inline std::string parse_str(const std::string &in) {
289   if ((in.front() == '\'') && (in.back() == '\''))
290     return in.substr(1, in.length() - 2);
291 
292   throw std::runtime_error("Invalid python string.");
293 }
294 
295 /**
296   Parses the string represenatation of a Python tuple into a vector of its items
297  */
parse_tuple(std::string in)298 inline std::vector <std::string> parse_tuple(std::string in) {
299   std::vector <std::string> v;
300   const char seperator = ',';
301 
302   in = trim(in);
303 
304   if ((in.front() == '(') && (in.back() == ')'))
305     in = in.substr(1, in.length() - 2);
306   else
307     throw std::runtime_error("Invalid Python tuple.");
308 
309   std::istringstream iss(in);
310 
311   for (std::string token; std::getline(iss, token, seperator);) {
312     v.push_back(token);
313   }
314 
315   return v;
316 }
317 
318 template<typename T>
write_tuple(const std::vector<T> & v)319 inline std::string write_tuple(const std::vector <T> &v) {
320   if (v.size() == 0)
321     return "()";
322 
323   std::ostringstream ss;
324 
325   if (v.size() == 1) {
326     ss << "(" << v.front() << ",)";
327   } else {
328     const std::string delimiter = ", ";
329     // v.size() > 1
330     ss << "(";
331     std::copy(v.begin(), v.end() - 1, std::ostream_iterator<T>(ss, delimiter.c_str()));
332     ss << v.back();
333     ss << ")";
334   }
335 
336   return ss.str();
337 }
338 
write_boolean(bool b)339 inline std::string write_boolean(bool b) {
340   if (b)
341     return "True";
342   else
343     return "False";
344 }
345 
346 }  // namespace pyparse
347 
348 
parse_header(std::string header)349 inline header_t parse_header(std::string header) {
350   /*
351      The first 6 bytes are a magic string: exactly "x93NUMPY".
352      The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
353      The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
354      The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
355      The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
356      The dictionary contains three keys:
357 
358      "descr" : dtype.descr
359      An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
360      "fortran_order" : bool
361      Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
362      "shape" : tuple of int
363      The shape of the array.
364      For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
365    */
366 
367   // remove trailing newline
368   if (header.back() != '\n')
369     throw std::runtime_error("invalid header");
370   header.pop_back();
371 
372   // parse the dictionary
373   std::vector <std::string> keys{"descr", "fortran_order", "shape"};
374   auto dict_map = npy::pyparse::parse_dict(header, keys);
375 
376   if (dict_map.size() == 0)
377     throw std::runtime_error("invalid dictionary in header");
378 
379   std::string descr_s = dict_map["descr"];
380   std::string fortran_s = dict_map["fortran_order"];
381   std::string shape_s = dict_map["shape"];
382 
383   std::string descr = npy::pyparse::parse_str(descr_s);
384   dtype_t dtype = parse_descr(descr);
385 
386   // convert literal Python bool to C++ bool
387   bool fortran_order = npy::pyparse::parse_bool(fortran_s);
388 
389   // parse the shape tuple
390   auto shape_v = npy::pyparse::parse_tuple(shape_s);
391 
392   std::vector <ndarray_len_t> shape;
393   for (auto item : shape_v) {
394     ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item));
395     shape.push_back(dim);
396   }
397 
398   return {dtype, fortran_order, shape};
399 }
400 
401 
402 inline std::string
write_header_dict(const std::string & descr,bool fortran_order,const std::vector<ndarray_len_t> & shape)403 write_header_dict(const std::string &descr, bool fortran_order, const std::vector <ndarray_len_t> &shape) {
404   std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
405   std::string shape_s = npy::pyparse::write_tuple(shape);
406 
407   return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
408 }
409 
write_header(std::ostream & out,const header_t & header)410 inline void write_header(std::ostream &out, const header_t &header) {
411   std::string header_dict = write_header_dict(header.dtype.str(), header.fortran_order, header.shape);
412 
413   size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
414 
415   version_t version{1, 0};
416   if (length >= 255 * 255) {
417     length = magic_string_length + 2 + 4 + header_dict.length() + 1;
418     version = {2, 0};
419   }
420   size_t padding_len = 16 - length % 16;
421   std::string padding(padding_len, ' ');
422 
423   // write magic
424   write_magic(out, version);
425 
426   // write header length
427   if (version == version_t{1, 0}) {
428     uint8_t header_len_le16[2];
429     uint16_t header_len = static_cast<uint16_t>(header_dict.length() + padding.length() + 1);
430 
431     header_len_le16[0] = (header_len >> 0) & 0xff;
432     header_len_le16[1] = (header_len >> 8) & 0xff;
433     out.write(reinterpret_cast<char *>(header_len_le16), 2);
434   } else {
435     uint8_t header_len_le32[4];
436     uint32_t header_len = static_cast<uint32_t>(header_dict.length() + padding.length() + 1);
437 
438     header_len_le32[0] = (header_len >> 0) & 0xff;
439     header_len_le32[1] = (header_len >> 8) & 0xff;
440     header_len_le32[2] = (header_len >> 16) & 0xff;
441     header_len_le32[3] = (header_len >> 24) & 0xff;
442     out.write(reinterpret_cast<char *>(header_len_le32), 4);
443   }
444 
445   out << header_dict << padding << '\n';
446 }
447 
read_header(std::istream & istream)448 inline std::string read_header(std::istream &istream) {
449   // check magic bytes an version number
450   version_t version = read_magic(istream);
451 
452   uint32_t header_length;
453   if (version == version_t{1, 0}) {
454     uint8_t header_len_le16[2];
455     istream.read(reinterpret_cast<char *>(header_len_le16), 2);
456     header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
457 
458     if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
459       // TODO(llohse): display warning
460     }
461   } else if (version == version_t{2, 0}) {
462     uint8_t header_len_le32[4];
463     istream.read(reinterpret_cast<char *>(header_len_le32), 4);
464 
465     header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
466                     | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
467 
468     if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
469       // TODO(llohse): display warning
470     }
471   } else {
472     throw std::runtime_error("unsupported file format version");
473   }
474 
475   auto buf_v = std::vector<char>(header_length);
476   istream.read(buf_v.data(), header_length);
477   std::string header(buf_v.data(), header_length);
478 
479   return header;
480 }
481 
comp_size(const std::vector<ndarray_len_t> & shape)482 inline ndarray_len_t comp_size(const std::vector <ndarray_len_t> &shape) {
483   ndarray_len_t size = 1;
484   for (ndarray_len_t i : shape)
485     size *= i;
486 
487   return size;
488 }
489 
490 template<typename Scalar>
491 inline void
SaveArrayAsNumpy(const std::string & filename,bool fortran_order,unsigned int n_dims,const unsigned long shape[],const Scalar * data)492 SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
493                  const Scalar* data) {
494 //  static_assert(has_typestring<Scalar>::value, "scalar type not understood");
495   const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
496 
497   std::ofstream stream(filename, std::ofstream::binary);
498   if (!stream) {
499     throw std::runtime_error("io error: failed to open a file.");
500   }
501 
502   std::vector <ndarray_len_t> shape_v(shape, shape + n_dims);
503   header_t header{dtype, fortran_order, shape_v};
504   write_header(stream, header);
505 
506   auto size = static_cast<size_t>(comp_size(shape_v));
507 
508   stream.write(reinterpret_cast<const char *>(data), sizeof(Scalar) * size);
509 }
510 
511 template<typename Scalar>
512 inline void
SaveArrayAsNumpy(const std::string & filename,bool fortran_order,unsigned int n_dims,const unsigned long shape[],const std::vector<Scalar> & data)513 SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
514                  const std::vector <Scalar> &data) {
515   SaveArrayAsNumpy(filename, fortran_order, n_dims, shape, data.data());
516 }
517 
518 template<typename Scalar>
519 inline void
LoadArrayFromNumpy(const std::string & filename,std::vector<unsigned long> & shape,std::vector<Scalar> & data)520 LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, std::vector <Scalar> &data) {
521   bool fortran_order;
522   LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data);
523 }
524 
525 template<typename Scalar>
LoadArrayFromNumpy(const std::string & filename,std::vector<unsigned long> & shape,bool & fortran_order,std::vector<Scalar> & data)526 inline void LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, bool &fortran_order,
527                                std::vector <Scalar> &data) {
528   std::ifstream stream(filename, std::ifstream::binary);
529   if (!stream) {
530     throw std::runtime_error("io error: failed to open a file.");
531   }
532 
533   std::string header_s = read_header(stream);
534 
535   // parse header
536   header_t header = parse_header(header_s);
537 
538   // check if the typestring matches the given one
539 //  static_assert(has_typestring<Scalar>::value, "scalar type not understood");
540   const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
541 
542   if (header.dtype.tie() != dtype.tie()) {
543     throw std::runtime_error("formatting error: typestrings not matching");
544   }
545 
546   shape = header.shape;
547   fortran_order = header.fortran_order;
548 
549   // compute the data size based on the shape
550   auto size = static_cast<size_t>(comp_size(shape));
551   data.resize(size);
552 
553   // read the data
554   stream.read(reinterpret_cast<char *>(data.data()), sizeof(Scalar) * size);
555 }
556 
557 }  // namespace npy
558 
559 #endif  // NPY_HPP_
560