1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.platform.test.flag.junit;
18 
19 import static org.junit.Assume.assumeFalse;
20 
21 import android.platform.test.flag.util.FlagReadException;
22 import android.platform.test.flag.util.FlagSetException;
23 
24 import com.google.common.base.CaseFormat;
25 import com.google.common.collect.Sets;
26 
27 import org.junit.rules.TestRule;
28 import org.junit.runner.Description;
29 import org.junit.runners.model.Statement;
30 
31 import java.lang.reflect.Field;
32 import java.util.HashMap;
33 import java.util.HashSet;
34 import java.util.Map;
35 import java.util.Objects;
36 import java.util.Set;
37 import java.util.concurrent.ConcurrentHashMap;
38 import java.util.function.BiPredicate;
39 import java.util.function.Predicate;
40 
41 import javax.annotation.Nonnull;
42 import javax.annotation.Nullable;
43 
44 /** A {@link TestRule} that helps to set flag values in unit test. */
45 public final class SetFlagsRule implements TestRule {
46     private static final String FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME = "FakeFeatureFlagsImpl";
47     private static final String REAL_FEATURE_FLAGS_IMPL_CLASS_NAME = "FeatureFlagsImpl";
48     private static final String CUSTOM_FEATURE_FLAGS_CLASS_NAME = "CustomFeatureFlags";
49     private static final String FEATURE_FLAGS_CLASS_NAME = "FeatureFlags";
50     private static final String FEATURE_FLAGS_FIELD_NAME = "FEATURE_FLAGS";
51     private static final String FLAGS_CLASS_NAME = "Flags";
52     private static final String FLAG_CONSTANT_PREFIX = "FLAG_";
53     private static final String SET_FLAG_METHOD_NAME = "setFlag";
54     private static final String RESET_ALL_METHOD_NAME = "resetAll";
55     private static final String IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME = "isFlagReadOnlyOptimized";
56 
57     // Store instances for entire life of a SetFlagsRule instance
58     private final Map<Class<?>, Object> mFlagsClassToFakeFlagsImpl = new HashMap<>();
59     private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>();
60 
61     // Store classes that are currently mutated by this rule
62     private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>();
63 
64     // Any flags added to this list cannot be set imperatively (i.e. with enableFlags/disableFlags)
65     private final Set<String> mLockedFlagNames = new HashSet<>();
66 
67     // listener to be called before setting a flag
68     private final Listener mListener;
69 
70     // TODO(322377082): remove repackage prefix list
71     private static final String[] REPACKAGE_PREFIX_LIST =
72             new String[] {
73                 "", "com.android.internal.hidden_from_bootclasspath.",
74             };
75     private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>();
76 
77     private final boolean mIsInitWithDefault;
78     private FlagsParameterization mFlagsParameterization;
79     private boolean mIsRuleEvaluating = false;
80 
81     public enum DefaultInitValueType {
82         /**
83          * Initialize flag value as null
84          *
85          * <p>Flag value need to be set before using
86          */
87         NULL_DEFAULT,
88 
89         /**
90          * Initialize flag value with the default value from the device
91          *
92          * <p>If flag value is not overridden by adb, then the default value is from the release
93          * configuration when the test is built.
94          */
95         DEVICE_DEFAULT,
96     }
97 
SetFlagsRule()98     public SetFlagsRule() {
99         this(DefaultInitValueType.DEVICE_DEFAULT);
100     }
101 
SetFlagsRule(DefaultInitValueType defaultType)102     public SetFlagsRule(DefaultInitValueType defaultType) {
103         this(defaultType, null);
104     }
105 
SetFlagsRule(@ullable FlagsParameterization flagsParameterization)106     public SetFlagsRule(@Nullable FlagsParameterization flagsParameterization) {
107         this(DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization);
108     }
109 
SetFlagsRule( DefaultInitValueType defaultType, @Nullable FlagsParameterization flagsParameterization)110     public SetFlagsRule(
111             DefaultInitValueType defaultType,
112             @Nullable FlagsParameterization flagsParameterization) {
113         this(defaultType, flagsParameterization, null);
114     }
115 
SetFlagsRule( DefaultInitValueType defaultType, @Nullable FlagsParameterization flagsParameterization, @Nullable Listener listener)116     private SetFlagsRule(
117             DefaultInitValueType defaultType,
118             @Nullable FlagsParameterization flagsParameterization,
119             @Nullable Listener listener) {
120         mIsInitWithDefault = defaultType == DefaultInitValueType.DEVICE_DEFAULT;
121         mFlagsParameterization = flagsParameterization;
122         if (flagsParameterization != null) {
123             mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
124         }
125         mListener = listener;
126     }
127 
128     /**
129      * Set the FlagsParameterization to be used during this test. This cannot be used to override a
130      * previous call, and cannot be called once the rule has been evaluated.
131      */
setFlagsParameterization(@onnull FlagsParameterization flagsParameterization)132     public void setFlagsParameterization(@Nonnull FlagsParameterization flagsParameterization) {
133         Objects.requireNonNull(flagsParameterization, "FlagsParameterization cannot be cleared");
134         if (mFlagsParameterization != null) {
135             throw new AssertionError("FlagsParameterization cannot be overridden");
136         }
137         if (mIsRuleEvaluating) {
138             throw new AssertionError("Cannot set FlagsParameterization once the rule is running");
139         }
140         ensureFlagsAreUnset();
141         mFlagsParameterization = flagsParameterization;
142         mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet());
143     }
144 
145     /**
146      * Enables the given flags.
147      *
148      * @param fullFlagNames The name of the flags in the flag class with the format
149      *     {packageName}.{flagName}
150      *
151      * @deprecated Annotate your test or class with <code>@EnableFlags(String...)</code> instead
152      */
153     @Deprecated
enableFlags(String... fullFlagNames)154     public void enableFlags(String... fullFlagNames) {
155         if (!mIsRuleEvaluating) {
156             throw new IllegalStateException("Not allowed to set flags outside test and setup code");
157         }
158         for (String fullFlagName : fullFlagNames) {
159             if (mLockedFlagNames.contains(fullFlagName)) {
160                 throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
161             }
162             setFlagValue(fullFlagName, true);
163         }
164     }
165 
166     /**
167      * Disables the given flags.
168      *
169      * @param fullFlagNames The name of the flags in the flag class with the format
170      *     {packageName}.{flagName}
171      *
172      * @deprecated Annotate your test or class with <code>@DisableFlags(String...)</code> instead
173      */
174     @Deprecated
disableFlags(String... fullFlagNames)175     public void disableFlags(String... fullFlagNames) {
176         if (!mIsRuleEvaluating) {
177             throw new IllegalStateException("Not allowed to set flags outside test and setup code");
178         }
179         for (String fullFlagName : fullFlagNames) {
180             if (mLockedFlagNames.contains(fullFlagName)) {
181                 throw new FlagSetException(fullFlagName, "Not allowed to change locked flags");
182             }
183             setFlagValue(fullFlagName, false);
184         }
185     }
186 
ensureFlagsAreUnset()187     private void ensureFlagsAreUnset() {
188         if (!mFlagsClassToFakeFlagsImpl.isEmpty()) {
189             throw new IllegalStateException("Some flags were set before the rule was initialized");
190         }
191     }
192 
193     @Override
apply(Statement base, Description description)194     public Statement apply(Statement base, Description description) {
195         return new Statement() {
196             @Override
197             public void evaluate() throws Throwable {
198                 Throwable throwable = null;
199                 try {
200                     if (mListener != null) {
201                         mListener.onStartedEvaluating();
202                     }
203                     AnnotationsRetriever.FlagAnnotations flagAnnotations =
204                             AnnotationsRetriever.getFlagAnnotations(description);
205                     assertAnnotationsMatchParameterization(flagAnnotations, mFlagsParameterization);
206                     flagAnnotations.assumeAllSetFlagsMatchParameterization(mFlagsParameterization);
207                     if (mFlagsParameterization != null) {
208                         ensureFlagsAreUnset();
209                         for (Map.Entry<String, Boolean> pair :
210                                 mFlagsParameterization.mOverrides.entrySet()) {
211                             setFlagValue(pair.getKey(), pair.getValue());
212                         }
213                     }
214                     for (Map.Entry<String, Boolean> pair :
215                             flagAnnotations.mSetFlagValues.entrySet()) {
216                         setFlagValue(pair.getKey(), pair.getValue());
217                     }
218                     mLockedFlagNames.addAll(flagAnnotations.mRequiredFlagValues.keySet());
219                     mLockedFlagNames.addAll(flagAnnotations.mSetFlagValues.keySet());
220                     mIsRuleEvaluating = true;
221                     base.evaluate();
222                 } catch (Throwable t) {
223                     throwable = t;
224                 } finally {
225                     mIsRuleEvaluating = false;
226                     try {
227                         resetFlags();
228                     } catch (Throwable t) {
229                         if (throwable != null) {
230                             t.addSuppressed(throwable);
231                         }
232                         throwable = t;
233                     }
234                     try {
235                         if (mListener != null) {
236                             mListener.onFinishedEvaluating();
237                         }
238                     } catch (Throwable t) {
239                         if (throwable != null) {
240                             t.addSuppressed(throwable);
241                         }
242                         throwable = t;
243                     }
244                 }
245                 if (throwable != null) throw throwable;
246             }
247         };
248     }
249 
250     private static void assertAnnotationsMatchParameterization(
251             AnnotationsRetriever.FlagAnnotations flagAnnotations,
252             FlagsParameterization parameterization) {
253         if (parameterization == null) return;
254         Set<String> parameterizedFlags = parameterization.mOverrides.keySet();
255         Set<String> requiredFlags = flagAnnotations.mRequiredFlagValues.keySet();
256         // Assert that NO Annotation-Required flag is in the parameterization
257         Set<String> parameterizedAndRequiredFlags =
258                 Sets.intersection(parameterizedFlags, requiredFlags);
259         if (!parameterizedAndRequiredFlags.isEmpty()) {
260             throw new AssertionError(
261                     "The following flags have required values (per @RequiresFlagsEnabled or"
262                             + " @RequiresFlagsDisabled) but they are part of the"
263                             + " FlagParameterization: "
264                             + parameterizedAndRequiredFlags);
265         }
266     }
267 
268     private void setFlagValue(String fullFlagName, boolean value) {
269         if (!fullFlagName.contains(".")) {
270             throw new FlagSetException(
271                     fullFlagName, "Flag name is not the expected format {packgeName}.{flagName}.");
272         }
273         // Get all packages containing Flags referencing the same fullFlagName.
274         Set<String> packageSet = getPackagesContainsFlag(fullFlagName);
275 
276         for (String packageName : packageSet) {
277             setFlagValue(Flag.createFlag(fullFlagName, packageName), value);
278         }
279     }
280 
281     private Set<String> getPackagesContainsFlag(String fullFlagName) {
282         return getAllPackagesForFlag(fullFlagName, mPackageToRepackage);
283     }
284 
285     private static Set<String> getAllPackagesForFlag(
286             String fullFlagName, Map<String, Set<String>> packageToRepackage) {
287 
288         String packageName = Flag.getFlagPackageName(fullFlagName);
289         Set<String> packageSet = packageToRepackage.getOrDefault(packageName, new HashSet<>());
290 
291         if (!packageSet.isEmpty()) {
292             return packageSet;
293         }
294 
295         for (String prefix : REPACKAGE_PREFIX_LIST) {
296             String repackagedName = String.format("%s%s", prefix, packageName);
297             String flagClassName = String.format("%s.%s", repackagedName, FLAGS_CLASS_NAME);
298             try {
299                 Class.forName(flagClassName, false, SetFlagsRule.class.getClassLoader());
300                 packageSet.add(repackagedName);
301             } catch (ClassNotFoundException e) {
302                 // Skip if the class is not found
303                 // An error will be thrown if no package containing flags referencing
304                 // the passed in flag
305             }
306         }
307         packageToRepackage.put(packageName, packageSet);
308         if (packageSet.isEmpty()) {
309             throw new FlagSetException(
310                     fullFlagName,
311                     "Cannot find package containing Flags class referencing to this flag.");
312         }
313         return packageSet;
314     }
315 
316     private void setFlagValue(Flag flag, boolean value) {
317         if (mListener != null) {
318             mListener.onBeforeSetFlag(flag, value);
319         }
320 
321         Object fakeFlagsImplInstance = null;
322 
323         Class<?> flagsClass = getFlagClassFromFlag(flag);
324         fakeFlagsImplInstance = getOrCreateFakeFlagsImp(flagsClass);
325 
326         if (!mMutatedFlagsClasses.contains(flagsClass)) {
327             // Replace FeatureFlags in Flags class with FakeFeatureFlagsImpl
328             replaceFlagsImpl(flagsClass, fakeFlagsImplInstance);
329             mMutatedFlagsClasses.add(flagsClass);
330         }
331 
332         // If the test is trying to set the flag value on a read_only flag in an optimized build
333         // skip this test, since it is not a valid testing case
334         // The reason for skipping instead of throwning error here is all read_write flag will be
335         // change to read_only in the final release configuration. Thus the test could be executed
336         // in other release configuration cases
337         // TODO(b/337449119): SetFlagsRule should still run tests that are consistent with the
338         // read-only values of flags. But be careful, if a ClassRule exists, the value returned by
339         // the original FeatureFlags instance may be overridden, and reading it may not be allowed.
340         boolean isOptimized = verifyFlagReadOnlyAndOptimized(fakeFlagsImplInstance, flag);
341         assumeFalse(
342                 String.format(
343                         "Flag %s is read_only, and the code is optimized. "
344                                 + " The flag value should not be modified on this build"
345                                 + " Skip this test.",
346                         flag.fullFlagName()),
347                 isOptimized);
348 
349         // Set desired flag value in the FakeFeatureFlagsImpl
350         setFlagValueInFakeFeatureFlagsImpl(fakeFlagsImplInstance, flag, value);
351     }
352 
353     private static Class<?> getFlagClassFromFlag(Flag flag) {
354         String className = flag.flagsClassName();
355         Class<?> flagsClass = null;
356         try {
357             flagsClass = Class.forName(className);
358         } catch (ClassNotFoundException e) {
359             throw new FlagSetException(
360                     flag.fullFlagName(),
361                     String.format(
362                             "Can not load the Flags class %s to set its values. Please check the "
363                                     + "flag name and ensure that the aconfig auto generated "
364                                     + "library is in the dependency.",
365                             className),
366                     e);
367         }
368         return flagsClass;
369     }
370 
371     private static Class<?> getFlagClassFromFlagsClassName(String className) {
372         if (!className.endsWith("." + FLAGS_CLASS_NAME)) {
373             throw new FlagSetException(
374                     className,
375                     "Can not watch this Flags class because it is not named 'Flags'. Please ensure"
376                             + " your @UsesFlags() annotations only reference the Flags classes.");
377         }
378         try {
379             return Class.forName(className);
380         } catch (ClassNotFoundException e) {
381             throw new FlagSetException(
382                     className,
383                     "Cannot load this Flags class to set its values. Please check the flag name and"
384                         + " ensure that the aconfig auto generated library is in the dependency.",
385                     e);
386         }
387     }
388 
389     private boolean getFlagValue(Object featureFlagsImpl, Flag flag) {
390         // Must be consistent with method name in aconfig auto generated code.
391         String methodName = getFlagMethodName(flag);
392         String fullFlagName = flag.fullFlagName();
393 
394         try {
395             Object result =
396                     featureFlagsImpl.getClass().getMethod(methodName).invoke(featureFlagsImpl);
397             if (result instanceof Boolean) {
398                 return (Boolean) result;
399             }
400             throw new FlagReadException(
401                     fullFlagName,
402                     String.format(
403                             "Flag type is %s, not boolean", result.getClass().getSimpleName()));
404         } catch (NoSuchMethodException e) {
405             throw new FlagReadException(
406                     fullFlagName,
407                     String.format(
408                             "No method %s in the Flags class %s to read the flag value. Please"
409                                     + " check the flag name.",
410                             methodName, featureFlagsImpl.getClass().getName()),
411                     e);
412         } catch (ReflectiveOperationException e) {
413             throw new FlagReadException(
414                     fullFlagName,
415                     String.format(
416                             "Fail to get value of flag %s from instance %s",
417                             fullFlagName, featureFlagsImpl.getClass().getName()),
418                     e);
419         }
420     }
421 
422     private String getFlagMethodName(Flag flag) {
423         return CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, flag.simpleFlagName());
424     }
425 
426     private void setFlagValueInFakeFeatureFlagsImpl(
427             Object fakeFeatureFlagsImpl, Flag flag, boolean value) {
428         String fullFlagName = flag.fullFlagName();
429         try {
430             fakeFeatureFlagsImpl
431                     .getClass()
432                     .getMethod(SET_FLAG_METHOD_NAME, String.class, boolean.class)
433                     .invoke(fakeFeatureFlagsImpl, fullFlagName, value);
434         } catch (NoSuchMethodException e) {
435             throw new FlagSetException(
436                     fullFlagName,
437                     String.format(
438                             "Flag implementation %s is not fake implementation",
439                             fakeFeatureFlagsImpl.getClass().getName()),
440                     e);
441         } catch (ReflectiveOperationException e) {
442             throw new FlagSetException(fullFlagName, e);
443         }
444     }
445 
446     private static boolean verifyFlagReadOnlyAndOptimized(Object fakeFeatureFlagsImpl, Flag flag) {
447         String fullFlagName = flag.fullFlagName();
448         try {
449             boolean result =
450                     (Boolean)
451                             fakeFeatureFlagsImpl
452                                     .getClass()
453                                     .getMethod(
454                                             IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME, String.class)
455                                     .invoke(fakeFeatureFlagsImpl, fullFlagName);
456             return result;
457         } catch (NoSuchMethodException e) {
458             // If the flag is generated under exported mode, then it doesn't have this method
459             String simpleClassName = fakeFeatureFlagsImpl.getClass().getSimpleName();
460             if (simpleClassName.equals(FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME)) {
461                 return false;
462             }
463             if (simpleClassName.equals(CUSTOM_FEATURE_FLAGS_CLASS_NAME)) {
464                 return false;
465             }
466             throw new FlagSetException(
467                     fullFlagName,
468                     String.format(
469                             "Cannot check whether flag is optimized. "
470                                     + "Flag implementation %s is not fake implementation",
471                             fakeFeatureFlagsImpl.getClass().getName()),
472                     e);
473         } catch (ReflectiveOperationException e) {
474             throw new FlagSetException(fullFlagName, e);
475         }
476     }
477 
478     @Nonnull
479     private Object getOrCreateFakeFlagsImp(Class<?> flagsClass) {
480         Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
481         if (fakeFlagsImplInstance != null) {
482             return fakeFlagsImplInstance;
483         }
484 
485         String packageName = flagsClass.getPackageName();
486         String fakeClassName =
487                 String.format("%s.%s", packageName, FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME);
488         String interfaceName = String.format("%s.%s", packageName, FEATURE_FLAGS_CLASS_NAME);
489 
490         Object realFlagsImplInstance = readFlagsImpl(flagsClass);
491         mFlagsClassToRealFlagsImpl.put(flagsClass, realFlagsImplInstance);
492 
493         try {
494             Class<?> flagImplClass = Class.forName(fakeClassName);
495             Class<?> flagInterface = Class.forName(interfaceName);
496             fakeFlagsImplInstance =
497                     flagImplClass
498                             .getConstructor(flagInterface)
499                             .newInstance(mIsInitWithDefault ? realFlagsImplInstance : null);
500         } catch (ReflectiveOperationException e) {
501             throw new UnsupportedOperationException(
502                     String.format(
503                             "Cannot create FakeFeatureFlagsImpl in Flags class %s.",
504                             flagsClass.getName()),
505                     e);
506         }
507 
508         mFlagsClassToFakeFlagsImpl.put(flagsClass, fakeFlagsImplInstance);
509 
510         return fakeFlagsImplInstance;
511     }
512 
513     private static void replaceFlagsImpl(Class<?> flagsClass, Object flagsImplInstance) {
514         Field featureFlagsField = getFeatureFlagsField(flagsClass);
515         try {
516             featureFlagsField.set(null, flagsImplInstance);
517         } catch (IllegalAccessException e) {
518             throw new UnsupportedOperationException(
519                     String.format(
520                             "Cannot replace FeatureFlagsImpl to %s.",
521                             flagsImplInstance.getClass().getName()),
522                     e);
523         }
524     }
525 
526     private static Object readFlagsImpl(Class<?> flagsClass) {
527         Field featureFlagsField = getFeatureFlagsField(flagsClass);
528         try {
529             return featureFlagsField.get(null);
530         } catch (IllegalAccessException e) {
531             throw new UnsupportedOperationException(
532                     String.format(
533                             "Cannot get FeatureFlags from Flags class %s.", flagsClass.getName()),
534                     e);
535         }
536     }
537 
538     private static Field getFeatureFlagsField(Class<?> flagsClass) {
539         Field featureFlagsField = null;
540         try {
541             featureFlagsField = flagsClass.getDeclaredField(FEATURE_FLAGS_FIELD_NAME);
542         } catch (ReflectiveOperationException e) {
543             throw new UnsupportedOperationException(
544                     String.format(
545                             "Cannot store FeatureFlagsImpl in Flag %s.", flagsClass.getName()),
546                     e);
547         }
548         featureFlagsField.setAccessible(true);
549         return featureFlagsField;
550     }
551 
552     private void resetFlags() {
553         String flagsClassName = null;
554         try {
555             for (Class<?> flagsClass : mMutatedFlagsClasses) {
556                 flagsClassName = flagsClass.getName();
557                 Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass);
558                 Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass);
559                 // Replace FeatureFlags in Flags class with real FeatureFlagsImpl
560                 replaceFlagsImpl(flagsClass, flagsImplInstance);
561                 fakeFlagsImplInstance
562                         .getClass()
563                         .getMethod(RESET_ALL_METHOD_NAME)
564                         .invoke(fakeFlagsImplInstance);
565             }
566             mMutatedFlagsClasses.clear();
567         } catch (Exception e) {
568             throw new FlagSetException(flagsClassName, e);
569         }
570     }
571 
572     /** An interface that provides hooks to the ClassRule. */
573     private interface Listener {
574         /** Called before a flag is set. */
575         void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value);
576 
577         /** Called after the rule has started evaluating for a test. */
578         void onStartedEvaluating();
579 
580         /** Called after the rule has finished evaluating for a test. */
581         void onFinishedEvaluating();
582     }
583 
584     /**
585      * A @ClassRule which adds extra consistency checks for SetFlagsRule.
586      * <li>Requires that tests monitor the Flags class of any flag that is set.
587      * <li>Fails a test if a flag that is set was read before the test started.
588      */
589     public static class ClassRule implements TestRule {
590         /** The flags classes that are requested to be watched during construction. */
591         private final Set<Class<?>> mGlobalFlagsClassesToWatch = new HashSet<>();
592 
593         /** The flags packages that are allowed to be set, for quick per-flag lookup */
594         private final Set<String> mSettableFlagsPackages = new HashSet<>();
595 
596         /** The mapping from the Flags classes to the real implementations */
597         private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>();
598 
599         /** The mapping from the Flags classes to the watcher implementations */
600         private final Map<Class<?>, Object> mFlagsClassToWatcherImpl = new HashMap<>();
601 
602         /** The flags classes that have actually been mutated */
603         private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>();
604 
605         /** The flag values set by class annotations */
606         private final Map<String, Boolean> mClassLevelSetFlagValues = new ConcurrentHashMap<>();
607 
608         /**
609          * The individual flags which have been read from prior to tests starting, mapped to the
610          * stack trace of the first read.
611          */
612         private final Map<String, FirstFlagRead> mFirstReadOutsideTestsByFlag =
613                 new ConcurrentHashMap<>();
614 
615         /**
616          * The individual flags which have been read from within a test, mapped to the stack trace
617          * of the first read.
618          */
619         private final Map<String, FirstFlagRead> mFirstReadWithinTestByFlag =
620                 new ConcurrentHashMap<>();
621 
622         /** repackage cache */
623         private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>();
624 
625         /** The depth of the ClassRule evaluating on potentially nested suites */
626         private int mSuiteRunDepth = 0;
627 
628         /** Whether the SetFlagsRule is evaluating for a test */
629         private boolean mIsTestRunning = false;
630 
631         /** Typical constructor takes an initial list flags classes to watch */
632         public ClassRule(Class<?>... flagsClasses) {
633             for (Class<?> flagsClass : flagsClasses) {
634                 mGlobalFlagsClassesToWatch.add(flagsClass);
635             }
636         }
637 
638         /** Listener to be notified of events in any created SetFlagsRule */
639         private SetFlagsRule.Listener mListener =
640                 new SetFlagsRule.Listener() {
641                     @Override
642                     public void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value) {
643                         if (!mIsTestRunning) {
644                             throw new IllegalStateException("Inner rule should be running!");
645                         }
646                         assertFlagCanBeSet(flag, value);
647                     }
648 
649                     @Override
650                     public void onStartedEvaluating() {
651                         if (mSuiteRunDepth == 0) {
652                             throw new IllegalStateException("Outer rule should be running!");
653                         }
654                         if (mIsTestRunning) {
655                             throw new IllegalStateException("Inner rule is still running!");
656                         }
657                         mIsTestRunning = true;
658                     }
659 
660                     @Override
661                     public void onFinishedEvaluating() {
662                         if (!mIsTestRunning) {
663                             throw new IllegalStateException("Inner rule did not start!");
664                         }
665                         mIsTestRunning = false;
666                         checkAllFlagsWatchersRestored();
667                         mFirstReadWithinTestByFlag.clear();
668                     }
669                 };
670 
671         /**
672          * Creates a SetFlagsRule which will work as normal, but additionally enforce the guarantees
673          * about not setting flags that were read within the ClassRule
674          */
675         public SetFlagsRule createSetFlagsRule() {
676             return createSetFlagsRule(null);
677         }
678 
679         /**
680          * Creates a SetFlagsRule with parameterization which will work as normal, but additionally
681          * enforce the guarantees about not setting flags that were read within the ClassRule
682          */
683         public SetFlagsRule createSetFlagsRule(
684                 @Nullable FlagsParameterization flagsParameterization) {
685             return new SetFlagsRule(
686                     DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization, mListener);
687         }
688 
689         private boolean isFlagsClassMonitored(SetFlagsRule.Flag flag) {
690             return mSettableFlagsPackages.contains(flag.flagPackageName());
691         }
692 
693         private void assertFlagCanBeSet(SetFlagsRule.Flag flag, boolean value) {
694             Exception firstReadWithinTest = mFirstReadWithinTestByFlag.get(flag.fullFlagName());
695             if (firstReadWithinTest != null) {
696                 throw new FlagSetException(
697                         flag.fullFlagName(),
698                         "This flag was locked when it was read earlier in this test. To fix this"
699                                 + " error, always use @EnableFlags() and @DisableFlags() to set"
700                                 + " flags, which ensures flags are set before even any"
701                                 + " @Before-annotated setup methods.",
702                         firstReadWithinTest);
703             }
704             Exception firstReadOutsideTest = mFirstReadOutsideTestsByFlag.get(flag.fullFlagName());
705             if (firstReadOutsideTest != null) {
706                 throw new FlagSetException(
707                         flag.fullFlagName(),
708                         "This flag was locked when it was read outside of the test code; likely"
709                                 + " during initialization of the test class. To fix this error,"
710                                 + " move test fixture initialization code into your"
711                                 + " @Before-annotated setup method, and ensure you are using"
712                                 + " @EnableFlags() and @DisableFlags() to set flags.",
713                         firstReadOutsideTest);
714             }
715             if (!isFlagsClassMonitored(flag)) {
716                 throw new FlagSetException(
717                         flag.fullFlagName(),
718                         "This flag's class is not monitored. Always use @EnableFlags() and"
719                                 + " @DisableFlags() on the class or method instead of"
720                                 + " .enableFlags() or .disableFlags() to prevent this error. When"
721                                 + " using FlagsParameterization, add `@UsesFlags("
722                                 + flag.flagPackageName()
723                                 + ".Flags.class)` to the test class. As a last resort, pass the"
724                                 + " Flags class to the constructor of your"
725                                 + " SetFlagsRule.ClassRule.");
726             }
727             // Detect errors where the rule messed up and set the wrong flag value.
728             Boolean classLevelValue = mClassLevelSetFlagValues.get(flag.fullFlagName());
729             if (classLevelValue != null && classLevelValue != value) {
730                 throw new FlagSetException(
731                         flag.fullFlagName(),
732                         "This flag's value was set at the class level to a different value.");
733             }
734         }
735 
736         private void checkInstanceOfRealFlagsImpl(Object actual) {
737             if (!actual.getClass().getSimpleName().equals(REAL_FEATURE_FLAGS_IMPL_CLASS_NAME)) {
738                 throw new IllegalStateException(
739                         String.format(
740                                 "Wrong impl type during setup: '%s' is not a %s",
741                                 actual, REAL_FEATURE_FLAGS_IMPL_CLASS_NAME));
742             }
743         }
744 
745         private void checkSameAs(Object expected, Object actual) {
746             if (expected != actual) {
747                 throw new IllegalStateException(
748                         String.format(
749                                 "Wrong impl instance found during teardown: expected %s but was %s",
750                                 expected, actual));
751             }
752         }
753 
754         private Object getOrCreateFlagReadWatcher(Class<?> flagsClass) {
755             Object watcher = mFlagsClassToWatcherImpl.get(flagsClass);
756             if (watcher != null) {
757                 return watcher;
758             }
759             Object flagsImplInstance = readFlagsImpl(flagsClass);
760             // strict mode: ensure that the current impl is the real impl
761             checkInstanceOfRealFlagsImpl(flagsImplInstance);
762             // save the real impl for restoration later
763             mFlagsClassToRealFlagsImpl.put(flagsClass, flagsImplInstance);
764             watcher = newFlagReadWatcher(flagsClass, flagsImplInstance);
765             mFlagsClassToWatcherImpl.put(flagsClass, watcher);
766             return watcher;
767         }
768 
769         private void recordFlagRead(String flagName) {
770             if (mIsTestRunning) {
771                 mFirstReadWithinTestByFlag.computeIfAbsent(flagName, FirstFlagRead::new);
772             } else {
773                 mFirstReadOutsideTestsByFlag.computeIfAbsent(flagName, FirstFlagRead::new);
774             }
775         }
776 
777         private Object newFlagReadWatcher(Class<?> flagsClass, Object flagsImplInstance) {
778             String packageName = flagsClass.getPackageName();
779             String customClassName =
780                     String.format("%s.%s", packageName, CUSTOM_FEATURE_FLAGS_CLASS_NAME);
781             BiPredicate<String, Predicate<Object>> getValueImpl =
782                     (flagName, predicate) -> {
783                         // Flags set at the class level pose no consistency risk
784                         Boolean value = mClassLevelSetFlagValues.get(flagName);
785                         if (value != null) {
786                             return value;
787                         }
788                         recordFlagRead(flagName);
789                         return predicate.test(flagsImplInstance);
790                     };
791             try {
792                 Class<?> customFlagsClass = Class.forName(customClassName);
793                 return customFlagsClass.getConstructor(BiPredicate.class).newInstance(getValueImpl);
794             } catch (ReflectiveOperationException e) {
795                 throw new UnsupportedOperationException(
796                         String.format(
797                                 "Cannot create CustomFeatureFlags in Flags class %s.",
798                                 flagsClass.getName()),
799                         e);
800             }
801         }
802 
803         /** Get the package name of the flags in this class. This is the non-repackaged name. */
804         private String getFlagPackageName(Class<?> flagsClass) {
805             String classPackageName = flagsClass.getPackageName();
806             String shortestPackageName = classPackageName;
807             for (String prefix : REPACKAGE_PREFIX_LIST) {
808                 if (prefix.isEmpty()) continue;
809                 if (classPackageName.startsWith(prefix)) {
810                     String unprefixedPackage = classPackageName.substring(prefix.length());
811                     if (unprefixedPackage.length() < shortestPackageName.length()) {
812                         shortestPackageName = unprefixedPackage;
813                     }
814                 }
815             }
816             return shortestPackageName;
817         }
818 
819         private void setupClassLevelFlagValues(Description description) {
820             mClassLevelSetFlagValues.putAll(
821                     AnnotationsRetriever.getFlagAnnotations(description).mSetFlagValues);
822         }
823 
824         private void setupFlagsWatchers(Description description) {
825             // Start with the static list of Flags classes to watch
826             Set<Class<?>> flagsClassesToWatch = new HashSet<>(mGlobalFlagsClassesToWatch);
827             // Collect the Flags classes from @UsedFlags annotations on the Descriptor
828             Set<String> usedFlagsClasses = AnnotationsRetriever.getAllUsedFlagsClasses(description);
829             for (String flagsClassName : usedFlagsClasses) {
830                 flagsClassesToWatch.add(getFlagClassFromFlagsClassName(flagsClassName));
831             }
832             // Now setup watchers on the provided Flags classes
833             for (Class<?> flagsClass : flagsClassesToWatch) {
834                 setupFlagsWatcher(flagsClass, getFlagPackageName(flagsClass));
835             }
836             // Get all annotated flags and then the distinct packages for each flag
837             Set<String> setFlags = AnnotationsRetriever.getAllAnnotationSetFlags(description);
838             Set<String> extraFlagPackages = new HashSet<>();
839             for (String setFlag : setFlags) {
840                 extraFlagPackages.add(Flag.getFlagPackageName(setFlag));
841             }
842             // Do not bother with flags that are already monitored
843             extraFlagPackages.removeAll(mSettableFlagsPackages);
844             // Expand packages to all repackaged versions, stored as Flag objects
845             Set<Flag> extraWildcardFlags = new HashSet<>();
846             for (String extraFlagPackage : extraFlagPackages) {
847                 String fullFlagName = extraFlagPackage + ".*";
848                 Set<String> packages = getAllPackagesForFlag(fullFlagName, mPackageToRepackage);
849                 for (String packageName : packages) {
850                     Flag flag = Flag.createFlag(fullFlagName, packageName);
851                     extraWildcardFlags.add(flag);
852                 }
853             }
854             // Set up watchers for each wildcard flag
855             for (Flag flag : extraWildcardFlags) {
856                 Class<?> flagsClass = getFlagClassFromFlag(flag);
857                 setupFlagsWatcher(flagsClass, flag.flagPackageName());
858             }
859         }
860 
861         private void setupFlagsWatcher(Class<?> flagsClass, String flagPackageName) {
862             if (mMutatedFlagsClasses.contains(flagsClass)) {
863                 throw new IllegalStateException(
864                         String.format("Flags class %s is already mutated", flagsClass.getName()));
865             }
866             Object watcher = getOrCreateFlagReadWatcher(flagsClass);
867             replaceFlagsImpl(flagsClass, watcher);
868             mMutatedFlagsClasses.add(flagsClass);
869             mSettableFlagsPackages.add(flagPackageName);
870         }
871 
872         private void teardownFlagsWatchers() {
873             try {
874                 for (Class<?> flagsClass : mMutatedFlagsClasses) {
875                     Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass);
876                     // strict mode: ensure that the watcher is still in place
877                     Object watcher = readFlagsImpl(flagsClass);
878                     checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher);
879                     // Replace FeatureFlags in Flags class with real FeatureFlagsImpl
880                     replaceFlagsImpl(flagsClass, flagsImplInstance);
881                 }
882                 mMutatedFlagsClasses.clear();
883                 mSettableFlagsPackages.clear();
884                 mFirstReadOutsideTestsByFlag.clear();
885             } catch (IllegalStateException e) {
886                 throw e;
887             } catch (Exception e) {
888                 throw new IllegalStateException("Failed to teardown Flags watchers", e);
889             }
890             if (mIsTestRunning) {
891                 throw new IllegalStateException("An inner SetFlagsRule is still running");
892             }
893             if (!mFirstReadWithinTestByFlag.isEmpty()) {
894                 throw new IllegalStateException("An inner SetFlagsRule did not fully clean up");
895             }
896         }
897 
898         private void checkAllFlagsWatchersRestored() {
899             for (Class<?> flagsClass : mMutatedFlagsClasses) {
900                 Object watcher = readFlagsImpl(flagsClass);
901                 checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher);
902             }
903         }
904 
905         @Override
906         public Statement apply(Statement base, Description description) {
907             return new Statement() {
908                 @Override
909                 public void evaluate() throws Throwable {
910                     Throwable throwable = null;
911                     final int initialDepth = mSuiteRunDepth;
912                     try {
913                         mSuiteRunDepth++;
914                         if (initialDepth == 0) {
915                             setupFlagsWatchers(description);
916                             setupClassLevelFlagValues(description);
917                         }
918                         base.evaluate();
919                     } catch (Throwable t) {
920                         throwable = t;
921                     } finally {
922                         mSuiteRunDepth--;
923                         try {
924                             if (initialDepth == 0) {
925                                 mClassLevelSetFlagValues.clear();
926                                 teardownFlagsWatchers();
927                             }
928                             if (mSuiteRunDepth != initialDepth) {
929                                 throw new IllegalStateException(
930                                         String.format(
931                                                 "Evaluations were not correctly nested: initial"
932                                                         + " depth was %d but final depth was %d",
933                                                 initialDepth, mSuiteRunDepth));
934                             }
935                         } catch (Throwable t) {
936                             if (throwable != null) {
937                                 t.addSuppressed(throwable);
938                             }
939                             throwable = t;
940                         }
941                     }
942                     if (throwable != null) throw throwable;
943                 }
944             };
945         }
946     }
947 
948     private static class FirstFlagRead extends Exception {
949         FirstFlagRead(String flagName) {
950             super(String.format("Flag '%s' was first read at this location:", flagName));
951         }
952     }
953 
954     private static class Flag {
955         private static final String PACKAGE_NAME_SIMPLE_NAME_SEPARATOR = ".";
956         private final String mFullFlagName;
957         private final String mFlagPackageName;
958         private final String mClassPackageName;
959         private final String mSimpleFlagName;
960 
961         public static String getFlagPackageName(String fullFlagName) {
962             int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
963             return fullFlagName.substring(0, index);
964         }
965 
966         public static Flag createFlag(String fullFlagName) {
967             int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
968             String packageName = fullFlagName.substring(0, index);
969             return createFlag(fullFlagName, packageName);
970         }
971 
972         public static Flag createFlag(String fullFlagName, String classPackageName) {
973             if (!fullFlagName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)
974                     || !classPackageName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)) {
975                 throw new IllegalArgumentException(
976                         String.format(
977                                 "Flag %s is invalid. The format should be {packageName}"
978                                         + ".{simpleFlagName}",
979                                 fullFlagName));
980             }
981             int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR);
982             String flagPackageName = fullFlagName.substring(0, index);
983             String simpleFlagName = fullFlagName.substring(index + 1);
984 
985             return new Flag(fullFlagName, flagPackageName, classPackageName, simpleFlagName);
986         }
987 
988         private Flag(
989                 String fullFlagName,
990                 String flagPackageName,
991                 String classPackageName,
992                 String simpleFlagName) {
993             this.mFullFlagName = fullFlagName;
994             this.mFlagPackageName = flagPackageName;
995             this.mClassPackageName = classPackageName;
996             this.mSimpleFlagName = simpleFlagName;
997         }
998 
999         public String fullFlagName() {
1000             return mFullFlagName;
1001         }
1002 
1003         public String flagPackageName() {
1004             return mFlagPackageName;
1005         }
1006 
1007         public String classPackageName() {
1008             return mClassPackageName;
1009         }
1010 
1011         public String simpleFlagName() {
1012             return mSimpleFlagName;
1013         }
1014 
1015         public String flagsClassName() {
1016             return String.format("%s.%s", classPackageName(), FLAGS_CLASS_NAME);
1017         }
1018     }
1019 }
1020