1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/flags.h"
17 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
18 #include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20 
21 namespace tensorflow {
22 namespace {
23 
TEST_F(XlaCompilationCacheSerializeTest,PersistentCacheOptionsTest)24 TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheOptionsTest) {
25   GraphDef graph = GetTestGraph({-1, 4});
26 
27   // Warmup the persistent cache(s) with multiple runs. 4 is a magic number to
28   // detect non-determinism in TF when running the test.
29   listener()->ClearListenerHistory();
30   for (int b = 1; b < 4; ++b) {
31     TF_ASSERT_OK(ExecuteWithBatch(graph, b));
32   }
33   TF_ASSERT_OK(
34       listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/false));
35 
36   // Reset the cluster numbering between sessions so we can get the same
37   // cluster numbering.
38   testing::ResetClusterSequenceNumber();
39 
40   auto status =
41       AlterPersistentCacheEntryHloModuleNames(tensorflow::testing::TmpDir());
42   EXPECT_FALSE(status.ok());
43   EXPECT_TRUE(absl::StrContains(
44       status.error_message(),
45       "Did not find any persistent XLA compilation cache entries to alter."));
46 
47   TF_ASSERT_OK(AlterPersistentCacheEntryHloModuleNames(
48       tensorflow::testing::TmpDir(), "my_test_prefix"));
49 
50   // Run again and these should all hit in the persistent cache despite having
51   // altered the persistent cache entries' HLO modules (disabled strict
52   // signature checks).
53   listener()->ClearListenerHistory();
54   for (int b = 1; b < 4; ++b) {
55     TF_ASSERT_OK(ExecuteWithBatch(graph, b));
56   }
57   TF_ASSERT_OK(
58       listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/true));
59 }
60 
61 }  // namespace
62 }  // namespace tensorflow
63 
main(int argc,char ** argv)64 int main(int argc, char** argv) {
65   tensorflow::GetMarkForCompilationPassFlags()
66       ->tf_xla_deterministic_cluster_names = true;
67   tensorflow::GetMarkForCompilationPassFlags()
68       ->tf_xla_persistent_cache_directory = tensorflow::testing::TmpDir();
69   tensorflow::GetMarkForCompilationPassFlags()
70       ->tf_xla_disable_strict_signature_checks = true;
71   tensorflow::GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix =
72       "my_test_prefix";
73   ::testing::InitGoogleTest(&argc, argv);
74   return RUN_ALL_TESTS();
75 }
76