1import argparse 2import gzip 3import os 4import sys 5from urllib.error import URLError 6from urllib.request import urlretrieve 7 8 9MIRRORS = [ 10 "http://yann.lecun.com/exdb/mnist/", 11 "https://ossci-datasets.s3.amazonaws.com/mnist/", 12] 13 14RESOURCES = [ 15 "train-images-idx3-ubyte.gz", 16 "train-labels-idx1-ubyte.gz", 17 "t10k-images-idx3-ubyte.gz", 18 "t10k-labels-idx1-ubyte.gz", 19] 20 21 22def report_download_progress( 23 chunk_number: int, 24 chunk_size: int, 25 file_size: int, 26) -> None: 27 if file_size != -1: 28 percent = min(1, (chunk_number * chunk_size) / file_size) 29 bar = "#" * int(64 * percent) 30 sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") 31 32 33def download(destination_path: str, resource: str, quiet: bool) -> None: 34 if os.path.exists(destination_path): 35 if not quiet: 36 print(f"{destination_path} already exists, skipping ...") 37 else: 38 for mirror in MIRRORS: 39 url = mirror + resource 40 print(f"Downloading {url} ...") 41 try: 42 hook = None if quiet else report_download_progress 43 urlretrieve(url, destination_path, reporthook=hook) 44 except (URLError, ConnectionError) as e: 45 print(f"Failed to download (trying next):\n{e}") 46 continue 47 finally: 48 if not quiet: 49 # Just a newline. 50 print() 51 break 52 else: 53 raise RuntimeError("Error downloading resource!") 54 55 56def unzip(zipped_path: str, quiet: bool) -> None: 57 unzipped_path = os.path.splitext(zipped_path)[0] 58 if os.path.exists(unzipped_path): 59 if not quiet: 60 print(f"{unzipped_path} already exists, skipping ... ") 61 return 62 with gzip.open(zipped_path, "rb") as zipped_file: 63 with open(unzipped_path, "wb") as unzipped_file: 64 unzipped_file.write(zipped_file.read()) 65 if not quiet: 66 print(f"Unzipped {zipped_path} ...") 67 68 69def main() -> None: 70 parser = argparse.ArgumentParser( 71 description="Download the MNIST dataset from the internet" 72 ) 73 parser.add_argument( 74 "-d", "--destination", default=".", help="Destination directory" 75 ) 76 parser.add_argument( 77 "-q", "--quiet", action="store_true", help="Don't report about progress" 78 ) 79 options = parser.parse_args() 80 81 if not os.path.exists(options.destination): 82 os.makedirs(options.destination) 83 84 try: 85 for resource in RESOURCES: 86 path = os.path.join(options.destination, resource) 87 download(path, resource, options.quiet) 88 unzip(path, options.quiet) 89 except KeyboardInterrupt: 90 print("Interrupted") 91 92 93if __name__ == "__main__": 94 main() 95