xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/framework/dataset_options.pb.h"
23 #include "tensorflow/core/platform/status.h"
24 #include "tensorflow/core/platform/statusor.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/core/protobuf/data_service.pb.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 // Increment this when making backwards-incompatible changes to communication
32 // between tf.data clients and servers.
33 constexpr int kDataServiceVersion = 5;
34 
35 // If the user starts a colocated tf.data worker on each TF host, the worker
36 // will be applied a "COLOCATED" tag. This is used to avoid reading from tf.data
37 // workers on other TF hosts when the host runs a local tf.data service worker.
38 constexpr absl::string_view kColocatedWorkerTag = "COLOCATED";
39 
40 // Returns true if `processing_mode` specifies no sharding policy.
41 bool IsNoShard(const ProcessingModeDef& processing_mode);
42 
43 // Returns true if `processing_mode` is dynamic sharding.
44 bool IsDynamicShard(const ProcessingModeDef& processing_mode);
45 
46 // Returns true if `processing_mode` is static sharding.
47 bool IsStaticShard(const ProcessingModeDef& processing_mode);
48 
49 // Returns an internal error if `processing_mode` is invalid.
50 Status ValidateProcessingMode(const ProcessingModeDef& processing_mode);
51 
52 // Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an
53 // internal error if `sharding_policy` is not supported.
54 StatusOr<AutoShardPolicy> ToAutoShardPolicy(
55     ProcessingModeDef::ShardingPolicy sharding_policy);
56 
57 // Parses a string representing a `TargetWorkers` (case-insensitive).
58 // Returns InvalidArgument if the string is not recognized.
59 StatusOr<TargetWorkers> ParseTargetWorkers(absl::string_view s);
60 
61 // Converts a `TargetWorkers` enum to string.
62 std::string TargetWorkersToString(TargetWorkers target_workers);
63 
64 // Parses a string representing a `DeploymentMode` (case-insensitive).
65 // Returns InvalidArgument if the string is not recognized.
66 StatusOr<DeploymentMode> ParseDeploymentMode(absl::string_view s);
67 
68 // Returns true if `status` is a retriable error that indicates preemption.
69 bool IsPreemptedError(const Status& status);
70 
71 // Base class for data service clients. Data service clients are
72 // threadsafe.
73 class DataServiceClientBase {
74  public:
DataServiceClientBase(const std::string & address,const std::string & protocol)75   DataServiceClientBase(const std::string& address, const std::string& protocol)
76       : address_(address), protocol_(protocol) {}
77 
78   virtual ~DataServiceClientBase() = default;
79   // Not copyable or movable.
80   DataServiceClientBase(const DataServiceClientBase&) = delete;
81   DataServiceClientBase& operator=(const DataServiceClientBase&) = delete;
82 
83   // Initializes the client. Calling `Initialize()` is not required since the
84   // first RPC will perform any necessary initialization. However, it can be
85   // useful to call `Initialize()` proactively so that any errors that happen
86   // during initialization can be surfaced earlier.
Initialize()87   Status Initialize() { return EnsureInitialized(); }
88 
89  protected:
90   // Initializes the client if it isn't already initialized.
91   virtual Status EnsureInitialized() = 0;
92 
93   const std::string address_;
94   const std::string protocol_;
95 };
96 
97 }  // namespace data
98 }  // namespace tensorflow
99 
100 #endif  // TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_
101