xref: /aosp_15_r20/external/pytorch/tools/download_mnist.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import gzip
3import os
4import sys
5from urllib.error import URLError
6from urllib.request import urlretrieve
7
8
9MIRRORS = [
10    "http://yann.lecun.com/exdb/mnist/",
11    "https://ossci-datasets.s3.amazonaws.com/mnist/",
12]
13
14RESOURCES = [
15    "train-images-idx3-ubyte.gz",
16    "train-labels-idx1-ubyte.gz",
17    "t10k-images-idx3-ubyte.gz",
18    "t10k-labels-idx1-ubyte.gz",
19]
20
21
22def report_download_progress(
23    chunk_number: int,
24    chunk_size: int,
25    file_size: int,
26) -> None:
27    if file_size != -1:
28        percent = min(1, (chunk_number * chunk_size) / file_size)
29        bar = "#" * int(64 * percent)
30        sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
31
32
33def download(destination_path: str, resource: str, quiet: bool) -> None:
34    if os.path.exists(destination_path):
35        if not quiet:
36            print(f"{destination_path} already exists, skipping ...")
37    else:
38        for mirror in MIRRORS:
39            url = mirror + resource
40            print(f"Downloading {url} ...")
41            try:
42                hook = None if quiet else report_download_progress
43                urlretrieve(url, destination_path, reporthook=hook)
44            except (URLError, ConnectionError) as e:
45                print(f"Failed to download (trying next):\n{e}")
46                continue
47            finally:
48                if not quiet:
49                    # Just a newline.
50                    print()
51            break
52        else:
53            raise RuntimeError("Error downloading resource!")
54
55
56def unzip(zipped_path: str, quiet: bool) -> None:
57    unzipped_path = os.path.splitext(zipped_path)[0]
58    if os.path.exists(unzipped_path):
59        if not quiet:
60            print(f"{unzipped_path} already exists, skipping ... ")
61        return
62    with gzip.open(zipped_path, "rb") as zipped_file:
63        with open(unzipped_path, "wb") as unzipped_file:
64            unzipped_file.write(zipped_file.read())
65            if not quiet:
66                print(f"Unzipped {zipped_path} ...")
67
68
69def main() -> None:
70    parser = argparse.ArgumentParser(
71        description="Download the MNIST dataset from the internet"
72    )
73    parser.add_argument(
74        "-d", "--destination", default=".", help="Destination directory"
75    )
76    parser.add_argument(
77        "-q", "--quiet", action="store_true", help="Don't report about progress"
78    )
79    options = parser.parse_args()
80
81    if not os.path.exists(options.destination):
82        os.makedirs(options.destination)
83
84    try:
85        for resource in RESOURCES:
86            path = os.path.join(options.destination, resource)
87            download(path, resource, options.quiet)
88            unzip(path, options.quiet)
89    except KeyboardInterrupt:
90        print("Interrupted")
91
92
93if __name__ == "__main__":
94    main()
95