1 // Copyright 2021 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Modified from
16 // https://raw.githubusercontent.com/google/atheris/034284dc4bb1ad4f4ab6ba5d34fb4dca7c633660/fuzzed_data_provider.cc
17 //
18 // Original license and copyright notices:
19 //
20 // Copyright 2020 Google LLC
21 //
22 // Licensed under the Apache License, Version 2.0 (the "License");
23 // you may not use this file except in compliance with the License.
24 // You may obtain a copy of the License at
25 //
26 //      http://www.apache.org/licenses/LICENSE-2.0
27 //
28 // Unless required by applicable law or agreed to in writing, software
29 // distributed under the License is distributed on an "AS IS" BASIS,
30 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31 // See the License for the specific language governing permissions and
32 // limitations under the License.
33 //
34 // Modified from
35 // https://github.com/llvm/llvm-project/blob/70de7e0d9a95b7fcd7c105b06bd90fdf4e01f563/compiler-rt/include/fuzzer/FuzzedDataProvider.h
36 //
37 // Original license and copyright notices:
38 //
39 //===- FuzzedDataProvider.h - Utility header for fuzz targets ---*- C++ -* ===//
40 //
41 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42 // See https://llvm.org/LICENSE.txt for license information.
43 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44 //
45 
46 #include <algorithm>
47 #include <cstdint>
48 #include <limits>
49 #include <string>
50 #include <tuple>
51 #include <type_traits>
52 
53 #include "com_code_intelligence_jazzer_driver_FuzzedDataProviderImpl.h"
54 
55 namespace {
56 
57 jfieldID gDataPtrField = nullptr;
58 jfieldID gRemainingBytesField = nullptr;
59 
ThrowIllegalArgumentException(JNIEnv & env,const std::string & message)60 void ThrowIllegalArgumentException(JNIEnv &env, const std::string &message) {
61   jclass illegal_argument_exception =
62       env.FindClass("java/lang/IllegalArgumentException");
63   env.ThrowNew(illegal_argument_exception, message.c_str());
64 }
65 
66 template <typename T>
67 struct JniArrayType {};
68 
69 #define JNI_ARRAY_TYPE(lower_case, sentence_case)                    \
70   template <>                                                        \
71   struct JniArrayType<j##lower_case> {                               \
72     typedef j##lower_case type;                                      \
73     typedef j##lower_case##Array array_type;                         \
74     static constexpr array_type (JNIEnv::*kNewArrayFunc)(jsize) =    \
75         &JNIEnv::New##sentence_case##Array;                          \
76     static constexpr void (JNIEnv::*kSetArrayRegionFunc)(            \
77         array_type array, jsize start, jsize len,                    \
78         const type *buf) = &JNIEnv::Set##sentence_case##ArrayRegion; \
79   };
80 
81 JNI_ARRAY_TYPE(boolean, Boolean);
82 JNI_ARRAY_TYPE(byte, Byte);
83 JNI_ARRAY_TYPE(short, Short);
84 JNI_ARRAY_TYPE(int, Int);
85 JNI_ARRAY_TYPE(long, Long);
86 
87 template <typename T>
88 typename JniArrayType<T>::array_type JNICALL
ConsumeIntegralArray(JNIEnv & env,jobject self,jint max_length)89 ConsumeIntegralArray(JNIEnv &env, jobject self, jint max_length) {
90   if (max_length < 0) {
91     ThrowIllegalArgumentException(env, "maxLength must not be negative");
92     return nullptr;
93   }
94   // Arrays of integral types are considered data and thus consumed from the
95   // beginning of the buffer.
96   const auto *dataPtr =
97       reinterpret_cast<const uint8_t *>(env.GetLongField(self, gDataPtrField));
98   jint remainingBytes = env.GetIntField(self, gRemainingBytesField);
99 
100   jint max_num_bytes =
101       std::min(static_cast<jint>(sizeof(T)) * max_length, remainingBytes);
102   jsize actual_length = max_num_bytes / sizeof(T);
103   jint actual_num_bytes = sizeof(T) * actual_length;
104   auto array = (env.*(JniArrayType<T>::kNewArrayFunc))(actual_length);
105   (env.*(JniArrayType<T>::kSetArrayRegionFunc))(
106       array, 0, actual_length, reinterpret_cast<const T *>(dataPtr));
107 
108   env.SetLongField(self, gDataPtrField, (jlong)(dataPtr + actual_num_bytes));
109   env.SetIntField(self, gRemainingBytesField,
110                   remainingBytes - actual_num_bytes);
111 
112   return array;
113 }
114 
115 template <typename T>
ConsumeRemainingAsArray(JNIEnv & env,jobject self)116 jbyteArray JNICALL ConsumeRemainingAsArray(JNIEnv &env, jobject self) {
117   return ConsumeIntegralArray<T>(env, self, std::numeric_limits<jint>::max());
118 }
119 
120 template <typename T>
ConsumeIntegralInRange(JNIEnv & env,jobject self,T min,T max)121 T JNICALL ConsumeIntegralInRange(JNIEnv &env, jobject self, T min, T max) {
122   uint64_t range = static_cast<uint64_t>(max) - min;
123   uint64_t result = 0;
124   jint offset = 0;
125 
126   const auto *dataPtr =
127       reinterpret_cast<const uint8_t *>(env.GetLongField(self, gDataPtrField));
128   jint remainingBytes = env.GetIntField(self, gRemainingBytesField);
129 
130   while (offset < 8 * sizeof(T) && (range >> offset) > 0 &&
131          remainingBytes != 0) {
132     --remainingBytes;
133     result = (result << 8u) | dataPtr[remainingBytes];
134     offset += 8;
135   }
136 
137   env.SetIntField(self, gRemainingBytesField, remainingBytes);
138   // dataPtr hasn't been modified, so we don't need to update gDataPtrField.
139 
140   if (range != std::numeric_limits<T>::max())
141     // We accept modulo bias in favor of reading a dynamic number of bytes as
142     // this would make it harder for the fuzzer to mutate towards values from
143     // the table of recent compares.
144     result = result % (range + 1);
145 
146   return static_cast<T>(min + result);
147 }
148 
149 template <typename T>
ConsumeIntegral(JNIEnv & env,jobject self)150 T JNICALL ConsumeIntegral(JNIEnv &env, jobject self) {
151   // First generate an unsigned value and then (safely) cast it to a signed
152   // integral type. By doing this rather than calling ConsumeIntegralInRange
153   // with bounds [signed_min, signed_max], we ensure that there is a direct
154   // correspondence between the consumed raw bytes and the result (e.g., 0
155   // corresponds to 0 and not to signed_min). This should help mutating
156   // towards entries of the table of recent compares.
157   using UnsignedT = typename std::make_unsigned<T>::type;
158   static_assert(
159       std::numeric_limits<UnsignedT>::is_modulo,
160       "Unsigned to signed conversion requires modulo-based overflow handling");
161   return static_cast<T>(ConsumeIntegralInRange<UnsignedT>(
162       env, self, 0, std::numeric_limits<UnsignedT>::max()));
163 }
164 
ConsumeBool(JNIEnv & env,jobject self)165 bool JNICALL ConsumeBool(JNIEnv &env, jobject self) {
166   return ConsumeIntegral<uint8_t>(env, self) & 1u;
167 }
168 
ConsumeCharInternal(JNIEnv & env,jobject self,bool filter_surrogates)169 jchar ConsumeCharInternal(JNIEnv &env, jobject self, bool filter_surrogates) {
170   auto raw_codepoint = ConsumeIntegral<jchar>(env, self);
171   if (filter_surrogates && raw_codepoint >= 0xd800 && raw_codepoint < 0xe000)
172     raw_codepoint -= 0xd800;
173   return raw_codepoint;
174 }
175 
ConsumeChar(JNIEnv & env,jobject self)176 jchar JNICALL ConsumeChar(JNIEnv &env, jobject self) {
177   return ConsumeCharInternal(env, self, false);
178 }
179 
ConsumeCharNoSurrogates(JNIEnv & env,jobject self)180 jchar JNICALL ConsumeCharNoSurrogates(JNIEnv &env, jobject self) {
181   return ConsumeCharInternal(env, self, true);
182 }
183 
184 template <typename T>
ConsumeProbability(JNIEnv & env,jobject self)185 T JNICALL ConsumeProbability(JNIEnv &env, jobject self) {
186   using IntegralType =
187       typename std::conditional<(sizeof(T) <= sizeof(uint32_t)), uint32_t,
188                                 uint64_t>::type;
189   T result = static_cast<T>(ConsumeIntegral<IntegralType>(env, self));
190   result /= static_cast<T>(std::numeric_limits<IntegralType>::max());
191   return result;
192 }
193 
194 template <typename T>
ConsumeFloatInRange(JNIEnv & env,jobject self,T min,T max)195 T JNICALL ConsumeFloatInRange(JNIEnv &env, jobject self, T min, T max) {
196   T range;
197   T result = min;
198 
199   // Deal with overflow, in the event min and max are very far apart
200   if (min < 0 && max > 0 && min + std::numeric_limits<T>::max() < max) {
201     range = (max / 2) - (min / 2);
202     if (ConsumeBool(env, self)) {
203       result += range;
204     }
205   } else {
206     range = max - min;
207   }
208 
209   T probability = ConsumeProbability<T>(env, self);
210   return result + range * probability;
211 }
212 
213 template <typename T>
ConsumeRegularFloat(JNIEnv & env,jobject self)214 T JNICALL ConsumeRegularFloat(JNIEnv &env, jobject self) {
215   return ConsumeFloatInRange(env, self, std::numeric_limits<T>::lowest(),
216                              std::numeric_limits<T>::max());
217 }
218 
219 template <typename T>
ConsumeFloat(JNIEnv & env,jobject self)220 T JNICALL ConsumeFloat(JNIEnv &env, jobject self) {
221   if (env.GetIntField(self, gRemainingBytesField) == 0) return 0.0;
222 
223   auto type_val = ConsumeIntegral<uint8_t>(env, self);
224 
225   if (type_val <= 10) {
226     // Consume the same amount of bytes as for a regular float/double
227     ConsumeRegularFloat<T>(env, self);
228 
229     switch (type_val) {
230       case 0:
231         return 0.0;
232       case 1:
233         return -0.0;
234       case 2:
235         return std::numeric_limits<T>::infinity();
236       case 3:
237         return -std::numeric_limits<T>::infinity();
238       case 4:
239         return std::numeric_limits<T>::quiet_NaN();
240       case 5:
241         return std::numeric_limits<T>::denorm_min();
242       case 6:
243         return -std::numeric_limits<T>::denorm_min();
244       case 7:
245         return std::numeric_limits<T>::min();
246       case 8:
247         return -std::numeric_limits<T>::min();
248       case 9:
249         return std::numeric_limits<T>::max();
250       case 10:
251         return -std::numeric_limits<T>::max();
252       default:
253         abort();
254     }
255   }
256 
257   T regular = ConsumeRegularFloat<T>(env, self);
258   return regular;
259 }
260 
261 // Polyfill for C++20 std::countl_one, which counts the number of leading ones
262 // in an unsigned integer.
countl_one(uint8_t byte)263 inline __attribute__((always_inline)) uint8_t countl_one(uint8_t byte) {
264   // The result of __builtin_clz is undefined for 0.
265   if (byte == 0xFF) return 8;
266   return __builtin_clz(static_cast<uint8_t>(~byte)) - 24;
267 }
268 
269 // Forces a byte to be a valid UTF-8 continuation byte.
ForceContinuationByte(uint8_t & byte)270 inline __attribute__((always_inline)) void ForceContinuationByte(
271     uint8_t &byte) {
272   byte = (byte | (1u << 7u)) & ~(1u << 6u);
273 }
274 
275 constexpr uint8_t kTwoByteZeroLeadingByte = 0b11000000;
276 constexpr uint8_t kTwoByteZeroContinuationByte = 0b10000000;
277 constexpr uint8_t kThreeByteLowLeadingByte = 0b11100000;
278 constexpr uint8_t kSurrogateLeadingByte = 0b11101101;
279 
280 enum class Utf8GenerationState {
281   LeadingByte_Generic,
282   LeadingByte_AfterBackslash,
283   ContinuationByte_Generic,
284   ContinuationByte_LowLeadingByte,
285   FirstContinuationByte_LowLeadingByte,
286   FirstContinuationByte_SurrogateLeadingByte,
287   FirstContinuationByte_Generic,
288   SecondContinuationByte_Generic,
289   LeadingByte_LowSurrogate,
290   FirstContinuationByte_LowSurrogate,
291   SecondContinuationByte_HighSurrogate,
292   SecondContinuationByte_LowSurrogate,
293 };
294 
295 // Consumes up to `max_bytes` arbitrary bytes pointed to by `ptr` and returns a
296 // valid "modified UTF-8" string of length at most `max_length` that resembles
297 // the input bytes as closely as possible as well as the number of consumed
298 // bytes. If `stop_on_slash` is true, then the string will end on the first
299 // single consumed '\'.
300 //
301 // "Modified UTF-8" is the string encoding used by the JNI. It is the same as
302 // the legacy encoding CESU-8, but with `\0` coded on two bytes. In these
303 // encodings, code points requiring 4 bytes in modern UTF-8 are represented as
304 // two surrogates, each of which is coded on 3 bytes.
305 //
306 // This function has been designed with the following goals in mind:
307 // 1. The generated string should be biased towards containing ASCII characters
308 //    as these are often the ones that affect control flow directly.
309 // 2. Correctly encoded data (e.g. taken from the table of recent compares)
310 //    should be emitted unchanged.
311 // 3. The raw fuzzer input should be preserved as far as possible, but the
312 //    output must always be correctly encoded.
313 //
314 // The JVM accepts string in two encodings: UTF-16 and modified UTF-8.
315 // Generating UTF-16 would make it harder to fulfill the first design goal and
316 // would potentially hinder compatibility with corpora using the much more
317 // widely used UTF-8 encoding, which is reasonably similar to modified UTF-8. As
318 // a result, this function uses modified UTF-8.
319 //
320 // See Algorithm 1 of https://arxiv.org/pdf/2010.03090.pdf for more details on
321 // the individual cases involved in determining the validity of a UTF-8 string.
322 template <bool ascii_only, bool stop_on_backslash>
FixUpModifiedUtf8(const uint8_t * data,jint max_bytes,jint max_length)323 std::pair<std::string, jint> FixUpModifiedUtf8(const uint8_t *data,
324                                                jint max_bytes,
325                                                jint max_length) {
326   std::string str;
327   // Every character in modified UTF-8 is coded on at most six bytes. Every
328   // consumed byte is transformed into at most one code unit, except for the
329   // case of a zero byte which requires two bytes.
330   if (ascii_only) {
331     str.reserve(std::min(2 * static_cast<std::size_t>(max_length),
332                          2 * static_cast<std::size_t>(max_bytes)));
333   } else {
334     str.reserve(std::min(6 * static_cast<std::size_t>(max_length),
335                          2 * static_cast<std::size_t>(max_bytes)));
336   }
337 
338   Utf8GenerationState state = Utf8GenerationState::LeadingByte_Generic;
339   const uint8_t *pos = data;
340   const auto data_end = data + max_bytes;
341   for (jint length = 0; length < max_length && pos != data_end; ++pos) {
342     uint8_t c = *pos;
343     if (ascii_only) {
344       // Clamp to 7-bit ASCII range.
345       c &= 0x7Fu;
346     }
347     // Fix up c or previously read bytes according to the value of c and the
348     // current state. In the end, add the fixed up code unit c to the string.
349     // Exception: The zero character has to be coded on two bytes and is the
350     // only case in which an iteration of the loop adds two code units.
351     switch (state) {
352       case Utf8GenerationState::LeadingByte_Generic: {
353         switch (ascii_only ? 0 : countl_one(c)) {
354           case 0: {
355             // valid - 1-byte code point (ASCII)
356             // The zero character has to be coded on two bytes in modified
357             // UTF-8.
358             if (c == 0) {
359               str += static_cast<char>(kTwoByteZeroLeadingByte);
360               c = kTwoByteZeroContinuationByte;
361             } else if (stop_on_backslash && c == '\\') {
362               state = Utf8GenerationState::LeadingByte_AfterBackslash;
363               // The slash either signals the end of the string or is skipped,
364               // so don't append anything.
365               continue;
366             }
367             // Remain in state LeadingByte.
368             ++length;
369             break;
370           }
371           case 1: {
372             // invalid - continuation byte at leader byte position
373             // Fix it up to be of the form 0b110XXXXX and fall through to the
374             // case of a 2-byte sequence.
375             c |= 1u << 6u;
376             c &= ~(1u << 5u);
377             [[fallthrough]];
378           }
379           case 2: {
380             // (most likely) valid - start of a 2-byte sequence
381             // ASCII characters must be coded on a single byte, so we must
382             // ensure that the lower two bits combined with the six non-header
383             // bits of the following byte do not form a 7-bit ASCII value. This
384             // could only be the case if at most the lowest bit is set.
385             if ((c & 0b00011110u) == 0) {
386               state = Utf8GenerationState::ContinuationByte_LowLeadingByte;
387             } else {
388               state = Utf8GenerationState::ContinuationByte_Generic;
389             }
390             break;
391           }
392           // The default case falls through to the case of three leading ones
393           // coming right after.
394           default: {
395             // invalid - at least four leading ones
396             // In the case of exactly four leading ones, this would be valid
397             // UTF-8, but is not valid in the JVM's modified UTF-8 encoding.
398             // Fix it up by clearing the fourth leading one and falling through
399             // to the 3-byte case.
400             c &= ~(1u << 4u);
401             [[fallthrough]];
402           }
403           case 3: {
404             // valid - start of a 3-byte sequence
405             if (c == kThreeByteLowLeadingByte) {
406               state = Utf8GenerationState::FirstContinuationByte_LowLeadingByte;
407             } else if (c == kSurrogateLeadingByte) {
408               state = Utf8GenerationState::
409                   FirstContinuationByte_SurrogateLeadingByte;
410             } else {
411               state = Utf8GenerationState::FirstContinuationByte_Generic;
412             }
413             break;
414           }
415         }
416         break;
417       }
418       case Utf8GenerationState::LeadingByte_AfterBackslash: {
419         if (c != '\\') {
420           // Mark the current byte as consumed.
421           ++pos;
422           goto done;
423         }
424         // A double backslash is consumed as a single one. As we skipped the
425         // first one, emit the second one as usual.
426         state = Utf8GenerationState::LeadingByte_Generic;
427         ++length;
428         break;
429       }
430       case Utf8GenerationState::ContinuationByte_LowLeadingByte: {
431         ForceContinuationByte(c);
432         // Preserve the zero character, which is coded on two bytes in modified
433         // UTF-8. In all other cases ensure that we are not incorrectly encoding
434         // an ASCII character on two bytes by setting the eighth least
435         // significant bit of the encoded value (second least significant bit of
436         // the leading byte).
437         auto previous_c = static_cast<uint8_t>(str.back());
438         if (previous_c != kTwoByteZeroLeadingByte ||
439             c != kTwoByteZeroContinuationByte) {
440           str.back() = static_cast<char>(previous_c | (1u << 1u));
441         }
442         state = Utf8GenerationState::LeadingByte_Generic;
443         ++length;
444         break;
445       }
446       case Utf8GenerationState::ContinuationByte_Generic: {
447         ForceContinuationByte(c);
448         state = Utf8GenerationState::LeadingByte_Generic;
449         ++length;
450         break;
451       }
452       case Utf8GenerationState::FirstContinuationByte_LowLeadingByte: {
453         ForceContinuationByte(c);
454         // Ensure that the current code point could not have been coded on two
455         // bytes. As two bytes encode up to 11 bits and three bytes encode up
456         // to 16 bits, we thus have to make it such that the five highest bits
457         // are not all zero. Four of these bits are the non-header bits of the
458         // leader byte. Thus, set the highest non-header bit in this byte (fifth
459         // highest in the encoded value).
460         c |= 1u << 5u;
461         state = Utf8GenerationState::SecondContinuationByte_Generic;
462         break;
463       }
464       case Utf8GenerationState::FirstContinuationByte_SurrogateLeadingByte: {
465         ForceContinuationByte(c);
466         if (c & (1u << 5u)) {
467           // Start with a high surrogate (0xD800-0xDBFF). c contains the second
468           // byte and the first two bits of the third byte. The first two bits
469           // of this second byte are fixed to 10 (in 0x8-0xB).
470           c |= 1u << 5u;
471           c &= ~(1u << 4u);
472           // The high surrogate must be followed by a low surrogate.
473           state = Utf8GenerationState::SecondContinuationByte_HighSurrogate;
474         } else {
475           state = Utf8GenerationState::SecondContinuationByte_Generic;
476         }
477         break;
478       }
479       case Utf8GenerationState::FirstContinuationByte_Generic: {
480         ForceContinuationByte(c);
481         state = Utf8GenerationState::SecondContinuationByte_Generic;
482         break;
483       }
484       case Utf8GenerationState::SecondContinuationByte_HighSurrogate: {
485         ForceContinuationByte(c);
486         state = Utf8GenerationState::LeadingByte_LowSurrogate;
487         ++length;
488         break;
489       }
490       case Utf8GenerationState::SecondContinuationByte_LowSurrogate:
491       case Utf8GenerationState::SecondContinuationByte_Generic: {
492         ForceContinuationByte(c);
493         state = Utf8GenerationState::LeadingByte_Generic;
494         ++length;
495         break;
496       }
497       case Utf8GenerationState::LeadingByte_LowSurrogate: {
498         // We have to emit a low surrogate leading byte, which is a fixed value.
499         // We still consume a byte from the input to make fuzzer changes more
500         // stable and preserve valid surrogate pairs picked up from e.g. the
501         // table of recent compares.
502         c = kSurrogateLeadingByte;
503         state = Utf8GenerationState::FirstContinuationByte_LowSurrogate;
504         break;
505       }
506       case Utf8GenerationState::FirstContinuationByte_LowSurrogate: {
507         ForceContinuationByte(c);
508         // Low surrogates are code points in the range 0xDC00-0xDFFF. c contains
509         // the second byte and the first two bits of the third byte. The first
510         // two bits of this second byte are fixed to 11 (in 0xC-0xF).
511         c |= (1u << 5u) | (1u << 4u);
512         // The second continuation byte of a low surrogate is not restricted,
513         // but we need to track it differently to allow for correct backtracking
514         // if it isn't completed.
515         state = Utf8GenerationState::SecondContinuationByte_LowSurrogate;
516         break;
517       }
518     }
519     str += static_cast<uint8_t>(c);
520   }
521 
522   // Backtrack the current incomplete character.
523   switch (state) {
524     case Utf8GenerationState::SecondContinuationByte_LowSurrogate:
525       str.pop_back();
526       [[fallthrough]];
527     case Utf8GenerationState::FirstContinuationByte_LowSurrogate:
528       str.pop_back();
529       [[fallthrough]];
530     case Utf8GenerationState::LeadingByte_LowSurrogate:
531       str.pop_back();
532       [[fallthrough]];
533     case Utf8GenerationState::SecondContinuationByte_Generic:
534     case Utf8GenerationState::SecondContinuationByte_HighSurrogate:
535       str.pop_back();
536       [[fallthrough]];
537     case Utf8GenerationState::ContinuationByte_Generic:
538     case Utf8GenerationState::ContinuationByte_LowLeadingByte:
539     case Utf8GenerationState::FirstContinuationByte_Generic:
540     case Utf8GenerationState::FirstContinuationByte_LowLeadingByte:
541     case Utf8GenerationState::FirstContinuationByte_SurrogateLeadingByte:
542       str.pop_back();
543       [[fallthrough]];
544     case Utf8GenerationState::LeadingByte_Generic:
545     case Utf8GenerationState::LeadingByte_AfterBackslash:
546       // No backtracking required.
547       break;
548   }
549 
550 done:
551   return std::make_pair(str, pos - data);
552 }
553 }  // namespace
554 
555 namespace jazzer {
556 // Exposed for testing only.
FixUpModifiedUtf8(const uint8_t * data,jint max_bytes,jint max_length,bool ascii_only,bool stop_on_backslash)557 std::pair<std::string, jint> FixUpModifiedUtf8(const uint8_t *data,
558                                                jint max_bytes, jint max_length,
559                                                bool ascii_only,
560                                                bool stop_on_backslash) {
561   if (ascii_only) {
562     if (stop_on_backslash) {
563       return ::FixUpModifiedUtf8<true, true>(data, max_bytes, max_length);
564     } else {
565       return ::FixUpModifiedUtf8<true, false>(data, max_bytes, max_length);
566     }
567   } else {
568     if (stop_on_backslash) {
569       return ::FixUpModifiedUtf8<false, true>(data, max_bytes, max_length);
570     } else {
571       return ::FixUpModifiedUtf8<false, false>(data, max_bytes, max_length);
572     }
573   }
574 }
575 }  // namespace jazzer
576 
577 namespace {
ConsumeStringInternal(JNIEnv & env,jobject self,jint max_length,bool ascii_only,bool stop_on_backslash)578 jstring ConsumeStringInternal(JNIEnv &env, jobject self, jint max_length,
579                               bool ascii_only, bool stop_on_backslash) {
580   if (max_length < 0) {
581     ThrowIllegalArgumentException(env, "maxLength must not be negative");
582     return nullptr;
583   }
584 
585   const auto *dataPtr =
586       reinterpret_cast<const uint8_t *>(env.GetLongField(self, gDataPtrField));
587   jint remainingBytes = env.GetIntField(self, gRemainingBytesField);
588 
589   if (max_length == 0 || remainingBytes == 0) return env.NewStringUTF("");
590 
591   if (remainingBytes == 1) {
592     env.SetIntField(self, gRemainingBytesField, 0);
593     return env.NewStringUTF("");
594   }
595 
596   std::string str;
597   jint consumed_bytes;
598   std::tie(str, consumed_bytes) = jazzer::FixUpModifiedUtf8(
599       dataPtr, remainingBytes, max_length, ascii_only, stop_on_backslash);
600   env.SetLongField(self, gDataPtrField, (jlong)(dataPtr + consumed_bytes));
601   env.SetIntField(self, gRemainingBytesField, remainingBytes - consumed_bytes);
602   return env.NewStringUTF(str.c_str());
603 }
604 
ConsumeAsciiString(JNIEnv & env,jobject self,jint max_length)605 jstring JNICALL ConsumeAsciiString(JNIEnv &env, jobject self, jint max_length) {
606   return ConsumeStringInternal(env, self, max_length, true, true);
607 }
608 
ConsumeString(JNIEnv & env,jobject self,jint max_length)609 jstring JNICALL ConsumeString(JNIEnv &env, jobject self, jint max_length) {
610   return ConsumeStringInternal(env, self, max_length, false, true);
611 }
612 
ConsumeRemainingAsAsciiString(JNIEnv & env,jobject self)613 jstring JNICALL ConsumeRemainingAsAsciiString(JNIEnv &env, jobject self) {
614   return ConsumeStringInternal(env, self, std::numeric_limits<jint>::max(),
615                                true, false);
616 }
617 
ConsumeRemainingAsString(JNIEnv & env,jobject self)618 jstring JNICALL ConsumeRemainingAsString(JNIEnv &env, jobject self) {
619   return ConsumeStringInternal(env, self, std::numeric_limits<jint>::max(),
620                                false, false);
621 }
622 
RemainingBytes(JNIEnv & env,jobject self)623 std::size_t RemainingBytes(JNIEnv &env, jobject self) {
624   return env.GetIntField(self, gRemainingBytesField);
625 }
626 
627 const JNINativeMethod kFuzzedDataMethods[]{
628     {(char *)"consumeBoolean", (char *)"()Z", (void *)&ConsumeBool},
629     {(char *)"consumeByte", (char *)"()B", (void *)&ConsumeIntegral<jbyte>},
630     {(char *)"consumeByteUnchecked", (char *)"(BB)B",
631      (void *)&ConsumeIntegralInRange<jbyte>},
632     {(char *)"consumeShort", (char *)"()S", (void *)&ConsumeIntegral<jshort>},
633     {(char *)"consumeShortUnchecked", (char *)"(SS)S",
634      (void *)&ConsumeIntegralInRange<jshort>},
635     {(char *)"consumeInt", (char *)"()I", (void *)&ConsumeIntegral<jint>},
636     {(char *)"consumeIntUnchecked", (char *)"(II)I",
637      (void *)&ConsumeIntegralInRange<jint>},
638     {(char *)"consumeLong", (char *)"()J", (void *)&ConsumeIntegral<jlong>},
639     {(char *)"consumeLongUnchecked", (char *)"(JJ)J",
640      (void *)&ConsumeIntegralInRange<jlong>},
641     {(char *)"consumeFloat", (char *)"()F", (void *)&ConsumeFloat<jfloat>},
642     {(char *)"consumeRegularFloat", (char *)"()F",
643      (void *)&ConsumeRegularFloat<jfloat>},
644     {(char *)"consumeRegularFloatUnchecked", (char *)"(FF)F",
645      (void *)&ConsumeFloatInRange<jfloat>},
646     {(char *)"consumeProbabilityFloat", (char *)"()F",
647      (void *)&ConsumeProbability<jfloat>},
648     {(char *)"consumeDouble", (char *)"()D", (void *)&ConsumeFloat<jdouble>},
649     {(char *)"consumeRegularDouble", (char *)"()D",
650      (void *)&ConsumeRegularFloat<jdouble>},
651     {(char *)"consumeRegularDoubleUnchecked", (char *)"(DD)D",
652      (void *)&ConsumeFloatInRange<jdouble>},
653     {(char *)"consumeProbabilityDouble", (char *)"()D",
654      (void *)&ConsumeProbability<jdouble>},
655     {(char *)"consumeChar", (char *)"()C", (void *)&ConsumeChar},
656     {(char *)"consumeCharUnchecked", (char *)"(CC)C",
657      (void *)&ConsumeIntegralInRange<jchar>},
658     {(char *)"consumeCharNoSurrogates", (char *)"()C",
659      (void *)&ConsumeCharNoSurrogates},
660     {(char *)"consumeAsciiString", (char *)"(I)Ljava/lang/String;",
661      (void *)&ConsumeAsciiString},
662     {(char *)"consumeRemainingAsAsciiString", (char *)"()Ljava/lang/String;",
663      (void *)&ConsumeRemainingAsAsciiString},
664     {(char *)"consumeString", (char *)"(I)Ljava/lang/String;",
665      (void *)&ConsumeString},
666     {(char *)"consumeRemainingAsString", (char *)"()Ljava/lang/String;",
667      (void *)&ConsumeRemainingAsString},
668     {(char *)"consumeBooleans", (char *)"(I)[Z",
669      (void *)&ConsumeIntegralArray<jboolean>},
670     {(char *)"consumeBytes", (char *)"(I)[B",
671      (void *)&ConsumeIntegralArray<jbyte>},
672     {(char *)"consumeShorts", (char *)"(I)[S",
673      (void *)&ConsumeIntegralArray<jshort>},
674     {(char *)"consumeInts", (char *)"(I)[I",
675      (void *)&ConsumeIntegralArray<jint>},
676     {(char *)"consumeLongs", (char *)"(I)[J",
677      (void *)&ConsumeIntegralArray<jlong>},
678     {(char *)"consumeRemainingAsBytes", (char *)"()[B",
679      (void *)&ConsumeRemainingAsArray<jbyte>},
680     {(char *)"remainingBytes", (char *)"()I", (void *)&RemainingBytes},
681 };
682 const jint kNumFuzzedDataMethods =
683     sizeof(kFuzzedDataMethods) / sizeof(kFuzzedDataMethods[0]);
684 }  // namespace
685 
686 [[maybe_unused]] void
Java_com_code_1intelligence_jazzer_driver_FuzzedDataProviderImpl_nativeInit(JNIEnv * env,jclass clazz)687 Java_com_code_1intelligence_jazzer_driver_FuzzedDataProviderImpl_nativeInit(
688     JNIEnv *env, jclass clazz) {
689   env->RegisterNatives(clazz, kFuzzedDataMethods, kNumFuzzedDataMethods);
690   gDataPtrField = env->GetFieldID(clazz, "dataPtr", "J");
691   gRemainingBytesField = env->GetFieldID(clazz, "remainingBytes", "I");
692 }
693