1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_ 18 #define FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_ 19 20 #include <cstddef> 21 #include <cstdint> 22 #include <initializer_list> 23 #include <utility> 24 #include <vector> 25 26 #include "fcp/base/monitoring.h" 27 28 #ifndef FCP_NANOLIBC 29 #include "fcp/aggregation/core/tensor.pb.h" 30 #endif 31 32 namespace fcp { 33 namespace aggregation { 34 35 // Represents a tensor shape as a collection of 36 // dimension sizes. 37 class TensorShape final { 38 public: 39 using DimSizesVector = std::vector<size_t>; 40 41 template <typename ForwardIterator> TensorShape(ForwardIterator first,ForwardIterator last)42 TensorShape(ForwardIterator first, ForwardIterator last) 43 : dim_sizes_(first, last) {} 44 TensorShape(std::initializer_list<size_t> dim_sizes)45 TensorShape(std::initializer_list<size_t> dim_sizes) 46 : dim_sizes_(dim_sizes) {} 47 48 #ifndef FCP_NANOLIBC 49 // Creates a TensorShape from a TensorShapeProto. 50 // Returns an error if any of the shape dimensions are unknown. 51 static StatusOr<TensorShape> FromProto(const TensorShapeProto& shape_proto); 52 53 // Returns a TensorShapeProto representation of the tensor shape. 54 TensorShapeProto ToProto() const; 55 #endif 56 57 // Gets the dimensions and their sizes. dim_sizes()58 const DimSizesVector& dim_sizes() const { return dim_sizes_; } 59 60 // Gets the total number of elements (which is a multiplication of sizes of 61 // all dimensions). 62 // For a scalar tensor with zero dimensions this returns 1. 63 size_t NumElements() const; 64 65 friend bool operator==(const TensorShape& a, const TensorShape& b) { 66 return a.dim_sizes_ == b.dim_sizes_; 67 } 68 69 friend bool operator!=(const TensorShape& a, const TensorShape& b) { 70 return a.dim_sizes_ != b.dim_sizes_; 71 } 72 73 private: TensorShape(DimSizesVector && dim_sizes)74 explicit TensorShape(DimSizesVector&& dim_sizes) 75 : dim_sizes_(std::move(dim_sizes)) {} 76 77 // TODO(team): Consider optimizing the storage for better inlining 78 // of small number of dimensions. 79 DimSizesVector dim_sizes_; 80 }; 81 82 } // namespace aggregation 83 } // namespace fcp 84 85 #endif // FCP_AGGREGATION_CORE_TENSOR_SHAPE_H_ 86