1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/cc/dtensor_utils.h"
17
18 #include <cstdlib>
19
20 #include "absl/strings/numbers.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24 namespace dtensor {
25
26 // LINT.IfChange
ClientId()27 int ClientId() {
28 char* client_id_str = std::getenv("DTENSOR_CLIENT_ID");
29 if (client_id_str == nullptr) return 0;
30 int client_id;
31 if (absl::SimpleAtoi(client_id_str, &client_id)) return client_id;
32 LOG(WARNING) << "Invalid DTENSOR_CLIENT_ID, using the default value 0.";
33 return 0;
34 }
35 // LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py)
36
37 // LINT.IfChange
NumClients()38 int NumClients() {
39 char* num_clients_str = std::getenv("DTENSOR_NUM_CLIENTS");
40 if (num_clients_str == nullptr) return 1;
41 int num_clients;
42 if (absl::SimpleAtoi(num_clients_str, &num_clients)) return num_clients;
43 LOG(WARNING) << "Invalid DTENSOR_NUM_CLIENTS, using the default value 1.";
44 return 1;
45 }
46 // LINT.ThenChange(//tensorflow/dtensor/python/dtensor_device.py)
47
LogOnAllTasks()48 bool LogOnAllTasks() {
49 char* dtensor_log_on_all_tasks_str = std::getenv("DTENSOR_LOG_ON_ALL_TASKS");
50 if (dtensor_log_on_all_tasks_str == nullptr) return false;
51 return true;
52 }
53
LogOpByOp()54 bool LogOpByOp() {
55 char* dtensor_log_op_by_op_str = std::getenv("DTENSOR_LOG_OP_BY_OP");
56 if (dtensor_log_op_by_op_str == nullptr) return false;
57 return true;
58 }
59
LayoutPropagationMaxSteps()60 int LayoutPropagationMaxSteps() {
61 char* dtensor_layout_propagation_max_steps_str =
62 std::getenv("DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS");
63 if (dtensor_layout_propagation_max_steps_str == nullptr) return 500;
64 int dtensor_layout_propagation_max_steps;
65 if (absl::SimpleAtoi(dtensor_layout_propagation_max_steps_str,
66 &dtensor_layout_propagation_max_steps))
67 return dtensor_layout_propagation_max_steps;
68 LOG(WARNING) << "Invalid DTENSOR_LAYOUT_PROPAGATION_MAX_STEPS, using "
69 "the default value 500.";
70 return 500;
71 }
72
EnableMixedPrecisionReduce()73 bool EnableMixedPrecisionReduce() {
74 char* dtensor_enable_mixed_precision_reduce_str =
75 std::getenv("DTENSOR_ENABLE_MIXED_PRECISION_REDUCE");
76 if (dtensor_enable_mixed_precision_reduce_str == nullptr) return false;
77 return true;
78 }
79
DoNotFuseReduceScatter()80 bool DoNotFuseReduceScatter() {
81 char* dtensor_do_not_fuse_reduce_scatter_str =
82 std::getenv("DTENSOR_DO_NOT_FUSE_REDUCE_SCATTER");
83 if (dtensor_do_not_fuse_reduce_scatter_str == nullptr) return false;
84 return true;
85 }
86
ReduceInBfloat16MaxGroupSize()87 int ReduceInBfloat16MaxGroupSize() {
88 char* dtensor_reduce_in_bfloat16_max_group_size_str =
89 std::getenv("DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE");
90 if (dtensor_reduce_in_bfloat16_max_group_size_str == nullptr) return 8;
91 int dtensor_reduce_in_bfloat16_max_group_size;
92 if (absl::SimpleAtoi(dtensor_reduce_in_bfloat16_max_group_size_str,
93 &dtensor_reduce_in_bfloat16_max_group_size))
94 return dtensor_reduce_in_bfloat16_max_group_size;
95 LOG(WARNING) << "Invalid DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE, using "
96 "the default value 8.";
97 return 8;
98 }
99
DTensorCheckpointV2Enabled()100 bool DTensorCheckpointV2Enabled() {
101 return std::getenv("DTENSOR_ENABLE_CHECKPOINT_V2");
102 }
103
104 } // namespace dtensor
105 } // namespace tensorflow
106