xref: /aosp_15_r20/external/pytorch/scripts/onnx/test.sh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/bin/bash
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerset -ex
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard WorkerUNKNOWN=()
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker# defaults
8*da0073e9SAndroid Build Coastguard WorkerPARALLEL=1
9*da0073e9SAndroid Build Coastguard Workerexport TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerwhile [[ $# -gt 0 ]]
12*da0073e9SAndroid Build Coastguard Workerdo
13*da0073e9SAndroid Build Coastguard Worker    arg="$1"
14*da0073e9SAndroid Build Coastguard Worker    case $arg in
15*da0073e9SAndroid Build Coastguard Worker        -p|--parallel)
16*da0073e9SAndroid Build Coastguard Worker            PARALLEL=1
17*da0073e9SAndroid Build Coastguard Worker            shift # past argument
18*da0073e9SAndroid Build Coastguard Worker            ;;
19*da0073e9SAndroid Build Coastguard Worker        *) # unknown option
20*da0073e9SAndroid Build Coastguard Worker            UNKNOWN+=("$1") # save it in an array for later
21*da0073e9SAndroid Build Coastguard Worker            shift # past argument
22*da0073e9SAndroid Build Coastguard Worker            ;;
23*da0073e9SAndroid Build Coastguard Worker    esac
24*da0073e9SAndroid Build Coastguard Workerdone
25*da0073e9SAndroid Build Coastguard Workerset -- "${UNKNOWN[@]}" # leave UNKNOWN
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker# allows coverage to run w/o failing due to a missing plug-in
28*da0073e9SAndroid Build Coastguard Workerpip install -e tools/coverage_plugins_package
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# realpath might not be available on MacOS
31*da0073e9SAndroid Build Coastguard Workerscript_path=$(python -c "import os; import sys; print(os.path.realpath(sys.argv[1]))" "${BASH_SOURCE[0]}")
32*da0073e9SAndroid Build Coastguard Workertop_dir=$(dirname $(dirname $(dirname "$script_path")))
33*da0073e9SAndroid Build Coastguard Workertest_paths=(
34*da0073e9SAndroid Build Coastguard Worker    "$top_dir/test/onnx"
35*da0073e9SAndroid Build Coastguard Worker)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerargs=()
38*da0073e9SAndroid Build Coastguard Workerargs+=("-v")
39*da0073e9SAndroid Build Coastguard Workerargs+=("--cov")
40*da0073e9SAndroid Build Coastguard Workerargs+=("--cov-report")
41*da0073e9SAndroid Build Coastguard Workerargs+=("xml:test/coverage.xml")
42*da0073e9SAndroid Build Coastguard Workerargs+=("--cov-append")
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workertime python "${top_dir}/test/run_test.py" --onnx --shard "$SHARD_NUMBER" 2 --verbose
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerif [[ "$SHARD_NUMBER" == "2" ]]; then
47*da0073e9SAndroid Build Coastguard Worker  # xdoctests on onnx
48*da0073e9SAndroid Build Coastguard Worker  xdoctest torch.onnx --style=google --options="+IGNORE_WHITESPACE"
49*da0073e9SAndroid Build Coastguard Workerfi
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerif [[ "$SHARD_NUMBER" == "2" ]]; then
52*da0073e9SAndroid Build Coastguard Worker  # Sanity check on torchbench w/ onnx
53*da0073e9SAndroid Build Coastguard Worker  pip install pandas
54*da0073e9SAndroid Build Coastguard Worker  log_folder="test/.torchbench_logs"
55*da0073e9SAndroid Build Coastguard Worker  device="cpu"
56*da0073e9SAndroid Build Coastguard Worker  modes=("accuracy" "performance")
57*da0073e9SAndroid Build Coastguard Worker  compilers=("dynamo-onnx" "torchscript-onnx")
58*da0073e9SAndroid Build Coastguard Worker  suites=("huggingface" "timm_models")
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker  mkdir -p "${log_folder}"
61*da0073e9SAndroid Build Coastguard Worker  for mode in "${modes[@]}"; do
62*da0073e9SAndroid Build Coastguard Worker    for compiler in "${compilers[@]}"; do
63*da0073e9SAndroid Build Coastguard Worker      for suite in "${suites[@]}"; do
64*da0073e9SAndroid Build Coastguard Worker        output_file="${log_folder}/${compiler}_${suite}_float32_inference_${device}_${mode}.csv"
65*da0073e9SAndroid Build Coastguard Worker        bench_file="benchmarks/dynamo/${suite}.py"
66*da0073e9SAndroid Build Coastguard Worker        bench_args=("--${mode}" --float32 "-d${device}" "--output=${output_file}" "--output-directory=${top_dir}" --inference -n5 "--${compiler}" --no-skip --dashboard --batch-size 1)
67*da0073e9SAndroid Build Coastguard Worker        # Run only selected model for each suite to quickly validate the benchmark suite works as expected.
68*da0073e9SAndroid Build Coastguard Worker        case "$suite" in
69*da0073e9SAndroid Build Coastguard Worker            "torchbench")
70*da0073e9SAndroid Build Coastguard Worker                bench_args+=(-k resnet18)
71*da0073e9SAndroid Build Coastguard Worker                ;;
72*da0073e9SAndroid Build Coastguard Worker            "huggingface")
73*da0073e9SAndroid Build Coastguard Worker                bench_args+=(-k ElectraForQuestionAnswering)
74*da0073e9SAndroid Build Coastguard Worker                ;;
75*da0073e9SAndroid Build Coastguard Worker            "timm_models")
76*da0073e9SAndroid Build Coastguard Worker                bench_args+=(-k lcnet_050)
77*da0073e9SAndroid Build Coastguard Worker                ;;
78*da0073e9SAndroid Build Coastguard Worker            *)
79*da0073e9SAndroid Build Coastguard Worker                echo "Unknown suite: ${suite}"
80*da0073e9SAndroid Build Coastguard Worker                exit 1
81*da0073e9SAndroid Build Coastguard Worker                ;;
82*da0073e9SAndroid Build Coastguard Worker        esac
83*da0073e9SAndroid Build Coastguard Worker        python "${top_dir}/${bench_file}" "${bench_args[@]}"
84*da0073e9SAndroid Build Coastguard Worker      done
85*da0073e9SAndroid Build Coastguard Worker    done
86*da0073e9SAndroid Build Coastguard Worker  done
87*da0073e9SAndroid Build Coastguard Workerfi
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker# Our CI expects both coverage.xml and .coverage to be within test/
90*da0073e9SAndroid Build Coastguard Workerif [ -d .coverage ]; then
91*da0073e9SAndroid Build Coastguard Worker  mv .coverage test/.coverage
92*da0073e9SAndroid Build Coastguard Workerfi
93