/* * Copyright (C) 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ZstdUtil.h" #include #include namespace simpleperf { namespace { class CompressionOutBuffer { public: CompressionOutBuffer(size_t min_free_size) : min_free_size_(min_free_size), buffer_(min_free_size) {} const char* DataStart() const { return buffer_.data() + data_pos_; } size_t DataSize() const { return data_size_; } char* FreeStart() { return buffer_.data() + data_pos_ + data_size_; } size_t FreeSize() const { return buffer_.size() - data_pos_ - data_size_; } void PrepareForInput() { if (data_pos_ > 0) { if (data_size_ == 0) { data_pos_ = 0; } else { memmove(buffer_.data(), buffer_.data() + data_pos_, data_size_); data_pos_ = 0; } } if (FreeSize() < min_free_size_) { buffer_.resize(buffer_.size() * 2); } } void ProduceData(size_t size) { data_size_ += size; CHECK_LE(data_pos_ + data_size_, buffer_.size()); } void ConsumeData(size_t size) { CHECK_LE(size, data_size_); data_pos_ += size; data_size_ -= size; } private: const size_t min_free_size_; std::vector buffer_; size_t data_pos_ = 0; size_t data_size_ = 0; }; using ZSTD_CCtx_pointer = std::unique_ptr; class ZstdCompressor : public Compressor { public: ZstdCompressor(ZSTD_CCtx_pointer cctx) : cctx_(std::move(cctx)), out_buffer_(ZSTD_CStreamOutSize()) {} bool AddInputData(const char* data, size_t size) override { ZSTD_inBuffer input = {data, size, 0}; while (input.pos < input.size) { out_buffer_.PrepareForInput(); ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0}; size_t remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_continue); if (ZSTD_isError(remaining)) { LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining); return false; } out_buffer_.ProduceData(output.pos); total_output_size_ += output.pos; } total_input_size_ += size; return true; } bool FlushOutputData() override { if (flushed_input_size_ == total_input_size_) { return true; } flushed_input_size_ = total_input_size_; ZSTD_inBuffer input = {nullptr, 0, 0}; size_t remaining = 0; do { out_buffer_.PrepareForInput(); ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0}; remaining = ZSTD_compressStream2(cctx_.get(), &output, &input, ZSTD_e_end); if (ZSTD_isError(remaining)) { LOG(ERROR) << "ZSTD_compressStream2() failed: " << ZSTD_getErrorName(remaining); return false; } out_buffer_.ProduceData(output.pos); total_output_size_ += output.pos; } while (remaining != 0); return true; } std::string_view GetOutputData() override { return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize()); } void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); } private: ZSTD_CCtx_pointer cctx_; CompressionOutBuffer out_buffer_; uint64_t flushed_input_size_ = 0; }; using ZSTD_DCtx_pointer = std::unique_ptr; class ZstdDecompressor : public Decompressor { public: ZstdDecompressor(ZSTD_DCtx_pointer dctx) : dctx_(std::move(dctx)), out_buffer_(ZSTD_DStreamOutSize()) {} bool AddInputData(const char* data, size_t size) override { ZSTD_inBuffer input = {data, size, 0}; while (input.pos < input.size) { out_buffer_.PrepareForInput(); ZSTD_outBuffer output = {out_buffer_.FreeStart(), out_buffer_.FreeSize(), 0}; size_t remaining = ZSTD_decompressStream(dctx_.get(), &output, &input); if (ZSTD_isError(remaining)) { LOG(ERROR) << "ZSTD_decompressStream() failed: " << ZSTD_getErrorName(remaining); return false; } out_buffer_.ProduceData(output.pos); } return true; } std::string_view GetOutputData() override { return std::string_view(out_buffer_.DataStart(), out_buffer_.DataSize()); } void ConsumeOutputData(size_t size) override { out_buffer_.ConsumeData(size); } private: ZSTD_DCtx_pointer dctx_; CompressionOutBuffer out_buffer_; }; } // namespace Compressor::~Compressor() {} Decompressor::~Decompressor() {} std::unique_ptr CreateZstdCompressor(size_t compression_level) { ZSTD_CCtx_pointer cctx(ZSTD_createCCtx(), ZSTD_freeCCtx); if (!cctx) { LOG(ERROR) << "ZSTD_createCCtx() failed"; return nullptr; } size_t err = ZSTD_CCtx_setParameter(cctx.get(), ZSTD_c_compressionLevel, compression_level); if (ZSTD_isError(err)) { LOG(ERROR) << "failed to set compression level: " << ZSTD_getErrorName(err); return nullptr; } return std::unique_ptr(new ZstdCompressor(std::move(cctx))); } std::unique_ptr CreateZstdDecompressor() { ZSTD_DCtx_pointer dctx(ZSTD_createDCtx(), ZSTD_freeDCtx); if (!dctx) { LOG(ERROR) << "ZSTD_createDCtx() failed"; return nullptr; } return std::unique_ptr(new ZstdDecompressor(std::move(dctx))); } bool ZstdCompress(const char* input_data, size_t input_size, std::string& output_data) { std::unique_ptr compressor = CreateZstdCompressor(); CHECK(compressor != nullptr); if (!compressor->AddInputData(input_data, input_size)) { return false; } if (!compressor->FlushOutputData()) { return false; } std::string_view output = compressor->GetOutputData(); output_data.clear(); output_data.insert(0, output.data(), output.size()); return true; } bool ZstdDecompress(const char* input_data, size_t input_size, std::string& output_data) { std::unique_ptr decompressor = CreateZstdDecompressor(); CHECK(decompressor != nullptr); if (!decompressor->AddInputData(input_data, input_size)) { return false; } std::string_view output = decompressor->GetOutputData(); output_data.clear(); output_data.insert(0, output.data(), output.size()); return true; } } // namespace simpleperf