xref: /aosp_15_r20/external/pytorch/.ci/pytorch/perf_test/test_cpu_speed_mini_sequence_labeler.sh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/bin/bash
2set -e
3
4. ./common.sh
5
6test_cpu_speed_mini_sequence_labeler () {
7  echo "Testing: mini sequence labeler, CPU"
8
9  export OMP_NUM_THREADS=4
10  export MKL_NUM_THREADS=4
11
12  git clone https://github.com/pytorch/benchmark.git
13
14  cd benchmark/
15
16  git checkout 726567a455edbfda6199445922a8cfee82535664
17
18  cd scripts/mini_sequence_labeler
19
20  SAMPLE_ARRAY=()
21  NUM_RUNS=$1
22
23  for (( i=1; i<=NUM_RUNS; i++ )) do
24    runtime=$(get_runtime_of_command python main.py)
25    SAMPLE_ARRAY+=("${runtime}")
26  done
27
28  cd ../../..
29
30  stats=$(python ../get_stats.py "${SAMPLE_ARRAY[@]}")
31  echo "Runtime stats in seconds:"
32  echo "$stats"
33
34  if [ "$2" == "compare_with_baseline" ]; then
35    python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}"
36  elif [ "$2" == "compare_and_update" ]; then
37    python ../compare_with_baseline.py --test-name "${FUNCNAME[0]}" --sample-stats "${stats}" --update
38  fi
39}
40
41if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then
42  run_test test_cpu_speed_mini_sequence_labeler "$@"
43fi
44