xref: /aosp_15_r20/external/clpeak/src/compute_dp.cpp (revision 1cd03ba3888297bc945f2c84574e105e3ced3e34)
1*1cd03ba3SJeremy Kemp #include <clpeak.h>
2*1cd03ba3SJeremy Kemp 
runComputeDP(cl::CommandQueue & queue,cl::Program & prog,device_info_t & devInfo)3*1cd03ba3SJeremy Kemp int clPeak::runComputeDP(cl::CommandQueue &queue, cl::Program &prog, device_info_t &devInfo)
4*1cd03ba3SJeremy Kemp {
5*1cd03ba3SJeremy Kemp   float timed, gflops;
6*1cd03ba3SJeremy Kemp   cl_uint workPerWI;
7*1cd03ba3SJeremy Kemp   cl::NDRange globalSize, localSize;
8*1cd03ba3SJeremy Kemp   cl_double A = 1.3f;
9*1cd03ba3SJeremy Kemp   uint iters = devInfo.computeIters;
10*1cd03ba3SJeremy Kemp 
11*1cd03ba3SJeremy Kemp   if (!isComputeDP)
12*1cd03ba3SJeremy Kemp     return 0;
13*1cd03ba3SJeremy Kemp 
14*1cd03ba3SJeremy Kemp   if (!devInfo.doubleSupported)
15*1cd03ba3SJeremy Kemp   {
16*1cd03ba3SJeremy Kemp     log->print(NEWLINE TAB TAB "No double precision support! Skipped" NEWLINE);
17*1cd03ba3SJeremy Kemp     return 0;
18*1cd03ba3SJeremy Kemp   }
19*1cd03ba3SJeremy Kemp 
20*1cd03ba3SJeremy Kemp   try
21*1cd03ba3SJeremy Kemp   {
22*1cd03ba3SJeremy Kemp     log->print(NEWLINE TAB TAB "Double-precision compute (GFLOPS)" NEWLINE);
23*1cd03ba3SJeremy Kemp     log->xmlOpenTag("double_precision_compute");
24*1cd03ba3SJeremy Kemp     log->xmlAppendAttribs("unit", "gflops");
25*1cd03ba3SJeremy Kemp 
26*1cd03ba3SJeremy Kemp     cl::Context ctx = queue.getInfo<CL_QUEUE_CONTEXT>();
27*1cd03ba3SJeremy Kemp 
28*1cd03ba3SJeremy Kemp     uint64_t globalWIs = (devInfo.numCUs) * (devInfo.computeDPWgsPerCU) * (devInfo.maxWGSize);
29*1cd03ba3SJeremy Kemp     uint64_t t = std::min((globalWIs * sizeof(cl_double)), devInfo.maxAllocSize) / sizeof(cl_double);
30*1cd03ba3SJeremy Kemp     globalWIs = roundToMultipleOf(t, devInfo.maxWGSize);
31*1cd03ba3SJeremy Kemp 
32*1cd03ba3SJeremy Kemp     cl::Buffer outputBuf = cl::Buffer(ctx, CL_MEM_WRITE_ONLY, (globalWIs * sizeof(cl_double)));
33*1cd03ba3SJeremy Kemp 
34*1cd03ba3SJeremy Kemp     globalSize = globalWIs;
35*1cd03ba3SJeremy Kemp     localSize = devInfo.maxWGSize;
36*1cd03ba3SJeremy Kemp 
37*1cd03ba3SJeremy Kemp     cl::Kernel kernel_v1(prog, "compute_dp_v1");
38*1cd03ba3SJeremy Kemp     kernel_v1.setArg(0, outputBuf), kernel_v1.setArg(1, A);
39*1cd03ba3SJeremy Kemp 
40*1cd03ba3SJeremy Kemp     cl::Kernel kernel_v2(prog, "compute_dp_v2");
41*1cd03ba3SJeremy Kemp     kernel_v2.setArg(0, outputBuf), kernel_v2.setArg(1, A);
42*1cd03ba3SJeremy Kemp 
43*1cd03ba3SJeremy Kemp     cl::Kernel kernel_v4(prog, "compute_dp_v4");
44*1cd03ba3SJeremy Kemp     kernel_v4.setArg(0, outputBuf), kernel_v4.setArg(1, A);
45*1cd03ba3SJeremy Kemp 
46*1cd03ba3SJeremy Kemp     cl::Kernel kernel_v8(prog, "compute_dp_v8");
47*1cd03ba3SJeremy Kemp     kernel_v8.setArg(0, outputBuf), kernel_v8.setArg(1, A);
48*1cd03ba3SJeremy Kemp 
49*1cd03ba3SJeremy Kemp     cl::Kernel kernel_v16(prog, "compute_dp_v16");
50*1cd03ba3SJeremy Kemp     kernel_v16.setArg(0, outputBuf), kernel_v16.setArg(1, A);
51*1cd03ba3SJeremy Kemp 
52*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
53*1cd03ba3SJeremy Kemp     // Vector width 1
54*1cd03ba3SJeremy Kemp     if (!forceTest || strcmp(specifiedTestName, "double") == 0)
55*1cd03ba3SJeremy Kemp     {
56*1cd03ba3SJeremy Kemp       log->print(TAB TAB TAB "double   : ");
57*1cd03ba3SJeremy Kemp 
58*1cd03ba3SJeremy Kemp       workPerWI = 4096; // Indicates flops executed per work-item
59*1cd03ba3SJeremy Kemp 
60*1cd03ba3SJeremy Kemp       timed = run_kernel(queue, kernel_v1, globalSize, localSize, iters);
61*1cd03ba3SJeremy Kemp 
62*1cd03ba3SJeremy Kemp       gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
63*1cd03ba3SJeremy Kemp 
64*1cd03ba3SJeremy Kemp       log->print(gflops);
65*1cd03ba3SJeremy Kemp       log->print(NEWLINE);
66*1cd03ba3SJeremy Kemp       log->xmlRecord("double", gflops);
67*1cd03ba3SJeremy Kemp     }
68*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
69*1cd03ba3SJeremy Kemp 
70*1cd03ba3SJeremy Kemp     // Vector width 2
71*1cd03ba3SJeremy Kemp     if (!forceTest || strcmp(specifiedTestName, "double2") == 0)
72*1cd03ba3SJeremy Kemp     {
73*1cd03ba3SJeremy Kemp       log->print(TAB TAB TAB "double2  : ");
74*1cd03ba3SJeremy Kemp 
75*1cd03ba3SJeremy Kemp       workPerWI = 4096;
76*1cd03ba3SJeremy Kemp 
77*1cd03ba3SJeremy Kemp       timed = run_kernel(queue, kernel_v2, globalSize, localSize, iters);
78*1cd03ba3SJeremy Kemp 
79*1cd03ba3SJeremy Kemp       gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
80*1cd03ba3SJeremy Kemp 
81*1cd03ba3SJeremy Kemp       log->print(gflops);
82*1cd03ba3SJeremy Kemp       log->print(NEWLINE);
83*1cd03ba3SJeremy Kemp       log->xmlRecord("double2", gflops);
84*1cd03ba3SJeremy Kemp     }
85*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
86*1cd03ba3SJeremy Kemp 
87*1cd03ba3SJeremy Kemp     // Vector width 4
88*1cd03ba3SJeremy Kemp     if (!forceTest || strcmp(specifiedTestName, "double4") == 0)
89*1cd03ba3SJeremy Kemp     {
90*1cd03ba3SJeremy Kemp       log->print(TAB TAB TAB "double4  : ");
91*1cd03ba3SJeremy Kemp 
92*1cd03ba3SJeremy Kemp       workPerWI = 4096;
93*1cd03ba3SJeremy Kemp 
94*1cd03ba3SJeremy Kemp       timed = run_kernel(queue, kernel_v4, globalSize, localSize, iters);
95*1cd03ba3SJeremy Kemp 
96*1cd03ba3SJeremy Kemp       gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
97*1cd03ba3SJeremy Kemp 
98*1cd03ba3SJeremy Kemp       log->print(gflops);
99*1cd03ba3SJeremy Kemp       log->print(NEWLINE);
100*1cd03ba3SJeremy Kemp       log->xmlRecord("double4", gflops);
101*1cd03ba3SJeremy Kemp     }
102*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
103*1cd03ba3SJeremy Kemp 
104*1cd03ba3SJeremy Kemp     // Vector width 8
105*1cd03ba3SJeremy Kemp     if (!forceTest || strcmp(specifiedTestName, "double8") == 0)
106*1cd03ba3SJeremy Kemp     {
107*1cd03ba3SJeremy Kemp       log->print(TAB TAB TAB "double8  : ");
108*1cd03ba3SJeremy Kemp       workPerWI = 4096;
109*1cd03ba3SJeremy Kemp 
110*1cd03ba3SJeremy Kemp       timed = run_kernel(queue, kernel_v8, globalSize, localSize, iters);
111*1cd03ba3SJeremy Kemp 
112*1cd03ba3SJeremy Kemp       gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
113*1cd03ba3SJeremy Kemp 
114*1cd03ba3SJeremy Kemp       log->print(gflops);
115*1cd03ba3SJeremy Kemp       log->print(NEWLINE);
116*1cd03ba3SJeremy Kemp       log->xmlRecord("double8", gflops);
117*1cd03ba3SJeremy Kemp     }
118*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
119*1cd03ba3SJeremy Kemp 
120*1cd03ba3SJeremy Kemp     // Vector width 16
121*1cd03ba3SJeremy Kemp     if (!forceTest || strcmp(specifiedTestName, "double16") == 0)
122*1cd03ba3SJeremy Kemp     {
123*1cd03ba3SJeremy Kemp       log->print(TAB TAB TAB "double16 : ");
124*1cd03ba3SJeremy Kemp 
125*1cd03ba3SJeremy Kemp       workPerWI = 4096;
126*1cd03ba3SJeremy Kemp 
127*1cd03ba3SJeremy Kemp       timed = run_kernel(queue, kernel_v16, globalSize, localSize, iters);
128*1cd03ba3SJeremy Kemp 
129*1cd03ba3SJeremy Kemp       gflops = (static_cast<float>(globalWIs) * static_cast<float>(workPerWI)) / timed / 1e3f;
130*1cd03ba3SJeremy Kemp 
131*1cd03ba3SJeremy Kemp       log->print(gflops);
132*1cd03ba3SJeremy Kemp       log->print(NEWLINE);
133*1cd03ba3SJeremy Kemp       log->xmlRecord("double16", gflops);
134*1cd03ba3SJeremy Kemp     }
135*1cd03ba3SJeremy Kemp     ///////////////////////////////////////////////////////////////////////////
136*1cd03ba3SJeremy Kemp     log->xmlCloseTag(); // double_precision_compute
137*1cd03ba3SJeremy Kemp   }
138*1cd03ba3SJeremy Kemp   catch (cl::Error &error)
139*1cd03ba3SJeremy Kemp   {
140*1cd03ba3SJeremy Kemp     stringstream ss;
141*1cd03ba3SJeremy Kemp     ss << error.what() << " (" << error.err() << ")" NEWLINE
142*1cd03ba3SJeremy Kemp        << TAB TAB TAB "Tests skipped" NEWLINE;
143*1cd03ba3SJeremy Kemp     log->print(ss.str());
144*1cd03ba3SJeremy Kemp     return -1;
145*1cd03ba3SJeremy Kemp   }
146*1cd03ba3SJeremy Kemp 
147*1cd03ba3SJeremy Kemp   return 0;
148*1cd03ba3SJeremy Kemp }
149