1#!/usr/bin/env python 2"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays 3Usage 4 python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path 5 6Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the 7corresponding labels to labels100.txt. 8 9Tested with Caffe 1.0 on Python 2.7 10""" 11import argparse 12import os 13import struct 14import numpy as np 15from array import array 16 17 18if __name__ == "__main__": 19 # Parse arguments 20 parser = argparse.ArgumentParser('Extract Caffe mnist image data') 21 parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory') 22 parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)') 23 args = parser.parse_args() 24 25 images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte') 26 labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte') 27 28 images_file = open(images_filename, 'rb') 29 labels_file = open(labels_filename, 'rb') 30 images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16)) 31 labels_magic, labels_size = struct.unpack('>II', labels_file.read(8)) 32 images = array('B', images_file.read()) 33 labels = array('b', labels_file.read()) 34 35 input10_path = os.path.join(args.outDir, 'input10.npy') 36 input100_path = os.path.join(args.outDir, 'input100.npy') 37 labels100_path = os.path.join(args.outDir, 'labels100.npy') 38 39 outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32) 40 outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32) 41 labels_output = open(labels100_path, 'w') 42 for i in xrange(100): 43 image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0 44 outputs_100[i, :, :, 0] = image 45 46 if i < 10: 47 outputs_10[i, :, :, 0] = image 48 49 if i == 10: 50 np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2))) 51 print "Wrote", input10_path 52 53 labels_output.write(str(labels[i]) + '\n') 54 55 labels_output.close() 56 print "Wrote", labels100_path 57 58 np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2))) 59 print "Wrote", input100_path 60