xref: /aosp_15_r20/external/pytorch/tools/download_mnist.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport gzip
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom urllib.error import URLError
6*da0073e9SAndroid Build Coastguard Workerfrom urllib.request import urlretrieve
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerMIRRORS = [
10*da0073e9SAndroid Build Coastguard Worker    "http://yann.lecun.com/exdb/mnist/",
11*da0073e9SAndroid Build Coastguard Worker    "https://ossci-datasets.s3.amazonaws.com/mnist/",
12*da0073e9SAndroid Build Coastguard Worker]
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerRESOURCES = [
15*da0073e9SAndroid Build Coastguard Worker    "train-images-idx3-ubyte.gz",
16*da0073e9SAndroid Build Coastguard Worker    "train-labels-idx1-ubyte.gz",
17*da0073e9SAndroid Build Coastguard Worker    "t10k-images-idx3-ubyte.gz",
18*da0073e9SAndroid Build Coastguard Worker    "t10k-labels-idx1-ubyte.gz",
19*da0073e9SAndroid Build Coastguard Worker]
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerdef report_download_progress(
23*da0073e9SAndroid Build Coastguard Worker    chunk_number: int,
24*da0073e9SAndroid Build Coastguard Worker    chunk_size: int,
25*da0073e9SAndroid Build Coastguard Worker    file_size: int,
26*da0073e9SAndroid Build Coastguard Worker) -> None:
27*da0073e9SAndroid Build Coastguard Worker    if file_size != -1:
28*da0073e9SAndroid Build Coastguard Worker        percent = min(1, (chunk_number * chunk_size) / file_size)
29*da0073e9SAndroid Build Coastguard Worker        bar = "#" * int(64 * percent)
30*da0073e9SAndroid Build Coastguard Worker        sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerdef download(destination_path: str, resource: str, quiet: bool) -> None:
34*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(destination_path):
35*da0073e9SAndroid Build Coastguard Worker        if not quiet:
36*da0073e9SAndroid Build Coastguard Worker            print(f"{destination_path} already exists, skipping ...")
37*da0073e9SAndroid Build Coastguard Worker    else:
38*da0073e9SAndroid Build Coastguard Worker        for mirror in MIRRORS:
39*da0073e9SAndroid Build Coastguard Worker            url = mirror + resource
40*da0073e9SAndroid Build Coastguard Worker            print(f"Downloading {url} ...")
41*da0073e9SAndroid Build Coastguard Worker            try:
42*da0073e9SAndroid Build Coastguard Worker                hook = None if quiet else report_download_progress
43*da0073e9SAndroid Build Coastguard Worker                urlretrieve(url, destination_path, reporthook=hook)
44*da0073e9SAndroid Build Coastguard Worker            except (URLError, ConnectionError) as e:
45*da0073e9SAndroid Build Coastguard Worker                print(f"Failed to download (trying next):\n{e}")
46*da0073e9SAndroid Build Coastguard Worker                continue
47*da0073e9SAndroid Build Coastguard Worker            finally:
48*da0073e9SAndroid Build Coastguard Worker                if not quiet:
49*da0073e9SAndroid Build Coastguard Worker                    # Just a newline.
50*da0073e9SAndroid Build Coastguard Worker                    print()
51*da0073e9SAndroid Build Coastguard Worker            break
52*da0073e9SAndroid Build Coastguard Worker        else:
53*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Error downloading resource!")
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerdef unzip(zipped_path: str, quiet: bool) -> None:
57*da0073e9SAndroid Build Coastguard Worker    unzipped_path = os.path.splitext(zipped_path)[0]
58*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(unzipped_path):
59*da0073e9SAndroid Build Coastguard Worker        if not quiet:
60*da0073e9SAndroid Build Coastguard Worker            print(f"{unzipped_path} already exists, skipping ... ")
61*da0073e9SAndroid Build Coastguard Worker        return
62*da0073e9SAndroid Build Coastguard Worker    with gzip.open(zipped_path, "rb") as zipped_file:
63*da0073e9SAndroid Build Coastguard Worker        with open(unzipped_path, "wb") as unzipped_file:
64*da0073e9SAndroid Build Coastguard Worker            unzipped_file.write(zipped_file.read())
65*da0073e9SAndroid Build Coastguard Worker            if not quiet:
66*da0073e9SAndroid Build Coastguard Worker                print(f"Unzipped {zipped_path} ...")
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
70*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
71*da0073e9SAndroid Build Coastguard Worker        description="Download the MNIST dataset from the internet"
72*da0073e9SAndroid Build Coastguard Worker    )
73*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
74*da0073e9SAndroid Build Coastguard Worker        "-d", "--destination", default=".", help="Destination directory"
75*da0073e9SAndroid Build Coastguard Worker    )
76*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
77*da0073e9SAndroid Build Coastguard Worker        "-q", "--quiet", action="store_true", help="Don't report about progress"
78*da0073e9SAndroid Build Coastguard Worker    )
79*da0073e9SAndroid Build Coastguard Worker    options = parser.parse_args()
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    if not os.path.exists(options.destination):
82*da0073e9SAndroid Build Coastguard Worker        os.makedirs(options.destination)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    try:
85*da0073e9SAndroid Build Coastguard Worker        for resource in RESOURCES:
86*da0073e9SAndroid Build Coastguard Worker            path = os.path.join(options.destination, resource)
87*da0073e9SAndroid Build Coastguard Worker            download(path, resource, options.quiet)
88*da0073e9SAndroid Build Coastguard Worker            unzip(path, options.quiet)
89*da0073e9SAndroid Build Coastguard Worker    except KeyboardInterrupt:
90*da0073e9SAndroid Build Coastguard Worker        print("Interrupted")
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
94*da0073e9SAndroid Build Coastguard Worker    main()
95