xref: /aosp_15_r20/external/pigweed/pw_stream/mpsc_stream.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_stream/mpsc_stream.h"
16 
17 #include <cstring>
18 #include <mutex>
19 
20 #include "pw_assert/check.h"
21 
22 namespace pw::stream {
23 namespace {
24 
25 // Wait to receive a thread notification with an optional timeout.
Await(sync::TimedThreadNotification & notification,const std::optional<chrono::SystemClock::duration> & timeout)26 bool Await(sync::TimedThreadNotification& notification,
27            const std::optional<chrono::SystemClock::duration>& timeout) {
28   if (timeout.has_value()) {
29     return notification.try_acquire_for(*timeout);
30   }
31   // Block indefinitely.
32   notification.acquire();
33   return true;
34 }
35 
36 }  // namespace
37 
CreateMpscStream(MpscReader & reader,MpscWriter & writer)38 void CreateMpscStream(MpscReader& reader, MpscWriter& writer) {
39   reader.Close();
40   std::lock_guard rlock(reader.mutex_);
41   PW_CHECK(reader.writers_.empty());
42   std::lock_guard wlock(writer.mutex_);
43   writer.CloseLocked();
44   reader.writers_.push_front(writer);
45   reader.IncreaseLimitLocked(Stream::kUnlimited);
46   writer.reader_ = &reader;
47 }
48 
49 ////////////////////////////////////////////////////////////////////////////////
50 // MpscWriter methods.
51 
MpscWriter(const MpscWriter & other)52 MpscWriter::MpscWriter(const MpscWriter& other) : MpscWriter() {
53   *this = other;
54 }
55 
operator =(const MpscWriter & other)56 MpscWriter& MpscWriter::operator=(const MpscWriter& other) {
57   Close();
58 
59   // Read the other object's internal state. Avoid holding both locks at once.
60   other.mutex_.lock();
61   MpscReader* reader = other.reader_;
62   duration timeout = other.timeout_;
63   size_t limit = other.limit_;
64   size_t last_write = other.last_write_;
65   other.mutex_.unlock();
66 
67   // Now update this object with the other's state.
68   mutex_.lock();
69   reader_ = reader;
70   timeout_ = timeout;
71   limit_ = limit;
72   last_write_ = last_write;
73   mutex_.unlock();
74 
75   // Add the writer to the reader outside the lock. If the reader was closed
76   // concurrently, this will close the writer.
77   if (reader != nullptr) {
78     std::lock_guard lock(reader->mutex_);
79     reader->writers_.push_front(*this);
80     reader->IncreaseLimitLocked(limit);
81   }
82   return *this;
83 }
84 
MpscWriter(MpscWriter && other)85 MpscWriter::MpscWriter(MpscWriter&& other) : MpscWriter() {
86   *this = std::move(other);
87 }
88 
operator =(MpscWriter && other)89 MpscWriter& MpscWriter::operator=(MpscWriter&& other) {
90   *this = other;
91   other.Close();
92   return *this;
93 }
94 
~MpscWriter()95 MpscWriter::~MpscWriter() { Close(); }
96 
connected() const97 bool MpscWriter::connected() const {
98   std::lock_guard lock(mutex_);
99   return reader_ != nullptr;
100 }
101 
last_write() const102 size_t MpscWriter::last_write() const {
103   std::lock_guard lock(mutex_);
104   return last_write_;
105 }
106 
SetTimeout(const duration & timeout)107 void MpscWriter::SetTimeout(const duration& timeout) {
108   std::lock_guard lock(mutex_);
109   timeout_ = timeout;
110 }
111 
SetLimit(size_t limit)112 void MpscWriter::SetLimit(size_t limit) {
113   std::lock_guard lock(mutex_);
114   if (reader_) {
115     reader_->DecreaseLimit(limit_);
116     reader_->IncreaseLimit(limit);
117   }
118   limit_ = limit;
119   if (limit_ == 0) {
120     CloseLocked();
121   }
122 }
123 
ConservativeLimit(LimitType type) const124 size_t MpscWriter::ConservativeLimit(LimitType type) const {
125   std::lock_guard lock(mutex_);
126   return reader_ != nullptr && type == LimitType::kWrite ? limit_ : 0;
127 }
128 
DoWrite(ConstByteSpan data)129 Status MpscWriter::DoWrite(ConstByteSpan data) {
130   // Check some conditions to see if an early exit is possible.
131   if (data.empty()) {
132     return OkStatus();
133   }
134   std::lock_guard lock(mutex_);
135   if (reader_ == nullptr) {
136     return Status::OutOfRange();
137   }
138   if (limit_ < data.size()) {
139     return Status::ResourceExhausted();
140   }
141   if (!write_request_.unlisted()) {
142     return Status::FailedPrecondition();
143   }
144   // Subscribe to the reader. This will enqueue this object's write request,
145   // which will be used to notify the writer when the reader has space available
146   // or has closed.
147   reader_->RequestWrite(write_request_);
148   last_write_ = 0;
149 
150   Status status;
151   while (!data.empty()) {
152     // Wait to be notified by the reader.
153     // Note: This manually unlocks and relocks the mutex currently held by the
154     // lock guard. It must not return while the mutex is not locked.
155     duration timeout = timeout_;
156     mutex_.unlock();
157     bool writeable = Await(write_request_.notification, timeout);
158     mutex_.lock();
159 
160     // Conditions may have changed while waiting; check again.
161     if (reader_ == nullptr) {
162       return Status::OutOfRange();
163     }
164     if (!writeable || limit_ < data.size()) {
165       status = Status::ResourceExhausted();
166       break;
167     }
168 
169     // Attempt to write data.
170     StatusWithSize result = reader_->WriteData(data, limit_);
171     last_write_ += result.size();
172     if (limit_ != kUnlimited) {
173       limit_ -= result.size();
174     }
175 
176     // WriteData() only returns an error if the reader is closed. In that case,
177     // or if the writer has written all of its data, the writer should close.
178     if (!result.ok() || limit_ == 0) {
179       CloseLocked();
180       return result.status();
181     }
182     data = data.subspan(result.size());
183   }
184 
185   // Unsubscribe from the reader.
186   reader_->CompleteWrite(write_request_);
187   return status;
188 }
189 
Close()190 void MpscWriter::Close() {
191   std::lock_guard lock(mutex_);
192   CloseLocked();
193 }
194 
CloseLocked()195 void MpscWriter::CloseLocked() {
196   if (reader_ != nullptr) {
197     std::lock_guard lock(reader_->mutex_);
198     reader_->CompleteWriteLocked(write_request_);
199     write_request_.notification.release();
200     if (reader_->writers_.remove(*this)) {
201       reader_->DecreaseLimitLocked(limit_);
202     }
203     if (reader_->writers_.empty()) {
204       reader_->readable_.release();
205     }
206     reader_ = nullptr;
207   }
208   limit_ = kUnlimited;
209 }
210 
211 ////////////////////////////////////////////////////////////////////////////////
212 // MpscReader methods.
213 
MpscReader()214 MpscReader::MpscReader() { last_request_ = write_requests_.begin(); }
215 
~MpscReader()216 MpscReader::~MpscReader() { Close(); }
217 
connected() const218 bool MpscReader::connected() const {
219   std::lock_guard lock(mutex_);
220   return !writers_.empty();
221 }
222 
SetBuffer(ByteSpan buffer)223 void MpscReader::SetBuffer(ByteSpan buffer) {
224   std::lock_guard lock(mutex_);
225   PW_CHECK(length_ == 0);
226   buffer_ = buffer;
227   offset_ = 0;
228 }
229 
SetTimeout(const duration & timeout)230 void MpscReader::SetTimeout(const duration& timeout) {
231   std::lock_guard lock(mutex_);
232   timeout_ = timeout;
233 }
234 
IncreaseLimit(size_t delta)235 void MpscReader::IncreaseLimit(size_t delta) {
236   std::lock_guard lock(mutex_);
237   IncreaseLimitLocked(delta);
238 }
239 
IncreaseLimitLocked(size_t delta)240 void MpscReader::IncreaseLimitLocked(size_t delta) {
241   if (delta == kUnlimited) {
242     ++num_unlimited_;
243     PW_CHECK_UINT_NE(num_unlimited_, 0);
244   } else if (limit_ != kUnlimited) {
245     PW_CHECK_UINT_LT(limit_, kUnlimited - delta);
246     limit_ += delta;
247   }
248 }
249 
DecreaseLimit(size_t delta)250 void MpscReader::DecreaseLimit(size_t delta) {
251   std::lock_guard lock(mutex_);
252   DecreaseLimitLocked(delta);
253 }
254 
DecreaseLimitLocked(size_t delta)255 void MpscReader::DecreaseLimitLocked(size_t delta) {
256   if (delta == kUnlimited) {
257     PW_CHECK_UINT_NE(num_unlimited_, 0);
258     --num_unlimited_;
259   } else if (limit_ != kUnlimited) {
260     PW_CHECK_UINT_LE(delta, limit_);
261     limit_ -= delta;
262   }
263 }
264 
ConservativeLimit(LimitType type) const265 size_t MpscReader::ConservativeLimit(LimitType type) const {
266   std::lock_guard lock(mutex_);
267   if (type != LimitType::kRead) {
268     return 0;
269   }
270   if (writers_.empty()) {
271     return length_;
272   }
273   if (num_unlimited_ != 0) {
274     return kUnlimited;
275   }
276   return limit_;
277 }
278 
RequestWrite(MpscWriter::Request & write_request)279 void MpscReader::RequestWrite(MpscWriter::Request& write_request) {
280   std::lock_guard lock(mutex_);
281   last_request_ = write_requests_.insert_after(last_request_, write_request);
282   CheckWriteableLocked();
283 }
284 
CheckWriteableLocked()285 void MpscReader::CheckWriteableLocked() {
286   if (write_requests_.empty()) {
287     return;
288   }
289   if (writers_.empty() || written_ < destination_.size() ||
290       length_ < buffer_.size()) {
291     MpscWriter::Request& write_request = write_requests_.front();
292     write_request.notification.release();
293   }
294 }
295 
WriteData(ConstByteSpan data,size_t limit)296 StatusWithSize MpscReader::WriteData(ConstByteSpan data, size_t limit) {
297   std::lock_guard lock(mutex_);
298   if (writers_.empty()) {
299     return StatusWithSize::OutOfRange(0);
300   }
301   size_t length = 0;
302   size_t available = buffer_.size() - length_;
303   if (written_ < destination_.size()) {
304     // A read is pending; copy directly into its buffer.
305     // Note: this condition is only true when the buffer is empty, so data
306     // order is preserved.
307     length = std::min(destination_.size() - written_, data.size());
308     memcpy(&destination_[written_], &data[0], length);
309     written_ += length;
310   } else if (available > 0) {
311     // The buffer has space for more data.
312     length = std::min(available, data.size());
313     size_t offset = (offset_ + length_) % buffer_.size();
314     size_t contiguous = buffer_.size() - offset;
315     if (length <= contiguous) {
316       memcpy(&buffer_[offset], &data[0], length);
317     } else {
318       memcpy(&buffer_[offset], &data[0], contiguous);
319       memcpy(&buffer_[0], &data[contiguous], length - contiguous);
320     }
321     length_ += length;
322   } else {
323     // If there is no space available, a write request can only be notified when
324     // its writer is closing. Do not notify the reader that data is available.
325     return StatusWithSize(0);
326   }
327   data = data.subspan(length);
328   // For unlimited writers, increase the read limit as needed.
329   // Do this before waking the reader and releasing the lock.
330   if (limit == kUnlimited) {
331     IncreaseLimitLocked(length);
332   }
333   readable_.release();
334   return StatusWithSize(length);
335 }
336 
CompleteWrite(MpscWriter::Request & write_request)337 void MpscReader::CompleteWrite(MpscWriter::Request& write_request) {
338   std::lock_guard lock(mutex_);
339   CompleteWriteLocked(write_request);
340 }
341 
CompleteWriteLocked(MpscWriter::Request & write_request)342 void MpscReader::CompleteWriteLocked(MpscWriter::Request& write_request) {
343   MpscWriter::Request& last_request = *last_request_;
344   write_requests_.remove(write_request);
345 
346   // If the last request is removed, find the new last request. This is O(n),
347   // but the oremoved element is first unless a request is being canceled due to
348   // its writer closing. Thus in the typical case of a successful write, this is
349   // O(1).
350   if (&last_request == &write_request) {
351     last_request_ = write_requests_.begin();
352     for (size_t i = 1; i < write_requests_.size(); ++i) {
353       ++last_request_;
354     }
355   }
356 
357   // The reader may have signaled this writer that it had space between the last
358   // call to WriteData() and this call. Check if that signal should be forwarded
359   // to the next write request.
360   CheckWriteableLocked();
361 }
362 
DoRead(ByteSpan destination)363 StatusWithSize MpscReader::DoRead(ByteSpan destination) {
364   if (destination.empty()) {
365     return StatusWithSize(0);
366   }
367   mutex_.lock();
368   PW_CHECK(!reading_, "All reads must happen from the same thread.");
369   reading_ = true;
370   Status status = OkStatus();
371   size_t length = 0;
372 
373   // Check for buffered data. Do this before checking if the reader is still
374   // connected in order to deliver data sent from a now-closed writer.
375   if (length_ != 0) {
376     length = std::min(length_, destination.size());
377     size_t contiguous = buffer_.size() - offset_;
378     if (length < contiguous) {
379       memcpy(&destination[0], &buffer_[offset_], length);
380       offset_ += length;
381     } else if (length == contiguous) {
382       memcpy(&destination[0], &buffer_[offset_], length);
383       offset_ = 0;
384     } else {
385       memcpy(&destination[0], &buffer_[offset_], contiguous);
386       offset_ = length - contiguous;
387       memcpy(&destination[contiguous], &buffer_[0], offset_);
388     }
389     length_ -= length;
390     DecreaseLimitLocked(length);
391     CheckWriteableLocked();
392 
393   } else {
394     // Register the output buffer to and wait for Write() to bypass the buffer
395     // and write directly into it. Note that the buffer is only bypassed when
396     // empty, so data order is preserved.
397     PW_CHECK(written_ == 0);
398     destination_ = destination;
399     CheckWriteableLocked();
400 
401     // The reader state may change while waiting, or even between acquiring the
402     // notification and acquiring the lock. As an example, the following
403     // sequence of events is possible:
404     //
405     //   1. A writer partially fills the output buffer and releases the
406     //      notification.
407     //   2. The reader acquires the notification.
408     //   3. Another writer fills the remainder of the buffer and releass the
409     //      notification *again*.
410     //   4. The reader acquires the lock.
411     //
412     // In this case, on the *next* read, the notification will be acquired
413     // immediately even if no data is available. As a result, this code loops
414     // until data is available.
415     while (status.ok()) {
416       bool readable = true;
417       if (!writers_.empty()) {
418         // Wait for a writer to provide data, or the reader to be closed.
419         duration timeout = timeout_;
420         mutex_.unlock();
421         readable = Await(readable_, timeout);
422         mutex_.lock();
423       }
424       if (!readable) {
425         status = Status::ResourceExhausted();
426       } else if (written_ != 0) {
427         break;
428       } else if (writers_.empty()) {
429         status = Status::OutOfRange();
430       }
431     }
432     destination_ = ByteSpan();
433     length = written_;
434     written_ = 0;
435     DecreaseLimitLocked(length);
436     CheckWriteableLocked();
437   }
438 
439   reading_ = false;
440   if (writers_.empty()) {
441     closeable_.release();
442   }
443   mutex_.unlock();
444   return StatusWithSize(status, length);
445 }
446 
ReadAll(ReadAllCallback callback)447 Status MpscReader::ReadAll(ReadAllCallback callback) {
448   mutex_.lock();
449   if (buffer_.empty()) {
450     mutex_.unlock();
451     return Status::FailedPrecondition();
452   }
453   PW_CHECK(!reading_, "All reads must happen from the same thread.");
454   reading_ = true;
455 
456   Status status = Status::OutOfRange();
457   while (true) {
458     // Check for buffered data. Do this before checking if the reader still has
459     // writers in order to deliver data sent from a now-closed writer.
460     if (length_ != 0) {
461       size_t length = std::min(buffer_.size() - offset_, length_);
462       ConstByteSpan data(&buffer_[offset_], length);
463       offset_ = (offset_ + length_) % buffer_.size();
464       length_ -= length;
465       DecreaseLimitLocked(data.size());
466       CheckWriteableLocked();
467       status = callback(data);
468       if (!status.ok()) {
469         break;
470       }
471     }
472     if (writers_.empty()) {
473       break;
474     }
475     // Wait for a writer to provide data.
476     duration timeout = timeout_;
477     mutex_.unlock();
478     bool readable = Await(readable_, timeout);
479     mutex_.lock();
480     if (!readable) {
481       status = Status::ResourceExhausted();
482       break;
483     }
484   }
485   reading_ = false;
486   if (writers_.empty()) {
487     closeable_.release();
488   }
489   mutex_.unlock();
490   return status;
491 }
492 
Close()493 void MpscReader::Close() {
494   mutex_.lock();
495   if (writers_.empty()) {
496     mutex_.unlock();
497     return;
498   }
499   IntrusiveList<MpscWriter> writers;
500   while (!writers_.empty()) {
501     MpscWriter& writer = writers_.front();
502     writers_.pop_front();
503     writers.push_front(writer);
504   }
505 
506   // Wait for any pending read to finish.
507   if (reading_) {
508     mutex_.unlock();
509     readable_.release();
510     closeable_.acquire();
511     mutex_.lock();
512   }
513 
514   num_unlimited_ = 0;
515   limit_ = 0;
516   written_ = 0;
517   offset_ = 0;
518   length_ = 0;
519   mutex_.unlock();
520 
521   for (auto& writer : writers) {
522     writer.Close();
523   }
524 }
525 
526 }  // namespace pw::stream
527