1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport gzip 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom urllib.error import URLError 6*da0073e9SAndroid Build Coastguard Workerfrom urllib.request import urlretrieve 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard WorkerMIRRORS = [ 10*da0073e9SAndroid Build Coastguard Worker "http://yann.lecun.com/exdb/mnist/", 11*da0073e9SAndroid Build Coastguard Worker "https://ossci-datasets.s3.amazonaws.com/mnist/", 12*da0073e9SAndroid Build Coastguard Worker] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard WorkerRESOURCES = [ 15*da0073e9SAndroid Build Coastguard Worker "train-images-idx3-ubyte.gz", 16*da0073e9SAndroid Build Coastguard Worker "train-labels-idx1-ubyte.gz", 17*da0073e9SAndroid Build Coastguard Worker "t10k-images-idx3-ubyte.gz", 18*da0073e9SAndroid Build Coastguard Worker "t10k-labels-idx1-ubyte.gz", 19*da0073e9SAndroid Build Coastguard Worker] 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerdef report_download_progress( 23*da0073e9SAndroid Build Coastguard Worker chunk_number: int, 24*da0073e9SAndroid Build Coastguard Worker chunk_size: int, 25*da0073e9SAndroid Build Coastguard Worker file_size: int, 26*da0073e9SAndroid Build Coastguard Worker) -> None: 27*da0073e9SAndroid Build Coastguard Worker if file_size != -1: 28*da0073e9SAndroid Build Coastguard Worker percent = min(1, (chunk_number * chunk_size) / file_size) 29*da0073e9SAndroid Build Coastguard Worker bar = "#" * int(64 * percent) 30*da0073e9SAndroid Build Coastguard Worker sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerdef download(destination_path: str, resource: str, quiet: bool) -> None: 34*da0073e9SAndroid Build Coastguard Worker if os.path.exists(destination_path): 35*da0073e9SAndroid Build Coastguard Worker if not quiet: 36*da0073e9SAndroid Build Coastguard Worker print(f"{destination_path} already exists, skipping ...") 37*da0073e9SAndroid Build Coastguard Worker else: 38*da0073e9SAndroid Build Coastguard Worker for mirror in MIRRORS: 39*da0073e9SAndroid Build Coastguard Worker url = mirror + resource 40*da0073e9SAndroid Build Coastguard Worker print(f"Downloading {url} ...") 41*da0073e9SAndroid Build Coastguard Worker try: 42*da0073e9SAndroid Build Coastguard Worker hook = None if quiet else report_download_progress 43*da0073e9SAndroid Build Coastguard Worker urlretrieve(url, destination_path, reporthook=hook) 44*da0073e9SAndroid Build Coastguard Worker except (URLError, ConnectionError) as e: 45*da0073e9SAndroid Build Coastguard Worker print(f"Failed to download (trying next):\n{e}") 46*da0073e9SAndroid Build Coastguard Worker continue 47*da0073e9SAndroid Build Coastguard Worker finally: 48*da0073e9SAndroid Build Coastguard Worker if not quiet: 49*da0073e9SAndroid Build Coastguard Worker # Just a newline. 50*da0073e9SAndroid Build Coastguard Worker print() 51*da0073e9SAndroid Build Coastguard Worker break 52*da0073e9SAndroid Build Coastguard Worker else: 53*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Error downloading resource!") 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerdef unzip(zipped_path: str, quiet: bool) -> None: 57*da0073e9SAndroid Build Coastguard Worker unzipped_path = os.path.splitext(zipped_path)[0] 58*da0073e9SAndroid Build Coastguard Worker if os.path.exists(unzipped_path): 59*da0073e9SAndroid Build Coastguard Worker if not quiet: 60*da0073e9SAndroid Build Coastguard Worker print(f"{unzipped_path} already exists, skipping ... ") 61*da0073e9SAndroid Build Coastguard Worker return 62*da0073e9SAndroid Build Coastguard Worker with gzip.open(zipped_path, "rb") as zipped_file: 63*da0073e9SAndroid Build Coastguard Worker with open(unzipped_path, "wb") as unzipped_file: 64*da0073e9SAndroid Build Coastguard Worker unzipped_file.write(zipped_file.read()) 65*da0073e9SAndroid Build Coastguard Worker if not quiet: 66*da0073e9SAndroid Build Coastguard Worker print(f"Unzipped {zipped_path} ...") 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerdef main() -> None: 70*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 71*da0073e9SAndroid Build Coastguard Worker description="Download the MNIST dataset from the internet" 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 74*da0073e9SAndroid Build Coastguard Worker "-d", "--destination", default=".", help="Destination directory" 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 77*da0073e9SAndroid Build Coastguard Worker "-q", "--quiet", action="store_true", help="Don't report about progress" 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker options = parser.parse_args() 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker if not os.path.exists(options.destination): 82*da0073e9SAndroid Build Coastguard Worker os.makedirs(options.destination) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker try: 85*da0073e9SAndroid Build Coastguard Worker for resource in RESOURCES: 86*da0073e9SAndroid Build Coastguard Worker path = os.path.join(options.destination, resource) 87*da0073e9SAndroid Build Coastguard Worker download(path, resource, options.quiet) 88*da0073e9SAndroid Build Coastguard Worker unzip(path, options.quiet) 89*da0073e9SAndroid Build Coastguard Worker except KeyboardInterrupt: 90*da0073e9SAndroid Build Coastguard Worker print("Interrupted") 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 94*da0073e9SAndroid Build Coastguard Worker main() 95