1*1a96fba6SXin Li // Copyright 2015 The Chromium OS Authors. All rights reserved.
2*1a96fba6SXin Li // Use of this source code is governed by a BSD-style license that can be
3*1a96fba6SXin Li // found in the LICENSE file.
4*1a96fba6SXin Li
5*1a96fba6SXin Li #include <brillo/streams/stream_utils.h>
6*1a96fba6SXin Li
7*1a96fba6SXin Li #include <algorithm>
8*1a96fba6SXin Li #include <limits>
9*1a96fba6SXin Li #include <memory>
10*1a96fba6SXin Li #include <utility>
11*1a96fba6SXin Li #include <vector>
12*1a96fba6SXin Li
13*1a96fba6SXin Li #include <base/bind.h>
14*1a96fba6SXin Li #include <brillo/message_loops/message_loop.h>
15*1a96fba6SXin Li #include <brillo/streams/stream_errors.h>
16*1a96fba6SXin Li
17*1a96fba6SXin Li namespace brillo {
18*1a96fba6SXin Li namespace stream_utils {
19*1a96fba6SXin Li
20*1a96fba6SXin Li namespace {
21*1a96fba6SXin Li
22*1a96fba6SXin Li // Status of asynchronous CopyData operation.
23*1a96fba6SXin Li struct CopyDataState {
24*1a96fba6SXin Li brillo::StreamPtr in_stream;
25*1a96fba6SXin Li brillo::StreamPtr out_stream;
26*1a96fba6SXin Li std::vector<uint8_t> buffer;
27*1a96fba6SXin Li uint64_t remaining_to_copy;
28*1a96fba6SXin Li uint64_t size_copied;
29*1a96fba6SXin Li CopyDataSuccessCallback success_callback;
30*1a96fba6SXin Li CopyDataErrorCallback error_callback;
31*1a96fba6SXin Li };
32*1a96fba6SXin Li
33*1a96fba6SXin Li // Async CopyData I/O error callback.
OnCopyDataError(const std::shared_ptr<CopyDataState> & state,const brillo::Error * error)34*1a96fba6SXin Li void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
35*1a96fba6SXin Li const brillo::Error* error) {
36*1a96fba6SXin Li state->error_callback.Run(std::move(state->in_stream),
37*1a96fba6SXin Li std::move(state->out_stream), error);
38*1a96fba6SXin Li }
39*1a96fba6SXin Li
40*1a96fba6SXin Li // Forward declaration.
41*1a96fba6SXin Li void PerformRead(const std::shared_ptr<CopyDataState>& state);
42*1a96fba6SXin Li
43*1a96fba6SXin Li // Callback from read operation for CopyData. Writes the read data to the output
44*1a96fba6SXin Li // stream and invokes PerformRead when done to restart the copy cycle.
PerformWrite(const std::shared_ptr<CopyDataState> & state,size_t size)45*1a96fba6SXin Li void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
46*1a96fba6SXin Li if (size == 0) {
47*1a96fba6SXin Li state->success_callback.Run(std::move(state->in_stream),
48*1a96fba6SXin Li std::move(state->out_stream),
49*1a96fba6SXin Li state->size_copied);
50*1a96fba6SXin Li return;
51*1a96fba6SXin Li }
52*1a96fba6SXin Li state->size_copied += size;
53*1a96fba6SXin Li CHECK_GE(state->remaining_to_copy, size);
54*1a96fba6SXin Li state->remaining_to_copy -= size;
55*1a96fba6SXin Li
56*1a96fba6SXin Li brillo::ErrorPtr error;
57*1a96fba6SXin Li bool success = state->out_stream->WriteAllAsync(
58*1a96fba6SXin Li state->buffer.data(), size, base::Bind(&PerformRead, state),
59*1a96fba6SXin Li base::Bind(&OnCopyDataError, state), &error);
60*1a96fba6SXin Li
61*1a96fba6SXin Li if (!success)
62*1a96fba6SXin Li OnCopyDataError(state, error.get());
63*1a96fba6SXin Li }
64*1a96fba6SXin Li
65*1a96fba6SXin Li // Performs the read part of asynchronous CopyData operation. Reads the data
66*1a96fba6SXin Li // from input stream and invokes PerformWrite when done to write the data to
67*1a96fba6SXin Li // the output stream.
PerformRead(const std::shared_ptr<CopyDataState> & state)68*1a96fba6SXin Li void PerformRead(const std::shared_ptr<CopyDataState>& state) {
69*1a96fba6SXin Li brillo::ErrorPtr error;
70*1a96fba6SXin Li const uint64_t buffer_size = state->buffer.size();
71*1a96fba6SXin Li // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
72*1a96fba6SXin Li // also not overflow size_t, so the static_cast below is safe.
73*1a96fba6SXin Li size_t size_to_read =
74*1a96fba6SXin Li static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
75*1a96fba6SXin Li if (size_to_read == 0)
76*1a96fba6SXin Li return PerformWrite(state, 0); // Nothing more to read. Finish operation.
77*1a96fba6SXin Li bool success = state->in_stream->ReadAsync(
78*1a96fba6SXin Li state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
79*1a96fba6SXin Li base::Bind(OnCopyDataError, state), &error);
80*1a96fba6SXin Li
81*1a96fba6SXin Li if (!success)
82*1a96fba6SXin Li OnCopyDataError(state, error.get());
83*1a96fba6SXin Li }
84*1a96fba6SXin Li
85*1a96fba6SXin Li } // anonymous namespace
86*1a96fba6SXin Li
ErrorStreamClosed(const base::Location & location,ErrorPtr * error)87*1a96fba6SXin Li bool ErrorStreamClosed(const base::Location& location,
88*1a96fba6SXin Li ErrorPtr* error) {
89*1a96fba6SXin Li Error::AddTo(error,
90*1a96fba6SXin Li location,
91*1a96fba6SXin Li errors::stream::kDomain,
92*1a96fba6SXin Li errors::stream::kStreamClosed,
93*1a96fba6SXin Li "Stream is closed");
94*1a96fba6SXin Li return false;
95*1a96fba6SXin Li }
96*1a96fba6SXin Li
ErrorOperationNotSupported(const base::Location & location,ErrorPtr * error)97*1a96fba6SXin Li bool ErrorOperationNotSupported(const base::Location& location,
98*1a96fba6SXin Li ErrorPtr* error) {
99*1a96fba6SXin Li Error::AddTo(error,
100*1a96fba6SXin Li location,
101*1a96fba6SXin Li errors::stream::kDomain,
102*1a96fba6SXin Li errors::stream::kOperationNotSupported,
103*1a96fba6SXin Li "Stream operation not supported");
104*1a96fba6SXin Li return false;
105*1a96fba6SXin Li }
106*1a96fba6SXin Li
ErrorReadPastEndOfStream(const base::Location & location,ErrorPtr * error)107*1a96fba6SXin Li bool ErrorReadPastEndOfStream(const base::Location& location,
108*1a96fba6SXin Li ErrorPtr* error) {
109*1a96fba6SXin Li Error::AddTo(error,
110*1a96fba6SXin Li location,
111*1a96fba6SXin Li errors::stream::kDomain,
112*1a96fba6SXin Li errors::stream::kPartialData,
113*1a96fba6SXin Li "Reading past the end of stream");
114*1a96fba6SXin Li return false;
115*1a96fba6SXin Li }
116*1a96fba6SXin Li
ErrorOperationTimeout(const base::Location & location,ErrorPtr * error)117*1a96fba6SXin Li bool ErrorOperationTimeout(const base::Location& location,
118*1a96fba6SXin Li ErrorPtr* error) {
119*1a96fba6SXin Li Error::AddTo(error,
120*1a96fba6SXin Li location,
121*1a96fba6SXin Li errors::stream::kDomain,
122*1a96fba6SXin Li errors::stream::kTimeout,
123*1a96fba6SXin Li "Operation timed out");
124*1a96fba6SXin Li return false;
125*1a96fba6SXin Li }
126*1a96fba6SXin Li
CheckInt64Overflow(const base::Location & location,uint64_t position,int64_t offset,ErrorPtr * error)127*1a96fba6SXin Li bool CheckInt64Overflow(const base::Location& location,
128*1a96fba6SXin Li uint64_t position,
129*1a96fba6SXin Li int64_t offset,
130*1a96fba6SXin Li ErrorPtr* error) {
131*1a96fba6SXin Li if (offset < 0) {
132*1a96fba6SXin Li // Subtracting the offset. Make sure we do not underflow.
133*1a96fba6SXin Li uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
134*1a96fba6SXin Li if (position >= unsigned_offset)
135*1a96fba6SXin Li return true;
136*1a96fba6SXin Li } else {
137*1a96fba6SXin Li // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
138*1a96fba6SXin Li if (position <= std::numeric_limits<uint64_t>::max() - offset) {
139*1a96fba6SXin Li // We definitely will not overflow the unsigned 64 bit integer.
140*1a96fba6SXin Li // Now check that we end up within the limits of signed 64 bit integer.
141*1a96fba6SXin Li uint64_t new_position = position + offset;
142*1a96fba6SXin Li uint64_t max = std::numeric_limits<int64_t>::max();
143*1a96fba6SXin Li if (new_position <= max)
144*1a96fba6SXin Li return true;
145*1a96fba6SXin Li }
146*1a96fba6SXin Li }
147*1a96fba6SXin Li Error::AddTo(error,
148*1a96fba6SXin Li location,
149*1a96fba6SXin Li errors::stream::kDomain,
150*1a96fba6SXin Li errors::stream::kInvalidParameter,
151*1a96fba6SXin Li "The stream offset value is out of range");
152*1a96fba6SXin Li return false;
153*1a96fba6SXin Li }
154*1a96fba6SXin Li
CalculateStreamPosition(const base::Location & location,int64_t offset,Stream::Whence whence,uint64_t current_position,uint64_t stream_size,uint64_t * new_position,ErrorPtr * error)155*1a96fba6SXin Li bool CalculateStreamPosition(const base::Location& location,
156*1a96fba6SXin Li int64_t offset,
157*1a96fba6SXin Li Stream::Whence whence,
158*1a96fba6SXin Li uint64_t current_position,
159*1a96fba6SXin Li uint64_t stream_size,
160*1a96fba6SXin Li uint64_t* new_position,
161*1a96fba6SXin Li ErrorPtr* error) {
162*1a96fba6SXin Li uint64_t pos = 0;
163*1a96fba6SXin Li switch (whence) {
164*1a96fba6SXin Li case Stream::Whence::FROM_BEGIN:
165*1a96fba6SXin Li pos = 0;
166*1a96fba6SXin Li break;
167*1a96fba6SXin Li
168*1a96fba6SXin Li case Stream::Whence::FROM_CURRENT:
169*1a96fba6SXin Li pos = current_position;
170*1a96fba6SXin Li break;
171*1a96fba6SXin Li
172*1a96fba6SXin Li case Stream::Whence::FROM_END:
173*1a96fba6SXin Li pos = stream_size;
174*1a96fba6SXin Li break;
175*1a96fba6SXin Li
176*1a96fba6SXin Li default:
177*1a96fba6SXin Li Error::AddTo(error,
178*1a96fba6SXin Li location,
179*1a96fba6SXin Li errors::stream::kDomain,
180*1a96fba6SXin Li errors::stream::kInvalidParameter,
181*1a96fba6SXin Li "Invalid stream position whence");
182*1a96fba6SXin Li return false;
183*1a96fba6SXin Li }
184*1a96fba6SXin Li
185*1a96fba6SXin Li if (!CheckInt64Overflow(location, pos, offset, error))
186*1a96fba6SXin Li return false;
187*1a96fba6SXin Li
188*1a96fba6SXin Li *new_position = static_cast<uint64_t>(pos + offset);
189*1a96fba6SXin Li return true;
190*1a96fba6SXin Li }
191*1a96fba6SXin Li
CopyData(StreamPtr in_stream,StreamPtr out_stream,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)192*1a96fba6SXin Li void CopyData(StreamPtr in_stream,
193*1a96fba6SXin Li StreamPtr out_stream,
194*1a96fba6SXin Li const CopyDataSuccessCallback& success_callback,
195*1a96fba6SXin Li const CopyDataErrorCallback& error_callback) {
196*1a96fba6SXin Li CopyData(std::move(in_stream), std::move(out_stream),
197*1a96fba6SXin Li std::numeric_limits<uint64_t>::max(), 4096, success_callback,
198*1a96fba6SXin Li error_callback);
199*1a96fba6SXin Li }
200*1a96fba6SXin Li
CopyData(StreamPtr in_stream,StreamPtr out_stream,uint64_t max_size_to_copy,size_t buffer_size,const CopyDataSuccessCallback & success_callback,const CopyDataErrorCallback & error_callback)201*1a96fba6SXin Li void CopyData(StreamPtr in_stream,
202*1a96fba6SXin Li StreamPtr out_stream,
203*1a96fba6SXin Li uint64_t max_size_to_copy,
204*1a96fba6SXin Li size_t buffer_size,
205*1a96fba6SXin Li const CopyDataSuccessCallback& success_callback,
206*1a96fba6SXin Li const CopyDataErrorCallback& error_callback) {
207*1a96fba6SXin Li auto state = std::make_shared<CopyDataState>();
208*1a96fba6SXin Li state->in_stream = std::move(in_stream);
209*1a96fba6SXin Li state->out_stream = std::move(out_stream);
210*1a96fba6SXin Li state->buffer.resize(buffer_size);
211*1a96fba6SXin Li state->remaining_to_copy = max_size_to_copy;
212*1a96fba6SXin Li state->size_copied = 0;
213*1a96fba6SXin Li state->success_callback = success_callback;
214*1a96fba6SXin Li state->error_callback = error_callback;
215*1a96fba6SXin Li brillo::MessageLoop::current()->PostTask(FROM_HERE,
216*1a96fba6SXin Li base::BindOnce(&PerformRead, state));
217*1a96fba6SXin Li }
218*1a96fba6SXin Li
219*1a96fba6SXin Li } // namespace stream_utils
220*1a96fba6SXin Li } // namespace brillo
221