xref: /aosp_15_r20/external/pigweed/pw_rpc_transport/stream_rpc_dispatcher_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc_transport/stream_rpc_dispatcher.h"
16 
17 #include <algorithm>
18 #include <atomic>
19 
20 #include "pw_bytes/span.h"
21 #include "pw_log/log.h"
22 #include "pw_span/span.h"
23 #include "pw_status/status.h"
24 #include "pw_stream/stream.h"
25 #include "pw_sync/mutex.h"
26 #include "pw_sync/thread_notification.h"
27 #include "pw_thread/thread.h"
28 #include "pw_thread_stl/options.h"
29 #include "pw_unit_test/framework.h"
30 
31 namespace pw::rpc {
32 namespace {
33 
34 using namespace std::chrono_literals;
35 
36 class TestIngress : public RpcIngressHandler {
37  public:
TestIngress(size_t num_bytes_expected)38   explicit TestIngress(size_t num_bytes_expected)
39       : num_bytes_expected_(num_bytes_expected) {}
40 
ProcessIncomingData(ConstByteSpan buffer)41   Status ProcessIncomingData(ConstByteSpan buffer) override {
42     if (num_bytes_expected_ > 0) {
43       std::copy(buffer.begin(), buffer.end(), std::back_inserter(received_));
44       num_bytes_expected_ -= std::min(num_bytes_expected_, buffer.size());
45     }
46     if (num_bytes_expected_ == 0) {
47       done_.release();
48     }
49     return OkStatus();
50   }
51 
received() const52   std::vector<std::byte> received() const { return received_; }
Wait()53   void Wait() { done_.acquire(); }
54 
55  private:
56   size_t num_bytes_expected_ = 0;
57   sync::ThreadNotification done_;
58   std::vector<std::byte> received_;
59 };
60 
61 class TestStream : public stream::NonSeekableReader {
62  public:
TestStream()63   TestStream() : position_(0) {}
64 
QueueData(ConstByteSpan data)65   void QueueData(ConstByteSpan data) {
66     std::lock_guard lock(send_mutex_);
67     std::copy(data.begin(), data.end(), std::back_inserter(to_send_));
68     available_.release();
69   }
70 
Stop()71   void Stop() {
72     stopped_ = true;
73     available_.release();
74   }
75 
76  private:
WaitForData()77   void WaitForData() {
78     while (!stopped_) {
79       {
80         std::lock_guard lock(send_mutex_);
81         if (position_ < to_send_.size()) {
82           break;
83         }
84       }
85 
86       available_.acquire();
87     }
88   }
89 
DoRead(ByteSpan out)90   StatusWithSize DoRead(ByteSpan out) final {
91     WaitForData();
92 
93     if (stopped_) {
94       return StatusWithSize(0);
95     }
96 
97     std::lock_guard lock(send_mutex_);
98 
99     if (position_ == to_send_.size()) {
100       return StatusWithSize::OutOfRange();
101     }
102 
103     size_t to_copy = std::min(out.size(), to_send_.size() - position_);
104     std::memcpy(out.data(), to_send_.data() + position_, to_copy);
105     position_ += to_copy;
106 
107     return StatusWithSize(to_copy);
108   }
109 
110   sync::Mutex send_mutex_;
111   std::vector<std::byte> to_send_;
112   std::atomic<bool> stopped_ = false;
113   size_t position_;
114   sync::ThreadNotification available_;
115 };
116 
TEST(StreamRpcDispatcherTest,RecvOk)117 TEST(StreamRpcDispatcherTest, RecvOk) {
118   constexpr size_t kWriteSize = 10;
119   constexpr std::array<std::byte, kWriteSize> kWriteBuffer = {};
120 
121   TestIngress test_ingress(kWriteSize);
122   TestStream test_stream;
123 
124   auto dispatcher = StreamRpcDispatcher<kWriteSize>(test_stream, test_ingress);
125   auto dispatcher_thread = Thread(thread::stl::Options(), dispatcher);
126 
127   test_stream.QueueData(kWriteBuffer);
128 
129   test_ingress.Wait();
130 
131   dispatcher.Stop();
132   test_stream.Stop();
133   dispatcher_thread.join();
134 
135   auto received = test_ingress.received();
136   EXPECT_EQ(received.size(), kWriteSize);
137   EXPECT_EQ(dispatcher.num_read_errors(), 0U);
138 }
139 
140 }  // namespace
141 }  // namespace pw::rpc
142