xref: /aosp_15_r20/external/cronet/net/third_party/quiche/src/quiche/quic/moqt/tools/chat_client_bin.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright (c) 2023 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <poll.h>
6 #include <unistd.h>
7 
8 #include <cstdint>
9 #include <fstream>
10 #include <iostream>
11 #include <memory>
12 #include <optional>
13 #include <sstream>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/functional/bind_front.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "quiche/quic/core/crypto/proof_verifier.h"
23 #include "quiche/quic/core/io/quic_default_event_loop.h"
24 #include "quiche/quic/core/io/quic_event_loop.h"
25 #include "quiche/quic/core/quic_default_clock.h"
26 #include "quiche/quic/core/quic_server_id.h"
27 #include "quiche/quic/core/quic_time.h"
28 #include "quiche/quic/moqt/moqt_messages.h"
29 #include "quiche/quic/moqt/moqt_session.h"
30 #include "quiche/quic/moqt/moqt_track.h"
31 #include "quiche/quic/moqt/tools/moqt_client.h"
32 #include "quiche/quic/platform/api/quic_default_proof_providers.h"
33 #include "quiche/quic/platform/api/quic_socket_address.h"
34 #include "quiche/quic/tools/fake_proof_verifier.h"
35 #include "quiche/quic/tools/interactive_cli.h"
36 #include "quiche/quic/tools/quic_name_lookup.h"
37 #include "quiche/quic/tools/quic_url.h"
38 #include "quiche/common/platform/api/quiche_command_line_flags.h"
39 #include "quiche/common/platform/api/quiche_export.h"
40 
41 DEFINE_QUICHE_COMMAND_LINE_FLAG(
42     bool, disable_certificate_verification, false,
43     "If true, don't verify the server certificate.");
44 
45 DEFINE_QUICHE_COMMAND_LINE_FLAG(
46     std::string, output_file, "",
47     "chat messages will stream to a file instead of stdout");
48 
49 class ChatClient {
50  public:
ChatClient(quic::QuicServerId & server_id,std::string path,std::string username,std::string chat_id)51   ChatClient(quic::QuicServerId& server_id, std::string path,
52              std::string username, std::string chat_id)
53       : chat_id_(chat_id),
54         username_(username),
55         my_track_name_(UsernameToTrackName(username)),
56         catalog_name_("moq-chat/" + chat_id, "/catalog") {
57     quic::QuicDefaultClock* clock = quic::QuicDefaultClock::Get();
58     std::cout << "Connecting to host " << server_id.host() << " port "
59               << server_id.port() << " path " << path << "\n";
60     event_loop_ = quic::GetDefaultEventLoop()->Create(clock);
61     quic::QuicSocketAddress peer_address =
62         quic::tools::LookupAddress(AF_UNSPEC, server_id);
63     std::unique_ptr<quic::ProofVerifier> verifier;
64     const bool ignore_certificate = quiche::GetQuicheCommandLineFlag(
65         FLAGS_disable_certificate_verification);
66     output_filename_ = quiche::GetQuicheCommandLineFlag(FLAGS_output_file);
67     if (!output_filename_.empty()) {
68       output_file_.open(output_filename_);
69       output_file_ << "Chat transcript:\n";
70       output_file_.flush();
71     }
72     if (ignore_certificate) {
73       verifier = std::make_unique<quic::FakeProofVerifier>();
74     } else {
75       verifier = quic::CreateDefaultProofVerifier(server_id.host());
76     }
77     client_ = std::make_unique<moqt::MoqtClient>(
78         peer_address, server_id, std::move(verifier), event_loop_.get());
79     session_callbacks_.session_established_callback = [this]() {
80       std::cout << "Session established\n";
81       session_is_open_ = true;
82       if (output_filename_.empty()) {  // Use the CLI.
83         cli_ = std::make_unique<quic::InteractiveCli>(
84             event_loop_.get(),
85             absl::bind_front(&ChatClient::OnTerminalLineInput, this));
86         cli_->PrintLine("Fully connected. Enter '/exit' to exit the chat.\n");
87       }
88     };
89     session_callbacks_.session_terminated_callback =
90         [this](absl::string_view error_message) {
91           std::cerr << "Closed session, reason = " << error_message << "\n";
92           session_is_open_ = false;
93         };
94     session_callbacks_.session_deleted_callback = [this]() {
95       session_ = nullptr;
96     };
97     client_->Connect(path, std::move(session_callbacks_));
98   }
99 
OnTerminalLineInput(absl::string_view input_message)100   void OnTerminalLineInput(absl::string_view input_message) {
101     if (input_message.empty()) {
102       return;
103     }
104     if (input_message == "/exit") {
105       session_is_open_ = false;
106       return;
107     }
108     session_->PublishObject(my_track_name_, next_sequence_.group++,
109                             next_sequence_.object, /*object_send_order=*/0,
110                             input_message, true);
111   }
112 
session_is_open() const113   bool session_is_open() const { return session_is_open_; }
is_syncing() const114   bool is_syncing() const {
115     return !catalog_group_.has_value() || subscribes_to_make_ > 0 ||
116            !session_->HasSubscribers(my_track_name_);
117   }
118 
RunEventLoop()119   void RunEventLoop() {
120     event_loop_->RunEventLoopOnce(quic::QuicTime::Delta::FromMilliseconds(500));
121   }
122 
has_output_file()123   bool has_output_file() { return !output_filename_.empty(); }
124 
WriteToFile(absl::string_view user,absl::string_view message)125   void WriteToFile(absl::string_view user, absl::string_view message) {
126     output_file_ << user << ": " << message << "\n\n";
127     output_file_.flush();
128   }
129 
130   class QUICHE_EXPORT RemoteTrackVisitor : public moqt::RemoteTrack::Visitor {
131    public:
RemoteTrackVisitor(ChatClient * client)132     RemoteTrackVisitor(ChatClient* client) : client_(client) {}
133 
OnReply(const moqt::FullTrackName & full_track_name,std::optional<absl::string_view> reason_phrase)134     void OnReply(const moqt::FullTrackName& full_track_name,
135                  std::optional<absl::string_view> reason_phrase) override {
136       client_->subscribes_to_make_--;
137       if (full_track_name == client_->catalog_name_) {
138         std::cout << "Subscription to catalog ";
139       } else {
140         std::cout << "Subscription to user " << full_track_name.track_namespace
141                   << " ";
142       }
143       if (reason_phrase.has_value()) {
144         std::cout << "REJECTED, reason = " << *reason_phrase << "\n";
145       } else {
146         std::cout << "ACCEPTED\n";
147       }
148     }
149 
OnObjectFragment(const moqt::FullTrackName & full_track_name,uint64_t group_sequence,uint64_t object_sequence,uint64_t,moqt::MoqtForwardingPreference,absl::string_view object,bool end_of_message)150     void OnObjectFragment(
151         const moqt::FullTrackName& full_track_name, uint64_t group_sequence,
152         uint64_t object_sequence, uint64_t /*object_send_order*/,
153         moqt::MoqtForwardingPreference /*forwarding_preference*/,
154         absl::string_view object, bool end_of_message) override {
155       if (!end_of_message) {
156         std::cerr << "Error: received partial message despite requesting "
157                      "buffering\n";
158       }
159       if (full_track_name == client_->catalog_name_) {
160         if (group_sequence < client_->catalog_group_) {
161           std::cout << "Ignoring old catalog";
162           return;
163         }
164         client_->ProcessCatalog(object, this, group_sequence, object_sequence);
165         return;
166       }
167       std::string username = full_track_name.track_namespace;
168       username = username.substr(username.find_last_of('/') + 1);
169       if (!client_->other_users_.contains(username)) {
170         std::cout << "Username " << username << "doesn't exist\n";
171         return;
172       }
173       if (client_->has_output_file()) {
174         client_->WriteToFile(username, object);
175         return;
176       }
177       if (cli_ != nullptr) {
178         std::string full_output = absl::StrCat(username, ": ", object);
179         cli_->PrintLine(full_output);
180       }
181     }
182 
set_cli(quic::InteractiveCli * cli)183     void set_cli(quic::InteractiveCli* cli) { cli_ = cli; }
184 
185    private:
186     ChatClient* client_;
187     quic::InteractiveCli* cli_;
188   };
189 
190   // returns false on error
AnnounceAndSubscribe()191   bool AnnounceAndSubscribe() {
192     session_ = client_->session();
193     if (session_ == nullptr) {
194       std::cout << "Failed to connect.\n";
195       return false;
196     }
197     // By not sending a visitor, the application will not fulfill subscriptions
198     // to previous objects.
199     session_->AddLocalTrack(my_track_name_,
200                             moqt::MoqtForwardingPreference::kObject, nullptr);
201     moqt::MoqtOutgoingAnnounceCallback announce_callback =
202         [&](absl::string_view track_namespace,
203             std::optional<moqt::MoqtAnnounceErrorReason> reason) {
204           if (reason.has_value()) {
205             std::cout << "ANNOUNCE rejected, " << reason->reason_phrase << "\n";
206             session_->Error(moqt::MoqtError::kInternalError,
207                             "Local ANNOUNCE rejected");
208             return;
209           }
210           std::cout << "ANNOUNCE for " << track_namespace << " accepted\n";
211           return;
212         };
213     std::cout << "Announcing " << my_track_name_.track_namespace << "\n";
214     session_->Announce(my_track_name_.track_namespace,
215                        std::move(announce_callback));
216     remote_track_visitor_ = std::make_unique<RemoteTrackVisitor>(this);
217     if (!session_->SubscribeCurrentGroup(
218             catalog_name_.track_namespace, catalog_name_.track_name,
219             remote_track_visitor_.get(), username_)) {
220       std::cout << "Failed to get catalog for " << chat_id_ << "\n";
221       return false;
222     }
223     return true;
224   }
225 
226  private:
UsernameToTrackName(absl::string_view username)227   moqt::FullTrackName UsernameToTrackName(absl::string_view username) {
228     return moqt::FullTrackName(
229         absl::StrCat("moq-chat/", chat_id_, "/participant/", username), "");
230   }
231 
232   // Objects from the same catalog group arrive on the same stream, and in
233   // object sequence order.
ProcessCatalog(absl::string_view object,moqt::RemoteTrack::Visitor * visitor,uint64_t group_sequence,uint64_t object_sequence)234   void ProcessCatalog(absl::string_view object,
235                       moqt::RemoteTrack::Visitor* visitor,
236                       uint64_t group_sequence, uint64_t object_sequence) {
237     std::string message(object);
238     std::istringstream f(message);
239     std::string line;
240     bool got_version = true;
241     if (object_sequence == 0) {
242       std::cout << "Received new Catalog. Users:\n";
243       got_version = false;
244     }
245     while (std::getline(f, line)) {
246       if (!got_version) {
247         if (line != "version=1") {
248           session_->Error(moqt::MoqtError::kProtocolViolation,
249                           "Catalog does not begin with version");
250           return;
251         }
252         got_version = true;
253         continue;
254       }
255       if (line.empty()) {
256         continue;
257       }
258       std::string user;
259       bool add = true;
260       if (object_sequence > 0) {
261         switch (line[0]) {
262           case '-':
263             add = false;
264             break;
265           case '+':
266             break;
267           default:
268             std::cerr << "Catalog update with neither + nor -\n";
269             return;
270         }
271         user = line.substr(1, line.size() - 1);
272       } else {
273         user = line;
274       }
275       if (username_ == user) {
276         std::cout << user << "\n";
277         continue;
278       }
279       if (!add) {
280         // TODO: Unsubscribe from the user that's leaving
281         std::cout << user << "left the chat\n";
282         other_users_.erase(user);
283         continue;
284       }
285       if (object_sequence == 0) {
286         std::cout << user << "\n";
287       } else {
288         std::cout << user << "joined the chat\n";
289       }
290       auto it = other_users_.find(user);
291       if (it == other_users_.end()) {
292         moqt::FullTrackName to_subscribe = UsernameToTrackName(user);
293         auto new_user = other_users_.emplace(
294             std::make_pair(user, ChatUser(to_subscribe, group_sequence)));
295         ChatUser& user_record = new_user.first->second;
296         session_->SubscribeRelative(user_record.full_track_name.track_namespace,
297                                     user_record.full_track_name.track_name, 0,
298                                     0, visitor);
299         subscribes_to_make_++;
300       } else {
301         if (it->second.from_group == group_sequence) {
302           session_->Error(moqt::MoqtError::kProtocolViolation,
303                           "User listed twice in Catalog");
304           return;
305         }
306         it->second.from_group = group_sequence;
307       }
308     }
309     if (object_sequence == 0) {  // Eliminate users that are no longer present
310       for (const auto& it : other_users_) {
311         if (it.second.from_group != group_sequence) {
312           other_users_.erase(it.first);
313         }
314       }
315     }
316     catalog_group_ = group_sequence;
317   }
318 
319   struct ChatUser {
320     moqt::FullTrackName full_track_name;
321     uint64_t from_group;
ChatUserChatClient::ChatUser322     ChatUser(moqt::FullTrackName& ftn, uint64_t group)
323         : full_track_name(ftn), from_group(group) {}
324   };
325 
326   // Basic session information
327   const std::string chat_id_;
328   const std::string username_;
329   const moqt::FullTrackName my_track_name_;
330 
331   // General state variables
332   std::unique_ptr<quic::QuicEventLoop> event_loop_;
333   bool session_is_open_ = false;
334   moqt::MoqtSession* session_ = nullptr;
335   std::unique_ptr<moqt::MoqtClient> client_;
336   moqt::MoqtSessionCallbacks session_callbacks_;
337 
338   // Related to syncing.
339   std::optional<uint64_t> catalog_group_;
340   moqt::FullTrackName catalog_name_;
341   absl::flat_hash_map<std::string, ChatUser> other_users_;
342   int subscribes_to_make_ = 1;
343 
344   // Related to subscriptions/announces
345   // TODO: One for each subscribe
346   std::unique_ptr<RemoteTrackVisitor> remote_track_visitor_;
347 
348   // Handling incoming and outgoing messages
349   moqt::FullSequence next_sequence_ = {0, 0};
350 
351   // Used when chat output goes to file.
352   std::ofstream output_file_;
353   std::string output_filename_;
354 
355   // Used when there is no output file, and both input and output are in the
356   // terminal.
357   std::unique_ptr<quic::InteractiveCli> cli_;
358 };
359 
360 // A client for MoQT over chat, used for interop testing. See
361 // https://afrind.github.io/draft-frindell-moq-chat/draft-frindell-moq-chat.html
main(int argc,char * argv[])362 int main(int argc, char* argv[]) {
363   const char* usage = "Usage: chat_client [options] <url> <username> <chat-id>";
364   std::vector<std::string> args =
365       quiche::QuicheParseCommandLineFlags(usage, argc, argv);
366   if (args.size() != 3) {
367     quiche::QuichePrintCommandLineFlagHelp(usage);
368     return 1;
369   }
370   quic::QuicUrl url(args[0], "https");
371   quic::QuicServerId server_id(url.host(), url.port());
372   std::string path = url.PathParamsQuery();
373   std::string username = args[1];
374   std::string chat_id = args[2];
375   ChatClient client(server_id, path, username, chat_id);
376 
377   while (!client.session_is_open()) {
378     client.RunEventLoop();
379   }
380 
381   if (!client.AnnounceAndSubscribe()) {
382     return 1;
383   }
384   while (client.is_syncing()) {
385     client.RunEventLoop();
386   }
387   if (!client.session_is_open()) {
388     return 1;  // Something went wrong in connecting.
389   }
390   if (!client.has_output_file()) {
391     while (client.session_is_open()) {
392       client.RunEventLoop();
393     }
394     return 0;
395   }
396   // There is an output file.
397   std::cout << "Fully connected. Messages are in the output file. Exit the "
398             << "session by entering /exit\n";
399   struct pollfd poll_settings = {
400       0,
401       POLLIN,
402       POLLIN,
403   };
404   while (client.session_is_open()) {
405     std::string message_to_send;
406     while (poll(&poll_settings, 1, 0) <= 0) {
407       client.RunEventLoop();
408     }
409     std::getline(std::cin, message_to_send);
410     client.OnTerminalLineInput(message_to_send);
411     client.WriteToFile(username, message_to_send);
412   }
413   return 0;
414 }
415