/* * Copyright 2019 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "fcp/base/random_token.h" #include "fcp/tensorflow/external_dataset.h" #include "fcp/tensorflow/status.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/public/version.h" namespace fcp { /** * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found * from the ExternalDatasetProviderRegistry (a HostObjectRegistry). * * Inputs: * selector: An opaque string scalar. Forwarded to the stub. * token: String scalar. It should encode a token obtained from * ExternalDatasetProviderRegistry::Register. * * See TensorFlow's guide to making custom dataset ops: * https://www.tensorflow.org/guide/extend/formats */ class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel { public: using tensorflow::data::DatasetOpKernel::DatasetOpKernel; void MakeDataset(tensorflow::OpKernelContext* ctx, tensorflow::data::DatasetBase** output) override { tensorflow::tstring token_str; OP_REQUIRES_OK(ctx, tensorflow::data::ParseScalarArgument( ctx, "token", &token_str)); absl::Span token_bytes = token_str; OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes, tensorflow::errors::InvalidArgument(absl::StrFormat( "Tokens have a fixed size. Expected: %d; Actual %d", kRandomTokenSizeInBytes, token_bytes.size()))); RandomToken token = RandomToken::FromBytes(token_bytes); tensorflow::tstring selector_str; OP_REQUIRES_OK(ctx, tensorflow::data::ParseScalarArgument( ctx, "selector", &selector_str)); std::optional> maybe_provider = ExternalDatasetProviderRegistry::TryLookup(token); OP_REQUIRES(ctx, maybe_provider.has_value(), tensorflow::errors::InvalidArgument( "A dataset provider is not currently registered for the " "provided token: ", token.ToPrintableString())); std::shared_ptr provider = *std::move(maybe_provider); StatusOr> maybe_dataset = provider->MakeDataset(selector_str); // The provider might not like the given selector. if (!maybe_dataset.ok()) { ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status())); return; } *output = new Dataset(ctx, std::move(maybe_dataset).value()); } private: class Dataset : public tensorflow::data::DatasetBase { public: Dataset(tensorflow::OpKernelContext* ctx, std::unique_ptr stub) : DatasetBase(tensorflow::data::DatasetContext(ctx)), stub_(std::move(stub)) {} std::unique_ptr MakeIteratorInternal( const std::string& prefix) const override { std::unique_ptr iter = stub_->MakeIterator(); Iterator::Params params{ this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")}; return std::unique_ptr( new Iterator(params, std::move(iter))); } // Each iterator element is just a scalar string. const tensorflow::DataTypeVector& output_dtypes() const override { static auto* const dtypes = new tensorflow::DataTypeVector({tensorflow::DT_STRING}); return *dtypes; } const std::vector& output_shapes() const override { static std::vector* shapes = new std::vector({{}}); return *shapes; } std::string DebugString() const override { return "ExternalDatasetOp::Dataset"; } tensorflow::Status InputDatasets( std::vector* inputs) const override { // ExternalDatast has no input datasets, so just return OK. return tensorflow::OkStatus(); } // The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We // use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if // we should add its override. #if TF_GRAPH_DEF_VERSION > 125 tensorflow::Status CheckExternalState() const override { return tensorflow::OkStatus(); } #endif protected: tensorflow::Status AsGraphDefInternal( tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, tensorflow::Node** output) const override { return ::tensorflow::errors::Unimplemented( DebugString(), " does not support serialization."); } private: class Iterator : public tensorflow::data::DatasetIterator { public: explicit Iterator(const Params& params, std::unique_ptr stub) : DatasetIterator(params), stub_(std::move(stub)) {} tensorflow::Status GetNextInternal( tensorflow::data::IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { StatusOr maybe_element; { absl::MutexLock _(&mu_); maybe_element = stub_->GetNext(); } if (maybe_element.ok()) { std::string element = std::move(maybe_element).value(); // The {} at the end specifies a scalar tensor. tensorflow::Tensor element_tensor(ctx->allocator({}), tensorflow::DT_STRING, {}); element_tensor.scalar()() = element; *end_of_sequence = false; out_tensors->push_back(std::move(element_tensor)); return tensorflow::OkStatus(); } else { *end_of_sequence = true; if (maybe_element.status().code() == StatusCode::kOutOfRange) { return tensorflow::OkStatus(); } else { return ConvertToTensorFlowStatus(maybe_element.status()); } } } protected: tensorflow::Status SaveInternal( // `::tensorflow::data::SerializationContext` argument was added on // 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343. #if TF_GRAPH_DEF_VERSION > 343 tensorflow::data::SerializationContext* ctx, #endif tensorflow::data::IteratorStateWriter* writer) override { return ::tensorflow::errors::Unimplemented( "Save / Restore of an ExternalDataset iterator is not supported"); } tensorflow::Status RestoreInternal( tensorflow::data::IteratorContext* ctx, tensorflow::data::IteratorStateReader* reader) override { return ::tensorflow::errors::Unimplemented( "Save / Restore of an ExternalDataset iterator is not supported"); } private: std::unique_ptr stub_; absl::Mutex mu_; }; // Private members of Dataset std::unique_ptr stub_; }; }; REGISTER_OP("ExternalDataset") .Input("token: string") .Input("selector: string") .Output("handle: variant") .SetIsStateful() .SetShapeFn(tensorflow::shape_inference::ScalarShape); REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU), ExternalDatasetOp); } // namespace fcp