Searched refs:TensorWithLayout (Results 1 – 3 of 3) sorted by relevance
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/ |
H A D | dtensor_device_util.cc | 92 std::unique_ptr<TensorWithLayout> BroadcastResourceTensor( in BroadcastResourceTensor() 145 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout); in BroadcastResourceTensor() 153 StatusOr<std::unique_ptr<TensorWithLayout>> result = TensorWithLayout::Wrap( in BroadcastResourceTensor() 258 tensorflow::Fprint128 TensorWithLayout::CacheKey() const { in CacheKey() 272 std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast( in Broadcast() 297 auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout); in Broadcast() 325 std::unique_ptr<TensorWithLayout> result(new TensorWithLayout( in Broadcast() 331 StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap( in Wrap() 338 return std::unique_ptr<TensorWithLayout>( in Wrap() 339 new TensorWithLayout(std::move(tensor), mesh, layout, *shape)); in Wrap() [all …]
|
H A D | dtensor_device_util.h | 209 class TensorWithLayout { 213 static std::unique_ptr<TensorWithLayout> Broadcast( 219 static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap( 224 static std::unique_ptr<TensorWithLayout> Dummy( 228 virtual ~TensorWithLayout() {} in ~TensorWithLayout() 328 TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor, 374 class ResourceHandleWithLayout : public TensorWithLayout { 431 : TensorWithLayout(std::move(tensor), mesh, layout, local_shape, in ResourceHandleWithLayout() 449 class SparseTensorWithLayout : public TensorWithLayout { 451 static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap( [all …]
|
H A D | dtensor_device.cc | 87 const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set, in PipeliningPartitionerRun() 343 const std::vector<TensorWithLayout*>& inputs, 358 const std::vector<TensorWithLayout*>& inputs, 365 std::unique_ptr<TensorWithLayout> t, 368 void RecordInShapeLayoutCache(const TensorWithLayout& tensor); 492 return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size(); in TensorWithLayoutNumDims() 496 return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index]; in TensorWithLayoutDim() 500 delete reinterpret_cast<TensorWithLayout*>(data); in TensorWithLayoutDeallocator() 505 reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue(); in TensorWithLayoutSummarize() 510 TFE_Context* context, std::unique_ptr<TensorWithLayout> t, in MakeLayoutTensorHandle() [all …]
|