1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.platform.test.flag.junit; 18 19 import static org.junit.Assume.assumeFalse; 20 21 import android.platform.test.flag.util.FlagReadException; 22 import android.platform.test.flag.util.FlagSetException; 23 24 import com.google.common.base.CaseFormat; 25 import com.google.common.collect.Sets; 26 27 import org.junit.rules.TestRule; 28 import org.junit.runner.Description; 29 import org.junit.runners.model.Statement; 30 31 import java.lang.reflect.Field; 32 import java.util.HashMap; 33 import java.util.HashSet; 34 import java.util.Map; 35 import java.util.Objects; 36 import java.util.Set; 37 import java.util.concurrent.ConcurrentHashMap; 38 import java.util.function.BiPredicate; 39 import java.util.function.Predicate; 40 41 import javax.annotation.Nonnull; 42 import javax.annotation.Nullable; 43 44 /** A {@link TestRule} that helps to set flag values in unit test. */ 45 public final class SetFlagsRule implements TestRule { 46 private static final String FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME = "FakeFeatureFlagsImpl"; 47 private static final String REAL_FEATURE_FLAGS_IMPL_CLASS_NAME = "FeatureFlagsImpl"; 48 private static final String CUSTOM_FEATURE_FLAGS_CLASS_NAME = "CustomFeatureFlags"; 49 private static final String FEATURE_FLAGS_CLASS_NAME = "FeatureFlags"; 50 private static final String FEATURE_FLAGS_FIELD_NAME = "FEATURE_FLAGS"; 51 private static final String FLAGS_CLASS_NAME = "Flags"; 52 private static final String FLAG_CONSTANT_PREFIX = "FLAG_"; 53 private static final String SET_FLAG_METHOD_NAME = "setFlag"; 54 private static final String RESET_ALL_METHOD_NAME = "resetAll"; 55 private static final String IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME = "isFlagReadOnlyOptimized"; 56 57 // Store instances for entire life of a SetFlagsRule instance 58 private final Map<Class<?>, Object> mFlagsClassToFakeFlagsImpl = new HashMap<>(); 59 private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>(); 60 61 // Store classes that are currently mutated by this rule 62 private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>(); 63 64 // Any flags added to this list cannot be set imperatively (i.e. with enableFlags/disableFlags) 65 private final Set<String> mLockedFlagNames = new HashSet<>(); 66 67 // listener to be called before setting a flag 68 private final Listener mListener; 69 70 // TODO(322377082): remove repackage prefix list 71 private static final String[] REPACKAGE_PREFIX_LIST = 72 new String[] { 73 "", "com.android.internal.hidden_from_bootclasspath.", 74 }; 75 private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>(); 76 77 private final boolean mIsInitWithDefault; 78 private FlagsParameterization mFlagsParameterization; 79 private boolean mIsRuleEvaluating = false; 80 81 public enum DefaultInitValueType { 82 /** 83 * Initialize flag value as null 84 * 85 * <p>Flag value need to be set before using 86 */ 87 NULL_DEFAULT, 88 89 /** 90 * Initialize flag value with the default value from the device 91 * 92 * <p>If flag value is not overridden by adb, then the default value is from the release 93 * configuration when the test is built. 94 */ 95 DEVICE_DEFAULT, 96 } 97 SetFlagsRule()98 public SetFlagsRule() { 99 this(DefaultInitValueType.DEVICE_DEFAULT); 100 } 101 SetFlagsRule(DefaultInitValueType defaultType)102 public SetFlagsRule(DefaultInitValueType defaultType) { 103 this(defaultType, null); 104 } 105 SetFlagsRule(@ullable FlagsParameterization flagsParameterization)106 public SetFlagsRule(@Nullable FlagsParameterization flagsParameterization) { 107 this(DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization); 108 } 109 SetFlagsRule( DefaultInitValueType defaultType, @Nullable FlagsParameterization flagsParameterization)110 public SetFlagsRule( 111 DefaultInitValueType defaultType, 112 @Nullable FlagsParameterization flagsParameterization) { 113 this(defaultType, flagsParameterization, null); 114 } 115 SetFlagsRule( DefaultInitValueType defaultType, @Nullable FlagsParameterization flagsParameterization, @Nullable Listener listener)116 private SetFlagsRule( 117 DefaultInitValueType defaultType, 118 @Nullable FlagsParameterization flagsParameterization, 119 @Nullable Listener listener) { 120 mIsInitWithDefault = defaultType == DefaultInitValueType.DEVICE_DEFAULT; 121 mFlagsParameterization = flagsParameterization; 122 if (flagsParameterization != null) { 123 mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet()); 124 } 125 mListener = listener; 126 } 127 128 /** 129 * Set the FlagsParameterization to be used during this test. This cannot be used to override a 130 * previous call, and cannot be called once the rule has been evaluated. 131 */ setFlagsParameterization(@onnull FlagsParameterization flagsParameterization)132 public void setFlagsParameterization(@Nonnull FlagsParameterization flagsParameterization) { 133 Objects.requireNonNull(flagsParameterization, "FlagsParameterization cannot be cleared"); 134 if (mFlagsParameterization != null) { 135 throw new AssertionError("FlagsParameterization cannot be overridden"); 136 } 137 if (mIsRuleEvaluating) { 138 throw new AssertionError("Cannot set FlagsParameterization once the rule is running"); 139 } 140 ensureFlagsAreUnset(); 141 mFlagsParameterization = flagsParameterization; 142 mLockedFlagNames.addAll(flagsParameterization.mOverrides.keySet()); 143 } 144 145 /** 146 * Enables the given flags. 147 * 148 * @param fullFlagNames The name of the flags in the flag class with the format 149 * {packageName}.{flagName} 150 * 151 * @deprecated Annotate your test or class with <code>@EnableFlags(String...)</code> instead 152 */ 153 @Deprecated enableFlags(String... fullFlagNames)154 public void enableFlags(String... fullFlagNames) { 155 if (!mIsRuleEvaluating) { 156 throw new IllegalStateException("Not allowed to set flags outside test and setup code"); 157 } 158 for (String fullFlagName : fullFlagNames) { 159 if (mLockedFlagNames.contains(fullFlagName)) { 160 throw new FlagSetException(fullFlagName, "Not allowed to change locked flags"); 161 } 162 setFlagValue(fullFlagName, true); 163 } 164 } 165 166 /** 167 * Disables the given flags. 168 * 169 * @param fullFlagNames The name of the flags in the flag class with the format 170 * {packageName}.{flagName} 171 * 172 * @deprecated Annotate your test or class with <code>@DisableFlags(String...)</code> instead 173 */ 174 @Deprecated disableFlags(String... fullFlagNames)175 public void disableFlags(String... fullFlagNames) { 176 if (!mIsRuleEvaluating) { 177 throw new IllegalStateException("Not allowed to set flags outside test and setup code"); 178 } 179 for (String fullFlagName : fullFlagNames) { 180 if (mLockedFlagNames.contains(fullFlagName)) { 181 throw new FlagSetException(fullFlagName, "Not allowed to change locked flags"); 182 } 183 setFlagValue(fullFlagName, false); 184 } 185 } 186 ensureFlagsAreUnset()187 private void ensureFlagsAreUnset() { 188 if (!mFlagsClassToFakeFlagsImpl.isEmpty()) { 189 throw new IllegalStateException("Some flags were set before the rule was initialized"); 190 } 191 } 192 193 @Override apply(Statement base, Description description)194 public Statement apply(Statement base, Description description) { 195 return new Statement() { 196 @Override 197 public void evaluate() throws Throwable { 198 Throwable throwable = null; 199 try { 200 if (mListener != null) { 201 mListener.onStartedEvaluating(); 202 } 203 AnnotationsRetriever.FlagAnnotations flagAnnotations = 204 AnnotationsRetriever.getFlagAnnotations(description); 205 assertAnnotationsMatchParameterization(flagAnnotations, mFlagsParameterization); 206 flagAnnotations.assumeAllSetFlagsMatchParameterization(mFlagsParameterization); 207 if (mFlagsParameterization != null) { 208 ensureFlagsAreUnset(); 209 for (Map.Entry<String, Boolean> pair : 210 mFlagsParameterization.mOverrides.entrySet()) { 211 setFlagValue(pair.getKey(), pair.getValue()); 212 } 213 } 214 for (Map.Entry<String, Boolean> pair : 215 flagAnnotations.mSetFlagValues.entrySet()) { 216 setFlagValue(pair.getKey(), pair.getValue()); 217 } 218 mLockedFlagNames.addAll(flagAnnotations.mRequiredFlagValues.keySet()); 219 mLockedFlagNames.addAll(flagAnnotations.mSetFlagValues.keySet()); 220 mIsRuleEvaluating = true; 221 base.evaluate(); 222 } catch (Throwable t) { 223 throwable = t; 224 } finally { 225 mIsRuleEvaluating = false; 226 try { 227 resetFlags(); 228 } catch (Throwable t) { 229 if (throwable != null) { 230 t.addSuppressed(throwable); 231 } 232 throwable = t; 233 } 234 try { 235 if (mListener != null) { 236 mListener.onFinishedEvaluating(); 237 } 238 } catch (Throwable t) { 239 if (throwable != null) { 240 t.addSuppressed(throwable); 241 } 242 throwable = t; 243 } 244 } 245 if (throwable != null) throw throwable; 246 } 247 }; 248 } 249 250 private static void assertAnnotationsMatchParameterization( 251 AnnotationsRetriever.FlagAnnotations flagAnnotations, 252 FlagsParameterization parameterization) { 253 if (parameterization == null) return; 254 Set<String> parameterizedFlags = parameterization.mOverrides.keySet(); 255 Set<String> requiredFlags = flagAnnotations.mRequiredFlagValues.keySet(); 256 // Assert that NO Annotation-Required flag is in the parameterization 257 Set<String> parameterizedAndRequiredFlags = 258 Sets.intersection(parameterizedFlags, requiredFlags); 259 if (!parameterizedAndRequiredFlags.isEmpty()) { 260 throw new AssertionError( 261 "The following flags have required values (per @RequiresFlagsEnabled or" 262 + " @RequiresFlagsDisabled) but they are part of the" 263 + " FlagParameterization: " 264 + parameterizedAndRequiredFlags); 265 } 266 } 267 268 private void setFlagValue(String fullFlagName, boolean value) { 269 if (!fullFlagName.contains(".")) { 270 throw new FlagSetException( 271 fullFlagName, "Flag name is not the expected format {packgeName}.{flagName}."); 272 } 273 // Get all packages containing Flags referencing the same fullFlagName. 274 Set<String> packageSet = getPackagesContainsFlag(fullFlagName); 275 276 for (String packageName : packageSet) { 277 setFlagValue(Flag.createFlag(fullFlagName, packageName), value); 278 } 279 } 280 281 private Set<String> getPackagesContainsFlag(String fullFlagName) { 282 return getAllPackagesForFlag(fullFlagName, mPackageToRepackage); 283 } 284 285 private static Set<String> getAllPackagesForFlag( 286 String fullFlagName, Map<String, Set<String>> packageToRepackage) { 287 288 String packageName = Flag.getFlagPackageName(fullFlagName); 289 Set<String> packageSet = packageToRepackage.getOrDefault(packageName, new HashSet<>()); 290 291 if (!packageSet.isEmpty()) { 292 return packageSet; 293 } 294 295 for (String prefix : REPACKAGE_PREFIX_LIST) { 296 String repackagedName = String.format("%s%s", prefix, packageName); 297 String flagClassName = String.format("%s.%s", repackagedName, FLAGS_CLASS_NAME); 298 try { 299 Class.forName(flagClassName, false, SetFlagsRule.class.getClassLoader()); 300 packageSet.add(repackagedName); 301 } catch (ClassNotFoundException e) { 302 // Skip if the class is not found 303 // An error will be thrown if no package containing flags referencing 304 // the passed in flag 305 } 306 } 307 packageToRepackage.put(packageName, packageSet); 308 if (packageSet.isEmpty()) { 309 throw new FlagSetException( 310 fullFlagName, 311 "Cannot find package containing Flags class referencing to this flag."); 312 } 313 return packageSet; 314 } 315 316 private void setFlagValue(Flag flag, boolean value) { 317 if (mListener != null) { 318 mListener.onBeforeSetFlag(flag, value); 319 } 320 321 Object fakeFlagsImplInstance = null; 322 323 Class<?> flagsClass = getFlagClassFromFlag(flag); 324 fakeFlagsImplInstance = getOrCreateFakeFlagsImp(flagsClass); 325 326 if (!mMutatedFlagsClasses.contains(flagsClass)) { 327 // Replace FeatureFlags in Flags class with FakeFeatureFlagsImpl 328 replaceFlagsImpl(flagsClass, fakeFlagsImplInstance); 329 mMutatedFlagsClasses.add(flagsClass); 330 } 331 332 // If the test is trying to set the flag value on a read_only flag in an optimized build 333 // skip this test, since it is not a valid testing case 334 // The reason for skipping instead of throwning error here is all read_write flag will be 335 // change to read_only in the final release configuration. Thus the test could be executed 336 // in other release configuration cases 337 // TODO(b/337449119): SetFlagsRule should still run tests that are consistent with the 338 // read-only values of flags. But be careful, if a ClassRule exists, the value returned by 339 // the original FeatureFlags instance may be overridden, and reading it may not be allowed. 340 boolean isOptimized = verifyFlagReadOnlyAndOptimized(fakeFlagsImplInstance, flag); 341 assumeFalse( 342 String.format( 343 "Flag %s is read_only, and the code is optimized. " 344 + " The flag value should not be modified on this build" 345 + " Skip this test.", 346 flag.fullFlagName()), 347 isOptimized); 348 349 // Set desired flag value in the FakeFeatureFlagsImpl 350 setFlagValueInFakeFeatureFlagsImpl(fakeFlagsImplInstance, flag, value); 351 } 352 353 private static Class<?> getFlagClassFromFlag(Flag flag) { 354 String className = flag.flagsClassName(); 355 Class<?> flagsClass = null; 356 try { 357 flagsClass = Class.forName(className); 358 } catch (ClassNotFoundException e) { 359 throw new FlagSetException( 360 flag.fullFlagName(), 361 String.format( 362 "Can not load the Flags class %s to set its values. Please check the " 363 + "flag name and ensure that the aconfig auto generated " 364 + "library is in the dependency.", 365 className), 366 e); 367 } 368 return flagsClass; 369 } 370 371 private static Class<?> getFlagClassFromFlagsClassName(String className) { 372 if (!className.endsWith("." + FLAGS_CLASS_NAME)) { 373 throw new FlagSetException( 374 className, 375 "Can not watch this Flags class because it is not named 'Flags'. Please ensure" 376 + " your @UsesFlags() annotations only reference the Flags classes."); 377 } 378 try { 379 return Class.forName(className); 380 } catch (ClassNotFoundException e) { 381 throw new FlagSetException( 382 className, 383 "Cannot load this Flags class to set its values. Please check the flag name and" 384 + " ensure that the aconfig auto generated library is in the dependency.", 385 e); 386 } 387 } 388 389 private boolean getFlagValue(Object featureFlagsImpl, Flag flag) { 390 // Must be consistent with method name in aconfig auto generated code. 391 String methodName = getFlagMethodName(flag); 392 String fullFlagName = flag.fullFlagName(); 393 394 try { 395 Object result = 396 featureFlagsImpl.getClass().getMethod(methodName).invoke(featureFlagsImpl); 397 if (result instanceof Boolean) { 398 return (Boolean) result; 399 } 400 throw new FlagReadException( 401 fullFlagName, 402 String.format( 403 "Flag type is %s, not boolean", result.getClass().getSimpleName())); 404 } catch (NoSuchMethodException e) { 405 throw new FlagReadException( 406 fullFlagName, 407 String.format( 408 "No method %s in the Flags class %s to read the flag value. Please" 409 + " check the flag name.", 410 methodName, featureFlagsImpl.getClass().getName()), 411 e); 412 } catch (ReflectiveOperationException e) { 413 throw new FlagReadException( 414 fullFlagName, 415 String.format( 416 "Fail to get value of flag %s from instance %s", 417 fullFlagName, featureFlagsImpl.getClass().getName()), 418 e); 419 } 420 } 421 422 private String getFlagMethodName(Flag flag) { 423 return CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, flag.simpleFlagName()); 424 } 425 426 private void setFlagValueInFakeFeatureFlagsImpl( 427 Object fakeFeatureFlagsImpl, Flag flag, boolean value) { 428 String fullFlagName = flag.fullFlagName(); 429 try { 430 fakeFeatureFlagsImpl 431 .getClass() 432 .getMethod(SET_FLAG_METHOD_NAME, String.class, boolean.class) 433 .invoke(fakeFeatureFlagsImpl, fullFlagName, value); 434 } catch (NoSuchMethodException e) { 435 throw new FlagSetException( 436 fullFlagName, 437 String.format( 438 "Flag implementation %s is not fake implementation", 439 fakeFeatureFlagsImpl.getClass().getName()), 440 e); 441 } catch (ReflectiveOperationException e) { 442 throw new FlagSetException(fullFlagName, e); 443 } 444 } 445 446 private static boolean verifyFlagReadOnlyAndOptimized(Object fakeFeatureFlagsImpl, Flag flag) { 447 String fullFlagName = flag.fullFlagName(); 448 try { 449 boolean result = 450 (Boolean) 451 fakeFeatureFlagsImpl 452 .getClass() 453 .getMethod( 454 IS_FLAG_READ_ONLY_OPTIMIZED_METHOD_NAME, String.class) 455 .invoke(fakeFeatureFlagsImpl, fullFlagName); 456 return result; 457 } catch (NoSuchMethodException e) { 458 // If the flag is generated under exported mode, then it doesn't have this method 459 String simpleClassName = fakeFeatureFlagsImpl.getClass().getSimpleName(); 460 if (simpleClassName.equals(FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME)) { 461 return false; 462 } 463 if (simpleClassName.equals(CUSTOM_FEATURE_FLAGS_CLASS_NAME)) { 464 return false; 465 } 466 throw new FlagSetException( 467 fullFlagName, 468 String.format( 469 "Cannot check whether flag is optimized. " 470 + "Flag implementation %s is not fake implementation", 471 fakeFeatureFlagsImpl.getClass().getName()), 472 e); 473 } catch (ReflectiveOperationException e) { 474 throw new FlagSetException(fullFlagName, e); 475 } 476 } 477 478 @Nonnull 479 private Object getOrCreateFakeFlagsImp(Class<?> flagsClass) { 480 Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass); 481 if (fakeFlagsImplInstance != null) { 482 return fakeFlagsImplInstance; 483 } 484 485 String packageName = flagsClass.getPackageName(); 486 String fakeClassName = 487 String.format("%s.%s", packageName, FAKE_FEATURE_FLAGS_IMPL_CLASS_NAME); 488 String interfaceName = String.format("%s.%s", packageName, FEATURE_FLAGS_CLASS_NAME); 489 490 Object realFlagsImplInstance = readFlagsImpl(flagsClass); 491 mFlagsClassToRealFlagsImpl.put(flagsClass, realFlagsImplInstance); 492 493 try { 494 Class<?> flagImplClass = Class.forName(fakeClassName); 495 Class<?> flagInterface = Class.forName(interfaceName); 496 fakeFlagsImplInstance = 497 flagImplClass 498 .getConstructor(flagInterface) 499 .newInstance(mIsInitWithDefault ? realFlagsImplInstance : null); 500 } catch (ReflectiveOperationException e) { 501 throw new UnsupportedOperationException( 502 String.format( 503 "Cannot create FakeFeatureFlagsImpl in Flags class %s.", 504 flagsClass.getName()), 505 e); 506 } 507 508 mFlagsClassToFakeFlagsImpl.put(flagsClass, fakeFlagsImplInstance); 509 510 return fakeFlagsImplInstance; 511 } 512 513 private static void replaceFlagsImpl(Class<?> flagsClass, Object flagsImplInstance) { 514 Field featureFlagsField = getFeatureFlagsField(flagsClass); 515 try { 516 featureFlagsField.set(null, flagsImplInstance); 517 } catch (IllegalAccessException e) { 518 throw new UnsupportedOperationException( 519 String.format( 520 "Cannot replace FeatureFlagsImpl to %s.", 521 flagsImplInstance.getClass().getName()), 522 e); 523 } 524 } 525 526 private static Object readFlagsImpl(Class<?> flagsClass) { 527 Field featureFlagsField = getFeatureFlagsField(flagsClass); 528 try { 529 return featureFlagsField.get(null); 530 } catch (IllegalAccessException e) { 531 throw new UnsupportedOperationException( 532 String.format( 533 "Cannot get FeatureFlags from Flags class %s.", flagsClass.getName()), 534 e); 535 } 536 } 537 538 private static Field getFeatureFlagsField(Class<?> flagsClass) { 539 Field featureFlagsField = null; 540 try { 541 featureFlagsField = flagsClass.getDeclaredField(FEATURE_FLAGS_FIELD_NAME); 542 } catch (ReflectiveOperationException e) { 543 throw new UnsupportedOperationException( 544 String.format( 545 "Cannot store FeatureFlagsImpl in Flag %s.", flagsClass.getName()), 546 e); 547 } 548 featureFlagsField.setAccessible(true); 549 return featureFlagsField; 550 } 551 552 private void resetFlags() { 553 String flagsClassName = null; 554 try { 555 for (Class<?> flagsClass : mMutatedFlagsClasses) { 556 flagsClassName = flagsClass.getName(); 557 Object fakeFlagsImplInstance = mFlagsClassToFakeFlagsImpl.get(flagsClass); 558 Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass); 559 // Replace FeatureFlags in Flags class with real FeatureFlagsImpl 560 replaceFlagsImpl(flagsClass, flagsImplInstance); 561 fakeFlagsImplInstance 562 .getClass() 563 .getMethod(RESET_ALL_METHOD_NAME) 564 .invoke(fakeFlagsImplInstance); 565 } 566 mMutatedFlagsClasses.clear(); 567 } catch (Exception e) { 568 throw new FlagSetException(flagsClassName, e); 569 } 570 } 571 572 /** An interface that provides hooks to the ClassRule. */ 573 private interface Listener { 574 /** Called before a flag is set. */ 575 void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value); 576 577 /** Called after the rule has started evaluating for a test. */ 578 void onStartedEvaluating(); 579 580 /** Called after the rule has finished evaluating for a test. */ 581 void onFinishedEvaluating(); 582 } 583 584 /** 585 * A @ClassRule which adds extra consistency checks for SetFlagsRule. 586 * <li>Requires that tests monitor the Flags class of any flag that is set. 587 * <li>Fails a test if a flag that is set was read before the test started. 588 */ 589 public static class ClassRule implements TestRule { 590 /** The flags classes that are requested to be watched during construction. */ 591 private final Set<Class<?>> mGlobalFlagsClassesToWatch = new HashSet<>(); 592 593 /** The flags packages that are allowed to be set, for quick per-flag lookup */ 594 private final Set<String> mSettableFlagsPackages = new HashSet<>(); 595 596 /** The mapping from the Flags classes to the real implementations */ 597 private final Map<Class<?>, Object> mFlagsClassToRealFlagsImpl = new HashMap<>(); 598 599 /** The mapping from the Flags classes to the watcher implementations */ 600 private final Map<Class<?>, Object> mFlagsClassToWatcherImpl = new HashMap<>(); 601 602 /** The flags classes that have actually been mutated */ 603 private final Set<Class<?>> mMutatedFlagsClasses = new HashSet<>(); 604 605 /** The flag values set by class annotations */ 606 private final Map<String, Boolean> mClassLevelSetFlagValues = new ConcurrentHashMap<>(); 607 608 /** 609 * The individual flags which have been read from prior to tests starting, mapped to the 610 * stack trace of the first read. 611 */ 612 private final Map<String, FirstFlagRead> mFirstReadOutsideTestsByFlag = 613 new ConcurrentHashMap<>(); 614 615 /** 616 * The individual flags which have been read from within a test, mapped to the stack trace 617 * of the first read. 618 */ 619 private final Map<String, FirstFlagRead> mFirstReadWithinTestByFlag = 620 new ConcurrentHashMap<>(); 621 622 /** repackage cache */ 623 private final Map<String, Set<String>> mPackageToRepackage = new HashMap<>(); 624 625 /** The depth of the ClassRule evaluating on potentially nested suites */ 626 private int mSuiteRunDepth = 0; 627 628 /** Whether the SetFlagsRule is evaluating for a test */ 629 private boolean mIsTestRunning = false; 630 631 /** Typical constructor takes an initial list flags classes to watch */ 632 public ClassRule(Class<?>... flagsClasses) { 633 for (Class<?> flagsClass : flagsClasses) { 634 mGlobalFlagsClassesToWatch.add(flagsClass); 635 } 636 } 637 638 /** Listener to be notified of events in any created SetFlagsRule */ 639 private SetFlagsRule.Listener mListener = 640 new SetFlagsRule.Listener() { 641 @Override 642 public void onBeforeSetFlag(SetFlagsRule.Flag flag, boolean value) { 643 if (!mIsTestRunning) { 644 throw new IllegalStateException("Inner rule should be running!"); 645 } 646 assertFlagCanBeSet(flag, value); 647 } 648 649 @Override 650 public void onStartedEvaluating() { 651 if (mSuiteRunDepth == 0) { 652 throw new IllegalStateException("Outer rule should be running!"); 653 } 654 if (mIsTestRunning) { 655 throw new IllegalStateException("Inner rule is still running!"); 656 } 657 mIsTestRunning = true; 658 } 659 660 @Override 661 public void onFinishedEvaluating() { 662 if (!mIsTestRunning) { 663 throw new IllegalStateException("Inner rule did not start!"); 664 } 665 mIsTestRunning = false; 666 checkAllFlagsWatchersRestored(); 667 mFirstReadWithinTestByFlag.clear(); 668 } 669 }; 670 671 /** 672 * Creates a SetFlagsRule which will work as normal, but additionally enforce the guarantees 673 * about not setting flags that were read within the ClassRule 674 */ 675 public SetFlagsRule createSetFlagsRule() { 676 return createSetFlagsRule(null); 677 } 678 679 /** 680 * Creates a SetFlagsRule with parameterization which will work as normal, but additionally 681 * enforce the guarantees about not setting flags that were read within the ClassRule 682 */ 683 public SetFlagsRule createSetFlagsRule( 684 @Nullable FlagsParameterization flagsParameterization) { 685 return new SetFlagsRule( 686 DefaultInitValueType.DEVICE_DEFAULT, flagsParameterization, mListener); 687 } 688 689 private boolean isFlagsClassMonitored(SetFlagsRule.Flag flag) { 690 return mSettableFlagsPackages.contains(flag.flagPackageName()); 691 } 692 693 private void assertFlagCanBeSet(SetFlagsRule.Flag flag, boolean value) { 694 Exception firstReadWithinTest = mFirstReadWithinTestByFlag.get(flag.fullFlagName()); 695 if (firstReadWithinTest != null) { 696 throw new FlagSetException( 697 flag.fullFlagName(), 698 "This flag was locked when it was read earlier in this test. To fix this" 699 + " error, always use @EnableFlags() and @DisableFlags() to set" 700 + " flags, which ensures flags are set before even any" 701 + " @Before-annotated setup methods.", 702 firstReadWithinTest); 703 } 704 Exception firstReadOutsideTest = mFirstReadOutsideTestsByFlag.get(flag.fullFlagName()); 705 if (firstReadOutsideTest != null) { 706 throw new FlagSetException( 707 flag.fullFlagName(), 708 "This flag was locked when it was read outside of the test code; likely" 709 + " during initialization of the test class. To fix this error," 710 + " move test fixture initialization code into your" 711 + " @Before-annotated setup method, and ensure you are using" 712 + " @EnableFlags() and @DisableFlags() to set flags.", 713 firstReadOutsideTest); 714 } 715 if (!isFlagsClassMonitored(flag)) { 716 throw new FlagSetException( 717 flag.fullFlagName(), 718 "This flag's class is not monitored. Always use @EnableFlags() and" 719 + " @DisableFlags() on the class or method instead of" 720 + " .enableFlags() or .disableFlags() to prevent this error. When" 721 + " using FlagsParameterization, add `@UsesFlags(" 722 + flag.flagPackageName() 723 + ".Flags.class)` to the test class. As a last resort, pass the" 724 + " Flags class to the constructor of your" 725 + " SetFlagsRule.ClassRule."); 726 } 727 // Detect errors where the rule messed up and set the wrong flag value. 728 Boolean classLevelValue = mClassLevelSetFlagValues.get(flag.fullFlagName()); 729 if (classLevelValue != null && classLevelValue != value) { 730 throw new FlagSetException( 731 flag.fullFlagName(), 732 "This flag's value was set at the class level to a different value."); 733 } 734 } 735 736 private void checkInstanceOfRealFlagsImpl(Object actual) { 737 if (!actual.getClass().getSimpleName().equals(REAL_FEATURE_FLAGS_IMPL_CLASS_NAME)) { 738 throw new IllegalStateException( 739 String.format( 740 "Wrong impl type during setup: '%s' is not a %s", 741 actual, REAL_FEATURE_FLAGS_IMPL_CLASS_NAME)); 742 } 743 } 744 745 private void checkSameAs(Object expected, Object actual) { 746 if (expected != actual) { 747 throw new IllegalStateException( 748 String.format( 749 "Wrong impl instance found during teardown: expected %s but was %s", 750 expected, actual)); 751 } 752 } 753 754 private Object getOrCreateFlagReadWatcher(Class<?> flagsClass) { 755 Object watcher = mFlagsClassToWatcherImpl.get(flagsClass); 756 if (watcher != null) { 757 return watcher; 758 } 759 Object flagsImplInstance = readFlagsImpl(flagsClass); 760 // strict mode: ensure that the current impl is the real impl 761 checkInstanceOfRealFlagsImpl(flagsImplInstance); 762 // save the real impl for restoration later 763 mFlagsClassToRealFlagsImpl.put(flagsClass, flagsImplInstance); 764 watcher = newFlagReadWatcher(flagsClass, flagsImplInstance); 765 mFlagsClassToWatcherImpl.put(flagsClass, watcher); 766 return watcher; 767 } 768 769 private void recordFlagRead(String flagName) { 770 if (mIsTestRunning) { 771 mFirstReadWithinTestByFlag.computeIfAbsent(flagName, FirstFlagRead::new); 772 } else { 773 mFirstReadOutsideTestsByFlag.computeIfAbsent(flagName, FirstFlagRead::new); 774 } 775 } 776 777 private Object newFlagReadWatcher(Class<?> flagsClass, Object flagsImplInstance) { 778 String packageName = flagsClass.getPackageName(); 779 String customClassName = 780 String.format("%s.%s", packageName, CUSTOM_FEATURE_FLAGS_CLASS_NAME); 781 BiPredicate<String, Predicate<Object>> getValueImpl = 782 (flagName, predicate) -> { 783 // Flags set at the class level pose no consistency risk 784 Boolean value = mClassLevelSetFlagValues.get(flagName); 785 if (value != null) { 786 return value; 787 } 788 recordFlagRead(flagName); 789 return predicate.test(flagsImplInstance); 790 }; 791 try { 792 Class<?> customFlagsClass = Class.forName(customClassName); 793 return customFlagsClass.getConstructor(BiPredicate.class).newInstance(getValueImpl); 794 } catch (ReflectiveOperationException e) { 795 throw new UnsupportedOperationException( 796 String.format( 797 "Cannot create CustomFeatureFlags in Flags class %s.", 798 flagsClass.getName()), 799 e); 800 } 801 } 802 803 /** Get the package name of the flags in this class. This is the non-repackaged name. */ 804 private String getFlagPackageName(Class<?> flagsClass) { 805 String classPackageName = flagsClass.getPackageName(); 806 String shortestPackageName = classPackageName; 807 for (String prefix : REPACKAGE_PREFIX_LIST) { 808 if (prefix.isEmpty()) continue; 809 if (classPackageName.startsWith(prefix)) { 810 String unprefixedPackage = classPackageName.substring(prefix.length()); 811 if (unprefixedPackage.length() < shortestPackageName.length()) { 812 shortestPackageName = unprefixedPackage; 813 } 814 } 815 } 816 return shortestPackageName; 817 } 818 819 private void setupClassLevelFlagValues(Description description) { 820 mClassLevelSetFlagValues.putAll( 821 AnnotationsRetriever.getFlagAnnotations(description).mSetFlagValues); 822 } 823 824 private void setupFlagsWatchers(Description description) { 825 // Start with the static list of Flags classes to watch 826 Set<Class<?>> flagsClassesToWatch = new HashSet<>(mGlobalFlagsClassesToWatch); 827 // Collect the Flags classes from @UsedFlags annotations on the Descriptor 828 Set<String> usedFlagsClasses = AnnotationsRetriever.getAllUsedFlagsClasses(description); 829 for (String flagsClassName : usedFlagsClasses) { 830 flagsClassesToWatch.add(getFlagClassFromFlagsClassName(flagsClassName)); 831 } 832 // Now setup watchers on the provided Flags classes 833 for (Class<?> flagsClass : flagsClassesToWatch) { 834 setupFlagsWatcher(flagsClass, getFlagPackageName(flagsClass)); 835 } 836 // Get all annotated flags and then the distinct packages for each flag 837 Set<String> setFlags = AnnotationsRetriever.getAllAnnotationSetFlags(description); 838 Set<String> extraFlagPackages = new HashSet<>(); 839 for (String setFlag : setFlags) { 840 extraFlagPackages.add(Flag.getFlagPackageName(setFlag)); 841 } 842 // Do not bother with flags that are already monitored 843 extraFlagPackages.removeAll(mSettableFlagsPackages); 844 // Expand packages to all repackaged versions, stored as Flag objects 845 Set<Flag> extraWildcardFlags = new HashSet<>(); 846 for (String extraFlagPackage : extraFlagPackages) { 847 String fullFlagName = extraFlagPackage + ".*"; 848 Set<String> packages = getAllPackagesForFlag(fullFlagName, mPackageToRepackage); 849 for (String packageName : packages) { 850 Flag flag = Flag.createFlag(fullFlagName, packageName); 851 extraWildcardFlags.add(flag); 852 } 853 } 854 // Set up watchers for each wildcard flag 855 for (Flag flag : extraWildcardFlags) { 856 Class<?> flagsClass = getFlagClassFromFlag(flag); 857 setupFlagsWatcher(flagsClass, flag.flagPackageName()); 858 } 859 } 860 861 private void setupFlagsWatcher(Class<?> flagsClass, String flagPackageName) { 862 if (mMutatedFlagsClasses.contains(flagsClass)) { 863 throw new IllegalStateException( 864 String.format("Flags class %s is already mutated", flagsClass.getName())); 865 } 866 Object watcher = getOrCreateFlagReadWatcher(flagsClass); 867 replaceFlagsImpl(flagsClass, watcher); 868 mMutatedFlagsClasses.add(flagsClass); 869 mSettableFlagsPackages.add(flagPackageName); 870 } 871 872 private void teardownFlagsWatchers() { 873 try { 874 for (Class<?> flagsClass : mMutatedFlagsClasses) { 875 Object flagsImplInstance = mFlagsClassToRealFlagsImpl.get(flagsClass); 876 // strict mode: ensure that the watcher is still in place 877 Object watcher = readFlagsImpl(flagsClass); 878 checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher); 879 // Replace FeatureFlags in Flags class with real FeatureFlagsImpl 880 replaceFlagsImpl(flagsClass, flagsImplInstance); 881 } 882 mMutatedFlagsClasses.clear(); 883 mSettableFlagsPackages.clear(); 884 mFirstReadOutsideTestsByFlag.clear(); 885 } catch (IllegalStateException e) { 886 throw e; 887 } catch (Exception e) { 888 throw new IllegalStateException("Failed to teardown Flags watchers", e); 889 } 890 if (mIsTestRunning) { 891 throw new IllegalStateException("An inner SetFlagsRule is still running"); 892 } 893 if (!mFirstReadWithinTestByFlag.isEmpty()) { 894 throw new IllegalStateException("An inner SetFlagsRule did not fully clean up"); 895 } 896 } 897 898 private void checkAllFlagsWatchersRestored() { 899 for (Class<?> flagsClass : mMutatedFlagsClasses) { 900 Object watcher = readFlagsImpl(flagsClass); 901 checkSameAs(mFlagsClassToWatcherImpl.get(flagsClass), watcher); 902 } 903 } 904 905 @Override 906 public Statement apply(Statement base, Description description) { 907 return new Statement() { 908 @Override 909 public void evaluate() throws Throwable { 910 Throwable throwable = null; 911 final int initialDepth = mSuiteRunDepth; 912 try { 913 mSuiteRunDepth++; 914 if (initialDepth == 0) { 915 setupFlagsWatchers(description); 916 setupClassLevelFlagValues(description); 917 } 918 base.evaluate(); 919 } catch (Throwable t) { 920 throwable = t; 921 } finally { 922 mSuiteRunDepth--; 923 try { 924 if (initialDepth == 0) { 925 mClassLevelSetFlagValues.clear(); 926 teardownFlagsWatchers(); 927 } 928 if (mSuiteRunDepth != initialDepth) { 929 throw new IllegalStateException( 930 String.format( 931 "Evaluations were not correctly nested: initial" 932 + " depth was %d but final depth was %d", 933 initialDepth, mSuiteRunDepth)); 934 } 935 } catch (Throwable t) { 936 if (throwable != null) { 937 t.addSuppressed(throwable); 938 } 939 throwable = t; 940 } 941 } 942 if (throwable != null) throw throwable; 943 } 944 }; 945 } 946 } 947 948 private static class FirstFlagRead extends Exception { 949 FirstFlagRead(String flagName) { 950 super(String.format("Flag '%s' was first read at this location:", flagName)); 951 } 952 } 953 954 private static class Flag { 955 private static final String PACKAGE_NAME_SIMPLE_NAME_SEPARATOR = "."; 956 private final String mFullFlagName; 957 private final String mFlagPackageName; 958 private final String mClassPackageName; 959 private final String mSimpleFlagName; 960 961 public static String getFlagPackageName(String fullFlagName) { 962 int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR); 963 return fullFlagName.substring(0, index); 964 } 965 966 public static Flag createFlag(String fullFlagName) { 967 int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR); 968 String packageName = fullFlagName.substring(0, index); 969 return createFlag(fullFlagName, packageName); 970 } 971 972 public static Flag createFlag(String fullFlagName, String classPackageName) { 973 if (!fullFlagName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR) 974 || !classPackageName.contains(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR)) { 975 throw new IllegalArgumentException( 976 String.format( 977 "Flag %s is invalid. The format should be {packageName}" 978 + ".{simpleFlagName}", 979 fullFlagName)); 980 } 981 int index = fullFlagName.lastIndexOf(PACKAGE_NAME_SIMPLE_NAME_SEPARATOR); 982 String flagPackageName = fullFlagName.substring(0, index); 983 String simpleFlagName = fullFlagName.substring(index + 1); 984 985 return new Flag(fullFlagName, flagPackageName, classPackageName, simpleFlagName); 986 } 987 988 private Flag( 989 String fullFlagName, 990 String flagPackageName, 991 String classPackageName, 992 String simpleFlagName) { 993 this.mFullFlagName = fullFlagName; 994 this.mFlagPackageName = flagPackageName; 995 this.mClassPackageName = classPackageName; 996 this.mSimpleFlagName = simpleFlagName; 997 } 998 999 public String fullFlagName() { 1000 return mFullFlagName; 1001 } 1002 1003 public String flagPackageName() { 1004 return mFlagPackageName; 1005 } 1006 1007 public String classPackageName() { 1008 return mClassPackageName; 1009 } 1010 1011 public String simpleFlagName() { 1012 return mSimpleFlagName; 1013 } 1014 1015 public String flagsClassName() { 1016 return String.format("%s.%s", classPackageName(), FLAGS_CLASS_NAME); 1017 } 1018 } 1019 } 1020