xref: /aosp_15_r20/frameworks/native/libs/renderengine/skia/filters/LutShader.cpp (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "LutShader.h"
17 
18 #include <SkM44.h>
19 #include <SkTileMode.h>
20 #include <common/trace.h>
21 #include <cutils/ashmem.h>
22 #include <math/half.h>
23 #include <sys/mman.h>
24 #include <ui/ColorSpace.h>
25 
26 #include "include/core/SkColorSpace.h"
27 #include "src/core/SkColorFilterPriv.h"
28 
29 using aidl::android::hardware::graphics::composer3::LutProperties;
30 
31 namespace android {
32 namespace renderengine {
33 namespace skia {
34 
35 static const SkString kShader = SkString(R"(
36     uniform shader image;
37     uniform shader lut;
38     uniform int size;
39     uniform int key;
40     uniform int dimension;
41     uniform vec3 luminanceCoefficients; // for CIE_Y
42 
43     vec4 main(vec2 xy) {
44         float4 rgba = image.eval(xy);
45         float3 linear = toLinearSrgb(rgba.rgb);
46         if (dimension == 1) {
47             // RGB
48             if (key == 0) {
49                 float indexR = linear.r * float(size - 1);
50                 float indexG = linear.g * float(size - 1);
51                 float indexB = linear.b * float(size - 1);
52                 float gainR = lut.eval(vec2(indexR, 0.0) + 0.5).r;
53                 float gainG = lut.eval(vec2(indexG, 0.0) + 0.5).r;
54                 float gainB = lut.eval(vec2(indexB, 0.0) + 0.5).r;
55                 return float4(linear.r * gainR, linear.g * gainG, linear.b * gainB, rgba.a);
56             // MAX_RGB
57             } else if (key == 1) {
58                 float maxRGB = max(linear.r, max(linear.g, linear.b));
59                 float index = maxRGB * float(size - 1);
60                 float gain = lut.eval(vec2(index, 0.0) + 0.5).r;
61                 return float4(linear * gain, rgba.a);
62             // CIE_Y
63             } else if (key == 2) {
64                 float y = dot(linear, luminanceCoefficients) / 3.0;
65                 float index = y * float(size - 1);
66                 float gain = lut.eval(vec2(index, 0.0) + 0.5).r;
67                 return float4(linear * gain, rgba.a);
68             }
69         } else if (dimension == 3) {
70             if (key == 0) {
71                 float tx = linear.r * float(size - 1);
72                 float ty = linear.g * float(size - 1);
73                 float tz = linear.b * float(size - 1);
74 
75                 // calculate lower and upper bounds for each dimension
76                 int x = int(tx);
77                 int y = int(ty);
78                 int z = int(tz);
79 
80                 int i000 = x + y * size + z * size * size;
81                 int i100 = i000 + 1;
82                 int i010 = i000 + size;
83                 int i110 = i000 + size + 1;
84                 int i001 = i000 + size * size;
85                 int i101 = i000 + size * size + 1;
86                 int i011 = i000 + size * size + size;
87                 int i111 = i000 + size * size + size + 1;
88 
89                 // get 1d normalized indices
90                 float c000 = float(i000) / float(size * size * size);
91                 float c100 = float(i100) / float(size * size * size);
92                 float c010 = float(i010) / float(size * size * size);
93                 float c110 = float(i110) / float(size * size * size);
94                 float c001 = float(i001) / float(size * size * size);
95                 float c101 = float(i101) / float(size * size * size);
96                 float c011 = float(i011) / float(size * size * size);
97                 float c111 = float(i111) / float(size * size * size);
98 
99                 //TODO(b/377984618): support Tetrahedral interpolation
100                 // perform trilinear interpolation
101                 float3 c00 = mix(lut.eval(vec2(c000, 0.0) + 0.5).rgb,
102                                  lut.eval(vec2(c100, 0.0) + 0.5).rgb, linear.r);
103                 float3 c01 = mix(lut.eval(vec2(c001, 0.0) + 0.5).rgb,
104                                  lut.eval(vec2(c101, 0.0) + 0.5).rgb, linear.r);
105                 float3 c10 = mix(lut.eval(vec2(c010, 0.0) + 0.5).rgb,
106                                  lut.eval(vec2(c110, 0.0) + 0.5).rgb, linear.r);
107                 float3 c11 = mix(lut.eval(vec2(c011, 0.0) + 0.5).rgb,
108                                  lut.eval(vec2(c111, 0.0) + 0.5).rgb, linear.r);
109 
110                 float3 c0 = mix(c00, c10, linear.g);
111                 float3 c1 = mix(c01, c11, linear.g);
112 
113                 float3 val = mix(c0, c1, linear.b);
114 
115                 return float4(val, rgba.a);
116             }
117         }
118         return rgba;
119     })");
120 
121 // same as shader::toColorSpace function
122 // TODO: put this function in a general place
toColorSpace(ui::Dataspace dataspace)123 static ColorSpace toColorSpace(ui::Dataspace dataspace) {
124     switch (dataspace & HAL_DATASPACE_STANDARD_MASK) {
125         case HAL_DATASPACE_STANDARD_BT709:
126             return ColorSpace::sRGB();
127         case HAL_DATASPACE_STANDARD_DCI_P3:
128             return ColorSpace::DisplayP3();
129         case HAL_DATASPACE_STANDARD_BT2020:
130         case HAL_DATASPACE_STANDARD_BT2020_CONSTANT_LUMINANCE:
131             return ColorSpace::BT2020();
132         case HAL_DATASPACE_STANDARD_ADOBE_RGB:
133             return ColorSpace::AdobeRGB();
134         case HAL_DATASPACE_STANDARD_BT601_625:
135         case HAL_DATASPACE_STANDARD_BT601_625_UNADJUSTED:
136         case HAL_DATASPACE_STANDARD_BT601_525:
137         case HAL_DATASPACE_STANDARD_BT601_525_UNADJUSTED:
138         case HAL_DATASPACE_STANDARD_BT470M:
139         case HAL_DATASPACE_STANDARD_FILM:
140         case HAL_DATASPACE_STANDARD_UNSPECIFIED:
141         default:
142             return ColorSpace::sRGB();
143     }
144 }
145 
generateLutShader(sk_sp<SkShader> input,const std::vector<float> & buffers,const int32_t offset,const int32_t length,const int32_t dimension,const int32_t size,const int32_t samplingKey,ui::Dataspace srcDataspace)146 sk_sp<SkShader> LutShader::generateLutShader(sk_sp<SkShader> input,
147                                              const std::vector<float>& buffers,
148                                              const int32_t offset, const int32_t length,
149                                              const int32_t dimension, const int32_t size,
150                                              const int32_t samplingKey,
151                                              ui::Dataspace srcDataspace) {
152     SFTRACE_NAME("lut shader");
153     std::vector<half> buffer(length * 4); // 4 is for RGBA
154     auto d = static_cast<LutProperties::Dimension>(dimension);
155     if (d == LutProperties::Dimension::ONE_D) {
156         auto it = buffers.begin() + offset;
157         std::generate(buffer.begin(), buffer.end(), [it, i = 0]() mutable {
158             float val = (i++ % 4 == 0) ? *it++ : 0.0f;
159             return half(val);
160         });
161     } else {
162         for (int i = 0; i < length; i++) {
163             buffer[i * 4] = half(buffers[offset + i]);
164             buffer[i * 4 + 1] = half(buffers[offset + length + i]);
165             buffer[i * 4 + 2] = half(buffers[offset + length * 2 + i]);
166             buffer[i * 4 + 3] = half(0);
167         }
168     }
169     /**
170      * 1D Lut RGB/MAX_RGB
171      * (R0, 0, 0, 0)
172      * (R1, 0, 0, 0)
173      *
174      * 1D Lut CIE_Y
175      * (Y0, 0, 0, 0)
176      * (Y1, 0, 0, 0)
177      * ...
178      *
179      * 3D Lut MAX_RGB
180      * (R0, G0, B0, 0)
181      * (R1, G1, B1, 0)
182      * ...
183      */
184     SkImageInfo info = SkImageInfo::Make(length /* the number of rgba */ * 4, 1,
185                                          kRGBA_F16_SkColorType, kPremul_SkAlphaType);
186     SkBitmap bitmap;
187     bitmap.allocPixels(info);
188     if (!bitmap.installPixels(info, buffer.data(), info.minRowBytes())) {
189         LOG_ALWAYS_FATAL("unable to install pixels");
190     }
191 
192     sk_sp<SkImage> lutImage = SkImages::RasterFromBitmap(bitmap);
193     mBuilder->child("image") = input;
194     mBuilder->child("lut") =
195             lutImage->makeRawShader(SkTileMode::kClamp, SkTileMode::kClamp,
196                                     d == LutProperties::Dimension::ONE_D
197                                             ? SkSamplingOptions(SkFilterMode::kLinear)
198                                             : SkSamplingOptions());
199 
200     const int uSize = static_cast<int>(size);
201     const int uKey = static_cast<int>(samplingKey);
202     const int uDimension = static_cast<int>(dimension);
203     if (static_cast<LutProperties::SamplingKey>(samplingKey) == LutProperties::SamplingKey::CIE_Y) {
204         // Use predefined colorspaces of input dataspace so that we can get D65 illuminant
205         mat3 toXYZMatrix(toColorSpace(srcDataspace).getRGBtoXYZ());
206         mBuilder->uniform("luminanceCoefficients") =
207                 SkV3{toXYZMatrix[0][1], toXYZMatrix[1][1], toXYZMatrix[2][1]};
208     } else {
209         mBuilder->uniform("luminanceCoefficients") = SkV3{1.f, 1.f, 1.f};
210     }
211     mBuilder->uniform("size") = uSize;
212     mBuilder->uniform("key") = uKey;
213     mBuilder->uniform("dimension") = uDimension;
214     return mBuilder->makeShader();
215 }
216 
lutShader(sk_sp<SkShader> & input,std::shared_ptr<gui::DisplayLuts> displayLuts,ui::Dataspace srcDataspace,sk_sp<SkColorSpace> outColorSpace)217 sk_sp<SkShader> LutShader::lutShader(sk_sp<SkShader>& input,
218                                      std::shared_ptr<gui::DisplayLuts> displayLuts,
219                                      ui::Dataspace srcDataspace,
220                                      sk_sp<SkColorSpace> outColorSpace) {
221     if (mBuilder == nullptr) {
222         const static SkRuntimeEffect::Result instance = SkRuntimeEffect::MakeForShader(kShader);
223         mBuilder = std::make_unique<SkRuntimeShaderBuilder>(instance.effect);
224     }
225 
226     auto& fd = displayLuts->getLutFileDescriptor();
227     if (fd.ok()) {
228         // de-gamma the image without changing the primaries
229         SkImage* baseImage = input->isAImage((SkMatrix*)nullptr, (SkTileMode*)nullptr);
230         sk_sp<SkColorSpace> baseColorSpace = baseImage && baseImage->colorSpace()
231                 ? baseImage->refColorSpace()
232                 : SkColorSpace::MakeSRGB();
233         sk_sp<SkColorSpace> lutMathColorSpace = baseColorSpace->makeLinearGamma();
234         input = input->makeWithWorkingColorSpace(lutMathColorSpace);
235 
236         auto& offsets = displayLuts->offsets;
237         auto& lutProperties = displayLuts->lutProperties;
238         std::vector<float> buffers;
239         int fullLength = offsets[lutProperties.size() - 1];
240         if (lutProperties[lutProperties.size() - 1].dimension == 1) {
241             fullLength += lutProperties[lutProperties.size() - 1].size;
242         } else {
243             fullLength += (lutProperties[lutProperties.size() - 1].size *
244                            lutProperties[lutProperties.size() - 1].size *
245                            lutProperties[lutProperties.size() - 1].size * 3);
246         }
247         size_t bufferSize = fullLength * sizeof(float);
248 
249         // decode the shared memory of luts
250         float* ptr =
251                 (float*)mmap(NULL, bufferSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0);
252         if (ptr == MAP_FAILED) {
253             LOG_ALWAYS_FATAL("mmap failed");
254         }
255         buffers = std::vector<float>(ptr, ptr + fullLength);
256         munmap(ptr, bufferSize);
257 
258         for (size_t i = 0; i < offsets.size(); i++) {
259             int bufferSizePerLut = (i == offsets.size() - 1) ? buffers.size() - offsets[i]
260                                                              : offsets[i + 1] - offsets[i];
261             // divide by 3 for 3d Lut because of 3 (RGB) channels
262             if (static_cast<LutProperties::Dimension>(lutProperties[i].dimension) ==
263                 LutProperties::Dimension::THREE_D) {
264                 bufferSizePerLut /= 3;
265             }
266             input = generateLutShader(input, buffers, offsets[i], bufferSizePerLut,
267                                       lutProperties[i].dimension, lutProperties[i].size,
268                                       lutProperties[i].samplingKey, srcDataspace);
269         }
270 
271         auto colorXformLutToDst =
272                 SkColorFilterPriv::MakeColorSpaceXform(lutMathColorSpace, outColorSpace);
273         input = input->makeWithColorFilter(colorXformLutToDst);
274     }
275     return input;
276 }
277 
278 } // namespace skia
279 } // namespace renderengine
280 } // namespace android