1*c217d954SCole Faust /*
2*c217d954SCole Faust Copyright 2017 Leon Merten Lohse
3*c217d954SCole Faust
4*c217d954SCole Faust Permission is hereby granted, free of charge, to any person obtaining a copy
5*c217d954SCole Faust of this software and associated documentation files (the "Software"), to deal
6*c217d954SCole Faust in the Software without restriction, including without limitation the rights
7*c217d954SCole Faust to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8*c217d954SCole Faust copies of the Software, and to permit persons to whom the Software is
9*c217d954SCole Faust furnished to do so, subject to the following conditions:
10*c217d954SCole Faust
11*c217d954SCole Faust The above copyright notice and this permission notice shall be included in
12*c217d954SCole Faust all copies or substantial portions of the Software.
13*c217d954SCole Faust
14*c217d954SCole Faust THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15*c217d954SCole Faust IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16*c217d954SCole Faust FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17*c217d954SCole Faust AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18*c217d954SCole Faust LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19*c217d954SCole Faust OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20*c217d954SCole Faust SOFTWARE.
21*c217d954SCole Faust */
22*c217d954SCole Faust
23*c217d954SCole Faust #ifndef NPY_HPP_
24*c217d954SCole Faust #define NPY_HPP_
25*c217d954SCole Faust
26*c217d954SCole Faust #include <complex>
27*c217d954SCole Faust #include <fstream>
28*c217d954SCole Faust #include <string>
29*c217d954SCole Faust #include <iostream>
30*c217d954SCole Faust #include <sstream>
31*c217d954SCole Faust #include <cstdint>
32*c217d954SCole Faust #include <cstring>
33*c217d954SCole Faust #include <array>
34*c217d954SCole Faust #include <vector>
35*c217d954SCole Faust #include <stdexcept>
36*c217d954SCole Faust #include <algorithm>
37*c217d954SCole Faust #include <unordered_map>
38*c217d954SCole Faust #include <type_traits>
39*c217d954SCole Faust #include <typeinfo>
40*c217d954SCole Faust #include <typeindex>
41*c217d954SCole Faust #include <iterator>
42*c217d954SCole Faust #include <utility>
43*c217d954SCole Faust
44*c217d954SCole Faust
45*c217d954SCole Faust namespace npy {
46*c217d954SCole Faust
47*c217d954SCole Faust /* Compile-time test for byte order.
48*c217d954SCole Faust If your compiler does not define these per default, you may want to define
49*c217d954SCole Faust one of these constants manually.
50*c217d954SCole Faust Defaults to little endian order. */
51*c217d954SCole Faust #if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
52*c217d954SCole Faust defined(__BIG_ENDIAN__) || \
53*c217d954SCole Faust defined(__ARMEB__) || \
54*c217d954SCole Faust defined(__THUMBEB__) || \
55*c217d954SCole Faust defined(__AARCH64EB__) || \
56*c217d954SCole Faust defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
57*c217d954SCole Faust const bool big_endian = true;
58*c217d954SCole Faust #else
59*c217d954SCole Faust const bool big_endian = false;
60*c217d954SCole Faust #endif
61*c217d954SCole Faust
62*c217d954SCole Faust
63*c217d954SCole Faust const char magic_string[] = "\x93NUMPY";
64*c217d954SCole Faust const size_t magic_string_length = 6;
65*c217d954SCole Faust
66*c217d954SCole Faust const char little_endian_char = '<';
67*c217d954SCole Faust const char big_endian_char = '>';
68*c217d954SCole Faust const char no_endian_char = '|';
69*c217d954SCole Faust
70*c217d954SCole Faust constexpr std::array<char, 3>
71*c217d954SCole Faust endian_chars = {little_endian_char, big_endian_char, no_endian_char};
72*c217d954SCole Faust constexpr std::array<char, 4>
73*c217d954SCole Faust numtype_chars = {'f', 'i', 'u', 'c'};
74*c217d954SCole Faust
75*c217d954SCole Faust constexpr char host_endian_char = (big_endian ?
76*c217d954SCole Faust big_endian_char :
77*c217d954SCole Faust little_endian_char);
78*c217d954SCole Faust
79*c217d954SCole Faust /* npy array length */
80*c217d954SCole Faust typedef unsigned long int ndarray_len_t;
81*c217d954SCole Faust
82*c217d954SCole Faust typedef std::pair<char, char> version_t;
83*c217d954SCole Faust
84*c217d954SCole Faust struct dtype_t {
85*c217d954SCole Faust const char byteorder;
86*c217d954SCole Faust const char kind;
87*c217d954SCole Faust const unsigned int itemsize;
88*c217d954SCole Faust
89*c217d954SCole Faust // TODO(llohse): implement as constexpr
strnpy::dtype_t90*c217d954SCole Faust inline std::string str() const {
91*c217d954SCole Faust const size_t max_buflen = 16;
92*c217d954SCole Faust char buf[max_buflen];
93*c217d954SCole Faust std::snprintf(buf, max_buflen, "%c%c%u", byteorder, kind, itemsize);
94*c217d954SCole Faust return std::string(buf);
95*c217d954SCole Faust }
96*c217d954SCole Faust
tienpy::dtype_t97*c217d954SCole Faust inline std::tuple<const char, const char, const unsigned int> tie() const {
98*c217d954SCole Faust return std::tie(byteorder, kind, itemsize);
99*c217d954SCole Faust }
100*c217d954SCole Faust };
101*c217d954SCole Faust
102*c217d954SCole Faust
103*c217d954SCole Faust struct header_t {
104*c217d954SCole Faust const dtype_t dtype;
105*c217d954SCole Faust const bool fortran_order;
106*c217d954SCole Faust const std::vector <ndarray_len_t> shape;
107*c217d954SCole Faust };
108*c217d954SCole Faust
write_magic(std::ostream & ostream,version_t version)109*c217d954SCole Faust inline void write_magic(std::ostream &ostream, version_t version) {
110*c217d954SCole Faust ostream.write(magic_string, magic_string_length);
111*c217d954SCole Faust ostream.put(version.first);
112*c217d954SCole Faust ostream.put(version.second);
113*c217d954SCole Faust }
114*c217d954SCole Faust
read_magic(std::istream & istream)115*c217d954SCole Faust inline version_t read_magic(std::istream &istream) {
116*c217d954SCole Faust char buf[magic_string_length + 2];
117*c217d954SCole Faust istream.read(buf, magic_string_length + 2);
118*c217d954SCole Faust
119*c217d954SCole Faust if (!istream) {
120*c217d954SCole Faust throw std::runtime_error("io error: failed reading file");
121*c217d954SCole Faust }
122*c217d954SCole Faust
123*c217d954SCole Faust if (0 != std::memcmp(buf, magic_string, magic_string_length))
124*c217d954SCole Faust throw std::runtime_error("this file does not have a valid npy format.");
125*c217d954SCole Faust
126*c217d954SCole Faust version_t version;
127*c217d954SCole Faust version.first = buf[magic_string_length];
128*c217d954SCole Faust version.second = buf[magic_string_length + 1];
129*c217d954SCole Faust
130*c217d954SCole Faust return version;
131*c217d954SCole Faust }
132*c217d954SCole Faust
133*c217d954SCole Faust const std::unordered_map<std::type_index, dtype_t> dtype_map = {
134*c217d954SCole Faust {std::type_index(typeid(float)), {host_endian_char, 'f', sizeof(float)}},
135*c217d954SCole Faust {std::type_index(typeid(double)), {host_endian_char, 'f', sizeof(double)}},
136*c217d954SCole Faust {std::type_index(typeid(long double)), {host_endian_char, 'f', sizeof(long double)}},
137*c217d954SCole Faust {std::type_index(typeid(char)), {no_endian_char, 'i', sizeof(char)}},
138*c217d954SCole Faust {std::type_index(typeid(signed char)), {no_endian_char, 'i', sizeof(signed char)}},
139*c217d954SCole Faust {std::type_index(typeid(short)), {host_endian_char, 'i', sizeof(short)}},
140*c217d954SCole Faust {std::type_index(typeid(int)), {host_endian_char, 'i', sizeof(int)}},
141*c217d954SCole Faust {std::type_index(typeid(long)), {host_endian_char, 'i', sizeof(long)}},
142*c217d954SCole Faust {std::type_index(typeid(long long)), {host_endian_char, 'i', sizeof(long long)}},
143*c217d954SCole Faust {std::type_index(typeid(unsigned char)), {no_endian_char, 'u', sizeof(unsigned char)}},
144*c217d954SCole Faust {std::type_index(typeid(unsigned short)), {host_endian_char, 'u', sizeof(unsigned short)}},
145*c217d954SCole Faust {std::type_index(typeid(unsigned int)), {host_endian_char, 'u', sizeof(unsigned int)}},
146*c217d954SCole Faust {std::type_index(typeid(unsigned long)), {host_endian_char, 'u', sizeof(unsigned long)}},
147*c217d954SCole Faust {std::type_index(typeid(unsigned long long)), {host_endian_char, 'u', sizeof(unsigned long long)}},
148*c217d954SCole Faust {std::type_index(typeid(std::complex<float>)), {host_endian_char, 'c', sizeof(std::complex<float>)}},
149*c217d954SCole Faust {std::type_index(typeid(std::complex<double>)), {host_endian_char, 'c', sizeof(std::complex<double>)}},
150*c217d954SCole Faust {std::type_index(typeid(std::complex<long double>)), {host_endian_char, 'c', sizeof(std::complex<long double>)}}
151*c217d954SCole Faust };
152*c217d954SCole Faust
153*c217d954SCole Faust
154*c217d954SCole Faust // helpers
is_digits(const std::string & str)155*c217d954SCole Faust inline bool is_digits(const std::string &str) {
156*c217d954SCole Faust return std::all_of(str.begin(), str.end(), ::isdigit);
157*c217d954SCole Faust }
158*c217d954SCole Faust
159*c217d954SCole Faust template<typename T, size_t N>
in_array(T val,const std::array<T,N> & arr)160*c217d954SCole Faust inline bool in_array(T val, const std::array <T, N> &arr) {
161*c217d954SCole Faust return std::find(std::begin(arr), std::end(arr), val) != std::end(arr);
162*c217d954SCole Faust }
163*c217d954SCole Faust
parse_descr(std::string typestring)164*c217d954SCole Faust inline dtype_t parse_descr(std::string typestring) {
165*c217d954SCole Faust if (typestring.length() < 3) {
166*c217d954SCole Faust throw std::runtime_error("invalid typestring (length)");
167*c217d954SCole Faust }
168*c217d954SCole Faust
169*c217d954SCole Faust char byteorder_c = typestring.at(0);
170*c217d954SCole Faust char kind_c = typestring.at(1);
171*c217d954SCole Faust std::string itemsize_s = typestring.substr(2);
172*c217d954SCole Faust
173*c217d954SCole Faust if (!in_array(byteorder_c, endian_chars)) {
174*c217d954SCole Faust throw std::runtime_error("invalid typestring (byteorder)");
175*c217d954SCole Faust }
176*c217d954SCole Faust
177*c217d954SCole Faust if (!in_array(kind_c, numtype_chars)) {
178*c217d954SCole Faust throw std::runtime_error("invalid typestring (kind)");
179*c217d954SCole Faust }
180*c217d954SCole Faust
181*c217d954SCole Faust if (!is_digits(itemsize_s)) {
182*c217d954SCole Faust throw std::runtime_error("invalid typestring (itemsize)");
183*c217d954SCole Faust }
184*c217d954SCole Faust unsigned int itemsize = std::stoul(itemsize_s);
185*c217d954SCole Faust
186*c217d954SCole Faust return {byteorder_c, kind_c, itemsize};
187*c217d954SCole Faust }
188*c217d954SCole Faust
189*c217d954SCole Faust namespace pyparse {
190*c217d954SCole Faust
191*c217d954SCole Faust /**
192*c217d954SCole Faust Removes leading and trailing whitespaces
193*c217d954SCole Faust */
trim(const std::string & str)194*c217d954SCole Faust inline std::string trim(const std::string &str) {
195*c217d954SCole Faust const std::string whitespace = " \t";
196*c217d954SCole Faust auto begin = str.find_first_not_of(whitespace);
197*c217d954SCole Faust
198*c217d954SCole Faust if (begin == std::string::npos)
199*c217d954SCole Faust return "";
200*c217d954SCole Faust
201*c217d954SCole Faust auto end = str.find_last_not_of(whitespace);
202*c217d954SCole Faust
203*c217d954SCole Faust return str.substr(begin, end - begin + 1);
204*c217d954SCole Faust }
205*c217d954SCole Faust
206*c217d954SCole Faust
get_value_from_map(const std::string & mapstr)207*c217d954SCole Faust inline std::string get_value_from_map(const std::string &mapstr) {
208*c217d954SCole Faust size_t sep_pos = mapstr.find_first_of(":");
209*c217d954SCole Faust if (sep_pos == std::string::npos)
210*c217d954SCole Faust return "";
211*c217d954SCole Faust
212*c217d954SCole Faust std::string tmp = mapstr.substr(sep_pos + 1);
213*c217d954SCole Faust return trim(tmp);
214*c217d954SCole Faust }
215*c217d954SCole Faust
216*c217d954SCole Faust /**
217*c217d954SCole Faust Parses the string representation of a Python dict
218*c217d954SCole Faust
219*c217d954SCole Faust The keys need to be known and may not appear anywhere else in the data.
220*c217d954SCole Faust */
parse_dict(std::string in,const std::vector<std::string> & keys)221*c217d954SCole Faust inline std::unordered_map <std::string, std::string> parse_dict(std::string in, const std::vector <std::string> &keys) {
222*c217d954SCole Faust std::unordered_map <std::string, std::string> map;
223*c217d954SCole Faust
224*c217d954SCole Faust if (keys.size() == 0)
225*c217d954SCole Faust return map;
226*c217d954SCole Faust
227*c217d954SCole Faust in = trim(in);
228*c217d954SCole Faust
229*c217d954SCole Faust // unwrap dictionary
230*c217d954SCole Faust if ((in.front() == '{') && (in.back() == '}'))
231*c217d954SCole Faust in = in.substr(1, in.length() - 2);
232*c217d954SCole Faust else
233*c217d954SCole Faust throw std::runtime_error("Not a Python dictionary.");
234*c217d954SCole Faust
235*c217d954SCole Faust std::vector <std::pair<size_t, std::string>> positions;
236*c217d954SCole Faust
237*c217d954SCole Faust for (auto const &value : keys) {
238*c217d954SCole Faust size_t pos = in.find("'" + value + "'");
239*c217d954SCole Faust
240*c217d954SCole Faust if (pos == std::string::npos)
241*c217d954SCole Faust throw std::runtime_error("Missing '" + value + "' key.");
242*c217d954SCole Faust
243*c217d954SCole Faust std::pair <size_t, std::string> position_pair{pos, value};
244*c217d954SCole Faust positions.push_back(position_pair);
245*c217d954SCole Faust }
246*c217d954SCole Faust
247*c217d954SCole Faust // sort by position in dict
248*c217d954SCole Faust std::sort(positions.begin(), positions.end());
249*c217d954SCole Faust
250*c217d954SCole Faust for (size_t i = 0; i < positions.size(); ++i) {
251*c217d954SCole Faust std::string raw_value;
252*c217d954SCole Faust size_t begin{positions[i].first};
253*c217d954SCole Faust size_t end{std::string::npos};
254*c217d954SCole Faust
255*c217d954SCole Faust std::string key = positions[i].second;
256*c217d954SCole Faust
257*c217d954SCole Faust if (i + 1 < positions.size())
258*c217d954SCole Faust end = positions[i + 1].first;
259*c217d954SCole Faust
260*c217d954SCole Faust raw_value = in.substr(begin, end - begin);
261*c217d954SCole Faust
262*c217d954SCole Faust raw_value = trim(raw_value);
263*c217d954SCole Faust
264*c217d954SCole Faust if (raw_value.back() == ',')
265*c217d954SCole Faust raw_value.pop_back();
266*c217d954SCole Faust
267*c217d954SCole Faust map[key] = get_value_from_map(raw_value);
268*c217d954SCole Faust }
269*c217d954SCole Faust
270*c217d954SCole Faust return map;
271*c217d954SCole Faust }
272*c217d954SCole Faust
273*c217d954SCole Faust /**
274*c217d954SCole Faust Parses the string representation of a Python boolean
275*c217d954SCole Faust */
parse_bool(const std::string & in)276*c217d954SCole Faust inline bool parse_bool(const std::string &in) {
277*c217d954SCole Faust if (in == "True")
278*c217d954SCole Faust return true;
279*c217d954SCole Faust if (in == "False")
280*c217d954SCole Faust return false;
281*c217d954SCole Faust
282*c217d954SCole Faust throw std::runtime_error("Invalid python boolan.");
283*c217d954SCole Faust }
284*c217d954SCole Faust
285*c217d954SCole Faust /**
286*c217d954SCole Faust Parses the string representation of a Python str
287*c217d954SCole Faust */
parse_str(const std::string & in)288*c217d954SCole Faust inline std::string parse_str(const std::string &in) {
289*c217d954SCole Faust if ((in.front() == '\'') && (in.back() == '\''))
290*c217d954SCole Faust return in.substr(1, in.length() - 2);
291*c217d954SCole Faust
292*c217d954SCole Faust throw std::runtime_error("Invalid python string.");
293*c217d954SCole Faust }
294*c217d954SCole Faust
295*c217d954SCole Faust /**
296*c217d954SCole Faust Parses the string represenatation of a Python tuple into a vector of its items
297*c217d954SCole Faust */
parse_tuple(std::string in)298*c217d954SCole Faust inline std::vector <std::string> parse_tuple(std::string in) {
299*c217d954SCole Faust std::vector <std::string> v;
300*c217d954SCole Faust const char seperator = ',';
301*c217d954SCole Faust
302*c217d954SCole Faust in = trim(in);
303*c217d954SCole Faust
304*c217d954SCole Faust if ((in.front() == '(') && (in.back() == ')'))
305*c217d954SCole Faust in = in.substr(1, in.length() - 2);
306*c217d954SCole Faust else
307*c217d954SCole Faust throw std::runtime_error("Invalid Python tuple.");
308*c217d954SCole Faust
309*c217d954SCole Faust std::istringstream iss(in);
310*c217d954SCole Faust
311*c217d954SCole Faust for (std::string token; std::getline(iss, token, seperator);) {
312*c217d954SCole Faust v.push_back(token);
313*c217d954SCole Faust }
314*c217d954SCole Faust
315*c217d954SCole Faust return v;
316*c217d954SCole Faust }
317*c217d954SCole Faust
318*c217d954SCole Faust template<typename T>
write_tuple(const std::vector<T> & v)319*c217d954SCole Faust inline std::string write_tuple(const std::vector <T> &v) {
320*c217d954SCole Faust if (v.size() == 0)
321*c217d954SCole Faust return "()";
322*c217d954SCole Faust
323*c217d954SCole Faust std::ostringstream ss;
324*c217d954SCole Faust
325*c217d954SCole Faust if (v.size() == 1) {
326*c217d954SCole Faust ss << "(" << v.front() << ",)";
327*c217d954SCole Faust } else {
328*c217d954SCole Faust const std::string delimiter = ", ";
329*c217d954SCole Faust // v.size() > 1
330*c217d954SCole Faust ss << "(";
331*c217d954SCole Faust std::copy(v.begin(), v.end() - 1, std::ostream_iterator<T>(ss, delimiter.c_str()));
332*c217d954SCole Faust ss << v.back();
333*c217d954SCole Faust ss << ")";
334*c217d954SCole Faust }
335*c217d954SCole Faust
336*c217d954SCole Faust return ss.str();
337*c217d954SCole Faust }
338*c217d954SCole Faust
write_boolean(bool b)339*c217d954SCole Faust inline std::string write_boolean(bool b) {
340*c217d954SCole Faust if (b)
341*c217d954SCole Faust return "True";
342*c217d954SCole Faust else
343*c217d954SCole Faust return "False";
344*c217d954SCole Faust }
345*c217d954SCole Faust
346*c217d954SCole Faust } // namespace pyparse
347*c217d954SCole Faust
348*c217d954SCole Faust
parse_header(std::string header)349*c217d954SCole Faust inline header_t parse_header(std::string header) {
350*c217d954SCole Faust /*
351*c217d954SCole Faust The first 6 bytes are a magic string: exactly "x93NUMPY".
352*c217d954SCole Faust The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
353*c217d954SCole Faust The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
354*c217d954SCole Faust The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
355*c217d954SCole Faust The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
356*c217d954SCole Faust The dictionary contains three keys:
357*c217d954SCole Faust
358*c217d954SCole Faust "descr" : dtype.descr
359*c217d954SCole Faust An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
360*c217d954SCole Faust "fortran_order" : bool
361*c217d954SCole Faust Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
362*c217d954SCole Faust "shape" : tuple of int
363*c217d954SCole Faust The shape of the array.
364*c217d954SCole Faust For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
365*c217d954SCole Faust */
366*c217d954SCole Faust
367*c217d954SCole Faust // remove trailing newline
368*c217d954SCole Faust if (header.back() != '\n')
369*c217d954SCole Faust throw std::runtime_error("invalid header");
370*c217d954SCole Faust header.pop_back();
371*c217d954SCole Faust
372*c217d954SCole Faust // parse the dictionary
373*c217d954SCole Faust std::vector <std::string> keys{"descr", "fortran_order", "shape"};
374*c217d954SCole Faust auto dict_map = npy::pyparse::parse_dict(header, keys);
375*c217d954SCole Faust
376*c217d954SCole Faust if (dict_map.size() == 0)
377*c217d954SCole Faust throw std::runtime_error("invalid dictionary in header");
378*c217d954SCole Faust
379*c217d954SCole Faust std::string descr_s = dict_map["descr"];
380*c217d954SCole Faust std::string fortran_s = dict_map["fortran_order"];
381*c217d954SCole Faust std::string shape_s = dict_map["shape"];
382*c217d954SCole Faust
383*c217d954SCole Faust std::string descr = npy::pyparse::parse_str(descr_s);
384*c217d954SCole Faust dtype_t dtype = parse_descr(descr);
385*c217d954SCole Faust
386*c217d954SCole Faust // convert literal Python bool to C++ bool
387*c217d954SCole Faust bool fortran_order = npy::pyparse::parse_bool(fortran_s);
388*c217d954SCole Faust
389*c217d954SCole Faust // parse the shape tuple
390*c217d954SCole Faust auto shape_v = npy::pyparse::parse_tuple(shape_s);
391*c217d954SCole Faust
392*c217d954SCole Faust std::vector <ndarray_len_t> shape;
393*c217d954SCole Faust for (auto item : shape_v) {
394*c217d954SCole Faust ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item));
395*c217d954SCole Faust shape.push_back(dim);
396*c217d954SCole Faust }
397*c217d954SCole Faust
398*c217d954SCole Faust return {dtype, fortran_order, shape};
399*c217d954SCole Faust }
400*c217d954SCole Faust
401*c217d954SCole Faust
402*c217d954SCole Faust inline std::string
write_header_dict(const std::string & descr,bool fortran_order,const std::vector<ndarray_len_t> & shape)403*c217d954SCole Faust write_header_dict(const std::string &descr, bool fortran_order, const std::vector <ndarray_len_t> &shape) {
404*c217d954SCole Faust std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
405*c217d954SCole Faust std::string shape_s = npy::pyparse::write_tuple(shape);
406*c217d954SCole Faust
407*c217d954SCole Faust return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
408*c217d954SCole Faust }
409*c217d954SCole Faust
write_header(std::ostream & out,const header_t & header)410*c217d954SCole Faust inline void write_header(std::ostream &out, const header_t &header) {
411*c217d954SCole Faust std::string header_dict = write_header_dict(header.dtype.str(), header.fortran_order, header.shape);
412*c217d954SCole Faust
413*c217d954SCole Faust size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
414*c217d954SCole Faust
415*c217d954SCole Faust version_t version{1, 0};
416*c217d954SCole Faust if (length >= 255 * 255) {
417*c217d954SCole Faust length = magic_string_length + 2 + 4 + header_dict.length() + 1;
418*c217d954SCole Faust version = {2, 0};
419*c217d954SCole Faust }
420*c217d954SCole Faust size_t padding_len = 16 - length % 16;
421*c217d954SCole Faust std::string padding(padding_len, ' ');
422*c217d954SCole Faust
423*c217d954SCole Faust // write magic
424*c217d954SCole Faust write_magic(out, version);
425*c217d954SCole Faust
426*c217d954SCole Faust // write header length
427*c217d954SCole Faust if (version == version_t{1, 0}) {
428*c217d954SCole Faust uint8_t header_len_le16[2];
429*c217d954SCole Faust uint16_t header_len = static_cast<uint16_t>(header_dict.length() + padding.length() + 1);
430*c217d954SCole Faust
431*c217d954SCole Faust header_len_le16[0] = (header_len >> 0) & 0xff;
432*c217d954SCole Faust header_len_le16[1] = (header_len >> 8) & 0xff;
433*c217d954SCole Faust out.write(reinterpret_cast<char *>(header_len_le16), 2);
434*c217d954SCole Faust } else {
435*c217d954SCole Faust uint8_t header_len_le32[4];
436*c217d954SCole Faust uint32_t header_len = static_cast<uint32_t>(header_dict.length() + padding.length() + 1);
437*c217d954SCole Faust
438*c217d954SCole Faust header_len_le32[0] = (header_len >> 0) & 0xff;
439*c217d954SCole Faust header_len_le32[1] = (header_len >> 8) & 0xff;
440*c217d954SCole Faust header_len_le32[2] = (header_len >> 16) & 0xff;
441*c217d954SCole Faust header_len_le32[3] = (header_len >> 24) & 0xff;
442*c217d954SCole Faust out.write(reinterpret_cast<char *>(header_len_le32), 4);
443*c217d954SCole Faust }
444*c217d954SCole Faust
445*c217d954SCole Faust out << header_dict << padding << '\n';
446*c217d954SCole Faust }
447*c217d954SCole Faust
read_header(std::istream & istream)448*c217d954SCole Faust inline std::string read_header(std::istream &istream) {
449*c217d954SCole Faust // check magic bytes an version number
450*c217d954SCole Faust version_t version = read_magic(istream);
451*c217d954SCole Faust
452*c217d954SCole Faust uint32_t header_length;
453*c217d954SCole Faust if (version == version_t{1, 0}) {
454*c217d954SCole Faust uint8_t header_len_le16[2];
455*c217d954SCole Faust istream.read(reinterpret_cast<char *>(header_len_le16), 2);
456*c217d954SCole Faust header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
457*c217d954SCole Faust
458*c217d954SCole Faust if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
459*c217d954SCole Faust // TODO(llohse): display warning
460*c217d954SCole Faust }
461*c217d954SCole Faust } else if (version == version_t{2, 0}) {
462*c217d954SCole Faust uint8_t header_len_le32[4];
463*c217d954SCole Faust istream.read(reinterpret_cast<char *>(header_len_le32), 4);
464*c217d954SCole Faust
465*c217d954SCole Faust header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
466*c217d954SCole Faust | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
467*c217d954SCole Faust
468*c217d954SCole Faust if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
469*c217d954SCole Faust // TODO(llohse): display warning
470*c217d954SCole Faust }
471*c217d954SCole Faust } else {
472*c217d954SCole Faust throw std::runtime_error("unsupported file format version");
473*c217d954SCole Faust }
474*c217d954SCole Faust
475*c217d954SCole Faust auto buf_v = std::vector<char>(header_length);
476*c217d954SCole Faust istream.read(buf_v.data(), header_length);
477*c217d954SCole Faust std::string header(buf_v.data(), header_length);
478*c217d954SCole Faust
479*c217d954SCole Faust return header;
480*c217d954SCole Faust }
481*c217d954SCole Faust
comp_size(const std::vector<ndarray_len_t> & shape)482*c217d954SCole Faust inline ndarray_len_t comp_size(const std::vector <ndarray_len_t> &shape) {
483*c217d954SCole Faust ndarray_len_t size = 1;
484*c217d954SCole Faust for (ndarray_len_t i : shape)
485*c217d954SCole Faust size *= i;
486*c217d954SCole Faust
487*c217d954SCole Faust return size;
488*c217d954SCole Faust }
489*c217d954SCole Faust
490*c217d954SCole Faust template<typename Scalar>
491*c217d954SCole Faust inline void
SaveArrayAsNumpy(const std::string & filename,bool fortran_order,unsigned int n_dims,const unsigned long shape[],const Scalar * data)492*c217d954SCole Faust SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
493*c217d954SCole Faust const Scalar* data) {
494*c217d954SCole Faust // static_assert(has_typestring<Scalar>::value, "scalar type not understood");
495*c217d954SCole Faust const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
496*c217d954SCole Faust
497*c217d954SCole Faust std::ofstream stream(filename, std::ofstream::binary);
498*c217d954SCole Faust if (!stream) {
499*c217d954SCole Faust throw std::runtime_error("io error: failed to open a file.");
500*c217d954SCole Faust }
501*c217d954SCole Faust
502*c217d954SCole Faust std::vector <ndarray_len_t> shape_v(shape, shape + n_dims);
503*c217d954SCole Faust header_t header{dtype, fortran_order, shape_v};
504*c217d954SCole Faust write_header(stream, header);
505*c217d954SCole Faust
506*c217d954SCole Faust auto size = static_cast<size_t>(comp_size(shape_v));
507*c217d954SCole Faust
508*c217d954SCole Faust stream.write(reinterpret_cast<const char *>(data), sizeof(Scalar) * size);
509*c217d954SCole Faust }
510*c217d954SCole Faust
511*c217d954SCole Faust template<typename Scalar>
512*c217d954SCole Faust inline void
SaveArrayAsNumpy(const std::string & filename,bool fortran_order,unsigned int n_dims,const unsigned long shape[],const std::vector<Scalar> & data)513*c217d954SCole Faust SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
514*c217d954SCole Faust const std::vector <Scalar> &data) {
515*c217d954SCole Faust SaveArrayAsNumpy(filename, fortran_order, n_dims, shape, data.data());
516*c217d954SCole Faust }
517*c217d954SCole Faust
518*c217d954SCole Faust template<typename Scalar>
519*c217d954SCole Faust inline void
LoadArrayFromNumpy(const std::string & filename,std::vector<unsigned long> & shape,std::vector<Scalar> & data)520*c217d954SCole Faust LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, std::vector <Scalar> &data) {
521*c217d954SCole Faust bool fortran_order;
522*c217d954SCole Faust LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data);
523*c217d954SCole Faust }
524*c217d954SCole Faust
525*c217d954SCole Faust template<typename Scalar>
LoadArrayFromNumpy(const std::string & filename,std::vector<unsigned long> & shape,bool & fortran_order,std::vector<Scalar> & data)526*c217d954SCole Faust inline void LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, bool &fortran_order,
527*c217d954SCole Faust std::vector <Scalar> &data) {
528*c217d954SCole Faust std::ifstream stream(filename, std::ifstream::binary);
529*c217d954SCole Faust if (!stream) {
530*c217d954SCole Faust throw std::runtime_error("io error: failed to open a file.");
531*c217d954SCole Faust }
532*c217d954SCole Faust
533*c217d954SCole Faust std::string header_s = read_header(stream);
534*c217d954SCole Faust
535*c217d954SCole Faust // parse header
536*c217d954SCole Faust header_t header = parse_header(header_s);
537*c217d954SCole Faust
538*c217d954SCole Faust // check if the typestring matches the given one
539*c217d954SCole Faust // static_assert(has_typestring<Scalar>::value, "scalar type not understood");
540*c217d954SCole Faust const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
541*c217d954SCole Faust
542*c217d954SCole Faust if (header.dtype.tie() != dtype.tie()) {
543*c217d954SCole Faust throw std::runtime_error("formatting error: typestrings not matching");
544*c217d954SCole Faust }
545*c217d954SCole Faust
546*c217d954SCole Faust shape = header.shape;
547*c217d954SCole Faust fortran_order = header.fortran_order;
548*c217d954SCole Faust
549*c217d954SCole Faust // compute the data size based on the shape
550*c217d954SCole Faust auto size = static_cast<size_t>(comp_size(shape));
551*c217d954SCole Faust data.resize(size);
552*c217d954SCole Faust
553*c217d954SCole Faust // read the data
554*c217d954SCole Faust stream.read(reinterpret_cast<char *>(data.data()), sizeof(Scalar) * size);
555*c217d954SCole Faust }
556*c217d954SCole Faust
557*c217d954SCole Faust } // namespace npy
558*c217d954SCole Faust
559*c217d954SCole Faust #endif // NPY_HPP_
560