1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/flags.h"
17 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
18 #include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h"
19 #include "tensorflow/core/lib/core/status_test_util.h"
20
21 namespace tensorflow {
22 namespace {
23
TEST_F(XlaCompilationCacheSerializeTest,PersistentCacheTest)24 TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheTest) {
25 GraphDef graph = GetTestGraph({-1, 4});
26
27 // Warmup the persistent cache(s) with multiple runs. 4 is a magic number to
28 // detect non-determinism in TF when running the test.
29 listener()->ClearListenerHistory();
30 for (int b = 1; b < 4; ++b) {
31 TF_ASSERT_OK(ExecuteWithBatch(graph, b));
32 }
33 TF_ASSERT_OK(
34 listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/false));
35
36 // Reset the cluster numbering between sessions so we can get the same
37 // cluster numbering.
38 testing::ResetClusterSequenceNumber();
39
40 // Run again but these should all hit in the persistent cache.
41 listener()->ClearListenerHistory();
42 for (int b = 1; b < 4; ++b) {
43 TF_ASSERT_OK(ExecuteWithBatch(graph, b));
44 }
45 TF_ASSERT_OK(
46 listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/true));
47
48 // Reset the cluster numbering between sessions so we can get the same
49 // cluster numbering.
50 testing::ResetClusterSequenceNumber();
51
52 TF_ASSERT_OK(
53 AlterPersistentCacheEntryHloModuleNames(tensorflow::testing::TmpDir()));
54
55 // Run again but these should all fail, because the persistent cache entries'
56 // HLO modules have been altered.
57 for (int b = 1; b < 4; ++b) {
58 auto status = ExecuteWithBatch(graph, b);
59 EXPECT_FALSE(status.ok());
60 EXPECT_TRUE(absl::StrContains(status.error_message(),
61 "Serialized HLO does not match."));
62 }
63 }
64
65 } // namespace
66 } // namespace tensorflow
67
main(int argc,char ** argv)68 int main(int argc, char** argv) {
69 tensorflow::GetMarkForCompilationPassFlags()
70 ->tf_xla_deterministic_cluster_names = true;
71 tensorflow::GetMarkForCompilationPassFlags()
72 ->tf_xla_persistent_cache_directory = tensorflow::testing::TmpDir();
73 ::testing::InitGoogleTest(&argc, argv);
74 return RUN_ALL_TESTS();
75 }
76