xref: /aosp_15_r20/external/pytorch/.ci/pytorch/perf_test/test_cpu_speed_mnist.sh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/bin/bash
2set -e
3
4. ./common.sh
5
6test_cpu_speed_mnist () {
7  echo "Testing: MNIST, CPU"
8
9  export OMP_NUM_THREADS=4
10  export MKL_NUM_THREADS=4
11
12  git clone https://github.com/pytorch/examples.git -b perftests
13
14  cd examples/mnist
15
16  conda install -c pytorch torchvision-cpu
17
18  # Download data
19  python main.py --epochs 0
20
21  SAMPLE_ARRAY=()
22  NUM_RUNS=$1
23
24  for (( i=1; i<=NUM_RUNS; i++ )) do
25    runtime=$(get_runtime_of_command python main.py --epochs 1 --no-log)
26    echo "$runtime"
27    SAMPLE_ARRAY+=("${runtime}")
28  done
29
30  cd ../..
31
32  stats=$(python ../get_stats.py "${SAMPLE_ARRAY[@]}")
33  echo "Runtime stats in seconds:"
34  echo "$stats"
35
36  if [ "$2" == "compare_with_baseline" ]; then
37    python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}"
38  elif [ "$2" == "compare_and_update" ]; then
39    python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}" --update
40  fi
41}
42
43if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
44  run_test test_cpu_speed_mnist "$@"
45fi
46