xref: /aosp_15_r20/external/libaom/aom_dsp/x86/highbd_subtract_sse2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <emmintrin.h>
14 #include <stddef.h>
15 
16 #include "config/aom_config.h"
17 #include "config/aom_dsp_rtcd.h"
18 
19 typedef void (*SubtractWxHFuncType)(int16_t *diff, ptrdiff_t diff_stride,
20                                     const uint16_t *src, ptrdiff_t src_stride,
21                                     const uint16_t *pred,
22                                     ptrdiff_t pred_stride);
23 
subtract_4x4(int16_t * diff,ptrdiff_t diff_stride,const uint16_t * src,ptrdiff_t src_stride,const uint16_t * pred,ptrdiff_t pred_stride)24 static void subtract_4x4(int16_t *diff, ptrdiff_t diff_stride,
25                          const uint16_t *src, ptrdiff_t src_stride,
26                          const uint16_t *pred, ptrdiff_t pred_stride) {
27   __m128i u0, u1, u2, u3;
28   __m128i v0, v1, v2, v3;
29   __m128i x0, x1, x2, x3;
30   int64_t *store_diff = (int64_t *)(diff + 0 * diff_stride);
31 
32   u0 = _mm_loadl_epi64((__m128i const *)(src + 0 * src_stride));
33   u1 = _mm_loadl_epi64((__m128i const *)(src + 1 * src_stride));
34   u2 = _mm_loadl_epi64((__m128i const *)(src + 2 * src_stride));
35   u3 = _mm_loadl_epi64((__m128i const *)(src + 3 * src_stride));
36 
37   v0 = _mm_loadl_epi64((__m128i const *)(pred + 0 * pred_stride));
38   v1 = _mm_loadl_epi64((__m128i const *)(pred + 1 * pred_stride));
39   v2 = _mm_loadl_epi64((__m128i const *)(pred + 2 * pred_stride));
40   v3 = _mm_loadl_epi64((__m128i const *)(pred + 3 * pred_stride));
41 
42   x0 = _mm_sub_epi16(u0, v0);
43   x1 = _mm_sub_epi16(u1, v1);
44   x2 = _mm_sub_epi16(u2, v2);
45   x3 = _mm_sub_epi16(u3, v3);
46 
47   _mm_storel_epi64((__m128i *)store_diff, x0);
48   store_diff = (int64_t *)(diff + 1 * diff_stride);
49   _mm_storel_epi64((__m128i *)store_diff, x1);
50   store_diff = (int64_t *)(diff + 2 * diff_stride);
51   _mm_storel_epi64((__m128i *)store_diff, x2);
52   store_diff = (int64_t *)(diff + 3 * diff_stride);
53   _mm_storel_epi64((__m128i *)store_diff, x3);
54 }
55 
subtract_4x8(int16_t * diff,ptrdiff_t diff_stride,const uint16_t * src,ptrdiff_t src_stride,const uint16_t * pred,ptrdiff_t pred_stride)56 static void subtract_4x8(int16_t *diff, ptrdiff_t diff_stride,
57                          const uint16_t *src, ptrdiff_t src_stride,
58                          const uint16_t *pred, ptrdiff_t pred_stride) {
59   __m128i u0, u1, u2, u3, u4, u5, u6, u7;
60   __m128i v0, v1, v2, v3, v4, v5, v6, v7;
61   __m128i x0, x1, x2, x3, x4, x5, x6, x7;
62   int64_t *store_diff = (int64_t *)(diff + 0 * diff_stride);
63 
64   u0 = _mm_loadl_epi64((__m128i const *)(src + 0 * src_stride));
65   u1 = _mm_loadl_epi64((__m128i const *)(src + 1 * src_stride));
66   u2 = _mm_loadl_epi64((__m128i const *)(src + 2 * src_stride));
67   u3 = _mm_loadl_epi64((__m128i const *)(src + 3 * src_stride));
68   u4 = _mm_loadl_epi64((__m128i const *)(src + 4 * src_stride));
69   u5 = _mm_loadl_epi64((__m128i const *)(src + 5 * src_stride));
70   u6 = _mm_loadl_epi64((__m128i const *)(src + 6 * src_stride));
71   u7 = _mm_loadl_epi64((__m128i const *)(src + 7 * src_stride));
72 
73   v0 = _mm_loadl_epi64((__m128i const *)(pred + 0 * pred_stride));
74   v1 = _mm_loadl_epi64((__m128i const *)(pred + 1 * pred_stride));
75   v2 = _mm_loadl_epi64((__m128i const *)(pred + 2 * pred_stride));
76   v3 = _mm_loadl_epi64((__m128i const *)(pred + 3 * pred_stride));
77   v4 = _mm_loadl_epi64((__m128i const *)(pred + 4 * pred_stride));
78   v5 = _mm_loadl_epi64((__m128i const *)(pred + 5 * pred_stride));
79   v6 = _mm_loadl_epi64((__m128i const *)(pred + 6 * pred_stride));
80   v7 = _mm_loadl_epi64((__m128i const *)(pred + 7 * pred_stride));
81 
82   x0 = _mm_sub_epi16(u0, v0);
83   x1 = _mm_sub_epi16(u1, v1);
84   x2 = _mm_sub_epi16(u2, v2);
85   x3 = _mm_sub_epi16(u3, v3);
86   x4 = _mm_sub_epi16(u4, v4);
87   x5 = _mm_sub_epi16(u5, v5);
88   x6 = _mm_sub_epi16(u6, v6);
89   x7 = _mm_sub_epi16(u7, v7);
90 
91   _mm_storel_epi64((__m128i *)store_diff, x0);
92   store_diff = (int64_t *)(diff + 1 * diff_stride);
93   _mm_storel_epi64((__m128i *)store_diff, x1);
94   store_diff = (int64_t *)(diff + 2 * diff_stride);
95   _mm_storel_epi64((__m128i *)store_diff, x2);
96   store_diff = (int64_t *)(diff + 3 * diff_stride);
97   _mm_storel_epi64((__m128i *)store_diff, x3);
98   store_diff = (int64_t *)(diff + 4 * diff_stride);
99   _mm_storel_epi64((__m128i *)store_diff, x4);
100   store_diff = (int64_t *)(diff + 5 * diff_stride);
101   _mm_storel_epi64((__m128i *)store_diff, x5);
102   store_diff = (int64_t *)(diff + 6 * diff_stride);
103   _mm_storel_epi64((__m128i *)store_diff, x6);
104   store_diff = (int64_t *)(diff + 7 * diff_stride);
105   _mm_storel_epi64((__m128i *)store_diff, x7);
106 }
107 
subtract_8x4(int16_t * diff,ptrdiff_t diff_stride,const uint16_t * src,ptrdiff_t src_stride,const uint16_t * pred,ptrdiff_t pred_stride)108 static void subtract_8x4(int16_t *diff, ptrdiff_t diff_stride,
109                          const uint16_t *src, ptrdiff_t src_stride,
110                          const uint16_t *pred, ptrdiff_t pred_stride) {
111   __m128i u0, u1, u2, u3;
112   __m128i v0, v1, v2, v3;
113   __m128i x0, x1, x2, x3;
114 
115   u0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
116   u1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
117   u2 = _mm_loadu_si128((__m128i const *)(src + 2 * src_stride));
118   u3 = _mm_loadu_si128((__m128i const *)(src + 3 * src_stride));
119 
120   v0 = _mm_loadu_si128((__m128i const *)(pred + 0 * pred_stride));
121   v1 = _mm_loadu_si128((__m128i const *)(pred + 1 * pred_stride));
122   v2 = _mm_loadu_si128((__m128i const *)(pred + 2 * pred_stride));
123   v3 = _mm_loadu_si128((__m128i const *)(pred + 3 * pred_stride));
124 
125   x0 = _mm_sub_epi16(u0, v0);
126   x1 = _mm_sub_epi16(u1, v1);
127   x2 = _mm_sub_epi16(u2, v2);
128   x3 = _mm_sub_epi16(u3, v3);
129 
130   _mm_storeu_si128((__m128i *)(diff + 0 * diff_stride), x0);
131   _mm_storeu_si128((__m128i *)(diff + 1 * diff_stride), x1);
132   _mm_storeu_si128((__m128i *)(diff + 2 * diff_stride), x2);
133   _mm_storeu_si128((__m128i *)(diff + 3 * diff_stride), x3);
134 }
135 
subtract_8x8(int16_t * diff,ptrdiff_t diff_stride,const uint16_t * src,ptrdiff_t src_stride,const uint16_t * pred,ptrdiff_t pred_stride)136 static void subtract_8x8(int16_t *diff, ptrdiff_t diff_stride,
137                          const uint16_t *src, ptrdiff_t src_stride,
138                          const uint16_t *pred, ptrdiff_t pred_stride) {
139   __m128i u0, u1, u2, u3, u4, u5, u6, u7;
140   __m128i v0, v1, v2, v3, v4, v5, v6, v7;
141   __m128i x0, x1, x2, x3, x4, x5, x6, x7;
142 
143   u0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
144   u1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
145   u2 = _mm_loadu_si128((__m128i const *)(src + 2 * src_stride));
146   u3 = _mm_loadu_si128((__m128i const *)(src + 3 * src_stride));
147   u4 = _mm_loadu_si128((__m128i const *)(src + 4 * src_stride));
148   u5 = _mm_loadu_si128((__m128i const *)(src + 5 * src_stride));
149   u6 = _mm_loadu_si128((__m128i const *)(src + 6 * src_stride));
150   u7 = _mm_loadu_si128((__m128i const *)(src + 7 * src_stride));
151 
152   v0 = _mm_loadu_si128((__m128i const *)(pred + 0 * pred_stride));
153   v1 = _mm_loadu_si128((__m128i const *)(pred + 1 * pred_stride));
154   v2 = _mm_loadu_si128((__m128i const *)(pred + 2 * pred_stride));
155   v3 = _mm_loadu_si128((__m128i const *)(pred + 3 * pred_stride));
156   v4 = _mm_loadu_si128((__m128i const *)(pred + 4 * pred_stride));
157   v5 = _mm_loadu_si128((__m128i const *)(pred + 5 * pred_stride));
158   v6 = _mm_loadu_si128((__m128i const *)(pred + 6 * pred_stride));
159   v7 = _mm_loadu_si128((__m128i const *)(pred + 7 * pred_stride));
160 
161   x0 = _mm_sub_epi16(u0, v0);
162   x1 = _mm_sub_epi16(u1, v1);
163   x2 = _mm_sub_epi16(u2, v2);
164   x3 = _mm_sub_epi16(u3, v3);
165   x4 = _mm_sub_epi16(u4, v4);
166   x5 = _mm_sub_epi16(u5, v5);
167   x6 = _mm_sub_epi16(u6, v6);
168   x7 = _mm_sub_epi16(u7, v7);
169 
170   _mm_storeu_si128((__m128i *)(diff + 0 * diff_stride), x0);
171   _mm_storeu_si128((__m128i *)(diff + 1 * diff_stride), x1);
172   _mm_storeu_si128((__m128i *)(diff + 2 * diff_stride), x2);
173   _mm_storeu_si128((__m128i *)(diff + 3 * diff_stride), x3);
174   _mm_storeu_si128((__m128i *)(diff + 4 * diff_stride), x4);
175   _mm_storeu_si128((__m128i *)(diff + 5 * diff_stride), x5);
176   _mm_storeu_si128((__m128i *)(diff + 6 * diff_stride), x6);
177   _mm_storeu_si128((__m128i *)(diff + 7 * diff_stride), x7);
178 }
179 
180 #define STACK_V(h, fun)                                                        \
181   do {                                                                         \
182     fun(diff, diff_stride, src, src_stride, pred, pred_stride);                \
183     fun(diff + diff_stride * h, diff_stride, src + src_stride * h, src_stride, \
184         pred + pred_stride * h, pred_stride);                                  \
185   } while (0)
186 
187 #define STACK_H(w, fun)                                                     \
188   do {                                                                      \
189     fun(diff, diff_stride, src, src_stride, pred, pred_stride);             \
190     fun(diff + w, diff_stride, src + w, src_stride, pred + w, pred_stride); \
191   } while (0)
192 
193 #define SUBTRACT_FUN(size)                                               \
194   static void subtract_##size(int16_t *diff, ptrdiff_t diff_stride,      \
195                               const uint16_t *src, ptrdiff_t src_stride, \
196                               const uint16_t *pred, ptrdiff_t pred_stride)
197 
198 SUBTRACT_FUN(8x16) { STACK_V(8, subtract_8x8); }
199 SUBTRACT_FUN(16x8) { STACK_H(8, subtract_8x8); }
200 SUBTRACT_FUN(16x16) { STACK_V(8, subtract_16x8); }
201 SUBTRACT_FUN(16x32) { STACK_V(16, subtract_16x16); }
202 SUBTRACT_FUN(32x16) { STACK_H(16, subtract_16x16); }
203 SUBTRACT_FUN(32x32) { STACK_V(16, subtract_32x16); }
204 SUBTRACT_FUN(32x64) { STACK_V(32, subtract_32x32); }
205 SUBTRACT_FUN(64x32) { STACK_H(32, subtract_32x32); }
206 SUBTRACT_FUN(64x64) { STACK_V(32, subtract_64x32); }
207 SUBTRACT_FUN(64x128) { STACK_V(64, subtract_64x64); }
208 SUBTRACT_FUN(128x64) { STACK_H(64, subtract_64x64); }
209 SUBTRACT_FUN(128x128) { STACK_V(64, subtract_128x64); }
210 SUBTRACT_FUN(4x16) { STACK_V(8, subtract_4x8); }
211 SUBTRACT_FUN(16x4) { STACK_H(8, subtract_8x4); }
212 SUBTRACT_FUN(8x32) { STACK_V(16, subtract_8x16); }
213 SUBTRACT_FUN(32x8) { STACK_H(16, subtract_16x8); }
214 SUBTRACT_FUN(16x64) { STACK_V(32, subtract_16x32); }
215 SUBTRACT_FUN(64x16) { STACK_H(32, subtract_32x16); }
216 
getSubtractFunc(int rows,int cols)217 static SubtractWxHFuncType getSubtractFunc(int rows, int cols) {
218   if (rows == 4) {
219     if (cols == 4) return subtract_4x4;
220     if (cols == 8) return subtract_8x4;
221     if (cols == 16) return subtract_16x4;
222   }
223   if (rows == 8) {
224     if (cols == 4) return subtract_4x8;
225     if (cols == 8) return subtract_8x8;
226     if (cols == 16) return subtract_16x8;
227     if (cols == 32) return subtract_32x8;
228   }
229   if (rows == 16) {
230     if (cols == 4) return subtract_4x16;
231     if (cols == 8) return subtract_8x16;
232     if (cols == 16) return subtract_16x16;
233     if (cols == 32) return subtract_32x16;
234     if (cols == 64) return subtract_64x16;
235   }
236   if (rows == 32) {
237     if (cols == 8) return subtract_8x32;
238     if (cols == 16) return subtract_16x32;
239     if (cols == 32) return subtract_32x32;
240     if (cols == 64) return subtract_64x32;
241   }
242   if (rows == 64) {
243     if (cols == 16) return subtract_16x64;
244     if (cols == 32) return subtract_32x64;
245     if (cols == 64) return subtract_64x64;
246     if (cols == 128) return subtract_128x64;
247   }
248   if (rows == 128) {
249     if (cols == 64) return subtract_64x128;
250     if (cols == 128) return subtract_128x128;
251   }
252   assert(0);
253   return NULL;
254 }
255 
aom_highbd_subtract_block_sse2(int rows,int cols,int16_t * diff,ptrdiff_t diff_stride,const uint8_t * src8,ptrdiff_t src_stride,const uint8_t * pred8,ptrdiff_t pred_stride)256 void aom_highbd_subtract_block_sse2(int rows, int cols, int16_t *diff,
257                                     ptrdiff_t diff_stride, const uint8_t *src8,
258                                     ptrdiff_t src_stride, const uint8_t *pred8,
259                                     ptrdiff_t pred_stride) {
260   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
261   uint16_t *pred = CONVERT_TO_SHORTPTR(pred8);
262   SubtractWxHFuncType func;
263 
264   func = getSubtractFunc(rows, cols);
265   func(diff, diff_stride, src, src_stride, pred, pred_stride);
266 }
267