1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/literal_util.h"
17
18 #include "tensorflow/compiler/tf2xla/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/core/common_runtime/dma_helper.h"
22
23 namespace tensorflow {
24
HostTensorToBorrowingLiteral(const Tensor & host_tensor,xla::BorrowingLiteral * literal)25 Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
26 xla::BorrowingLiteral* literal) {
27 xla::Shape xla_shape;
28 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
29 host_tensor.shape(), &xla_shape));
30 return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal);
31 }
32
HostTensorToBorrowingLiteral(const xla::Shape & xla_shape,const Tensor & host_tensor,xla::BorrowingLiteral * literal)33 Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape,
34 const Tensor& host_tensor,
35 xla::BorrowingLiteral* literal) {
36 const auto& tshape = host_tensor.shape();
37 TF_RET_CHECK(tshape.IsFullyDefined() &&
38 tshape.dims() == xla_shape.dimensions_size() &&
39 tshape.dim_sizes() == xla_shape.dimensions())
40 << "Provided xla::Shape must have the same dims as the Tensor shape.";
41 *literal = xla::BorrowingLiteral(
42 static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
43 return OkStatus();
44 }
45
HostTensorToLiteral(const Tensor & host_tensor)46 StatusOr<xla::Literal> HostTensorToLiteral(const Tensor& host_tensor) {
47 xla::BorrowingLiteral literal;
48 TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(host_tensor, &literal));
49 return literal.Clone();
50 }
51
HostTensorToMutableBorrowingLiteral(Tensor * host_tensor,xla::MutableBorrowingLiteral * literal)52 Status HostTensorToMutableBorrowingLiteral(
53 Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) {
54 xla::Shape xla_shape;
55 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(),
56 host_tensor->shape(), &xla_shape));
57 return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal);
58 }
59
HostTensorToMutableBorrowingLiteral(const xla::Shape & xla_shape,Tensor * host_tensor,xla::MutableBorrowingLiteral * literal)60 Status HostTensorToMutableBorrowingLiteral(
61 const xla::Shape& xla_shape, Tensor* host_tensor,
62 xla::MutableBorrowingLiteral* literal) {
63 *literal = xla::MutableBorrowingLiteral(
64 static_cast<const char*>(DMAHelper::base(host_tensor)), xla_shape);
65
66 return OkStatus();
67 }
68
HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,xla::BorrowingLiteral * literal)69 Status HostTensorsToBorrowingLiteralTuple(absl::Span<const Tensor> host_tensors,
70 xla::BorrowingLiteral* literal) {
71 std::vector<const char*> buf_ptrs;
72 buf_ptrs.reserve(host_tensors.size());
73 std::vector<xla::Shape> tensor_shapes(host_tensors.size());
74
75 for (int i = 0, end = host_tensors.size(); i < end; i++) {
76 // Validate runtime shapes and fail if it doesn't match the contract.
77 const Tensor* tensor = &host_tensors[i];
78 buf_ptrs.emplace_back(static_cast<const char*>(DMAHelper::base(tensor)));
79 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(),
80 &tensor_shapes[i]));
81 }
82
83 *literal = xla::BorrowingLiteral(
84 buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes));
85
86 return OkStatus();
87 }
88
CopyLiteralToHostTensor(const xla::LiteralSlice & literal,Tensor * host_tensor)89 Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
90 Tensor* host_tensor) {
91 TF_RET_CHECK(literal.shape().IsArray() &&
92 xla::ShapeUtil::ElementsIn(literal.shape()) ==
93 host_tensor->NumElements());
94 xla::PrimitiveType primitive_type;
95 TF_RETURN_IF_ERROR(
96 DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type));
97 if (literal.shape().element_type() != primitive_type) {
98 return errors::InvalidArgument(
99 "Cannot convert literal of type ",
100 xla::PrimitiveType_Name(literal.shape().element_type()),
101 " to tensor of type ", DataTypeString(host_tensor->dtype()));
102 }
103 size_t total_bytes = host_tensor->TotalBytes();
104 if (total_bytes > 0) {
105 const void* src_ptr = literal.untyped_data();
106 void* dst_ptr = DMAHelper::base(host_tensor);
107 memcpy(dst_ptr, src_ptr, total_bytes);
108 }
109 return OkStatus();
110 }
111
LiteralToHostTensor(const xla::LiteralSlice & literal,DataType target_type,Tensor * host_tensor)112 Status LiteralToHostTensor(const xla::LiteralSlice& literal,
113 DataType target_type, Tensor* host_tensor) {
114 TensorShape shape;
115 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
116 *host_tensor = Tensor(target_type, shape);
117 return CopyLiteralToHostTensor(literal, host_tensor);
118 }
119
120 } // namespace tensorflow
121