1#!/bin/bash 2set -e 3 4. ./common.sh 5 6test_cpu_speed_mnist () { 7 echo "Testing: MNIST, CPU" 8 9 export OMP_NUM_THREADS=4 10 export MKL_NUM_THREADS=4 11 12 git clone https://github.com/pytorch/examples.git -b perftests 13 14 cd examples/mnist 15 16 conda install -c pytorch torchvision-cpu 17 18 # Download data 19 python main.py --epochs 0 20 21 SAMPLE_ARRAY=() 22 NUM_RUNS=$1 23 24 for (( i=1; i<=NUM_RUNS; i++ )) do 25 runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log) 26 echo "$runtime" 27 SAMPLE_ARRAY+=("${runtime}") 28 done 29 30 cd ../.. 31 32 stats=$(python ../get_stats.py "${SAMPLE_ARRAY[@]}") 33 echo "Runtime stats in seconds:" 34 echo "$stats" 35 36 if [ "$2" == "compare_with_baseline" ]; then 37 python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}" 38 elif [ "$2" == "compare_and_update" ]; then 39 python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}" --update 40 fi 41} 42 43if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then 44 run_test test_cpu_speed_mnist "$@" 45fi 46