1#!/usr/bin/env python3 2# Copyright 2020 NXP 3# SPDX-License-Identifier: MIT 4"""Downloads and extracts resources for unit tests. 5 6It is mandatory to run this script prior to running unit tests. Resources are stored as a tar.gz or a tar.bz2 archive and 7extracted into the test/testdata/shared folder. 8""" 9 10import tarfile 11import requests 12import os 13import uuid 14 15SCRIPTS_DIR = os.path.dirname(os.path.realpath(__file__)) 16EXTRACT_DIR = os.path.join(SCRIPTS_DIR, "..", "test") 17ARCHIVE_URL = "https://snapshots.linaro.org/components/pyarmnn-tests/pyarmnn_testdata_201100_20201022.tar.bz2" 18 19 20def download_resources(url, save_path): 21 # download archive - only support tar.gz or tar.bz2 22 print("Downloading '{}'".format(url)) 23 temp_filename = str(uuid.uuid4()) 24 if url.endswith(".tar.bz2"): 25 temp_filename += ".tar.bz2" 26 elif url.endswith(".tar.gz"): 27 temp_filename += ".tar.gz" 28 else: 29 raise RuntimeError("Unsupported file.") 30 try: 31 r = requests.get(url, stream=True) 32 except requests.exceptions.RequestException as e: 33 raise RuntimeError("Unable to download file: {}".format(e)) 34 file_path = os.path.join(save_path, temp_filename) 35 with open(file_path, 'wb') as f: 36 f.write(r.content) 37 38 # extract and delete temp file 39 with tarfile.open(file_path, "r:bz2" if temp_filename.endswith(".tar.bz2") else "r:gz") as tar: 40 print("Extracting '{}'".format(file_path)) 41 tar.extractall(save_path) 42 if os.path.exists(file_path): 43 print("Removing '{}'".format(file_path)) 44 os.remove(file_path) 45 46 47download_resources(ARCHIVE_URL, EXTRACT_DIR) 48