xref: /aosp_15_r20/external/pytorch/torch/lib/libshm/socket.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <poll.h>
4 #include <sys/socket.h>
5 #include <sys/stat.h>
6 #include <sys/types.h>
7 #include <sys/un.h>
8 #include <unistd.h>
9 #include <cstddef>
10 #include <cstdio>
11 #include <cstring>
12 #include <string>
13 
14 #include <libshm/alloc_info.h>
15 #include <libshm/err.h>
16 
17 class Socket {
18  public:
19   int socket_fd;
20 
21  protected:
Socket()22   Socket() {
23     SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
24   }
25   Socket(const Socket& other) = delete;
Socket(Socket && other)26   Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
27     other.socket_fd = -1;
28   };
Socket(int fd)29   explicit Socket(int fd) : socket_fd(fd) {}
30 
~Socket()31   virtual ~Socket() {
32     if (socket_fd != -1)
33       close(socket_fd);
34   }
35 
prepare_address(const char * path)36   struct sockaddr_un prepare_address(const char* path) {
37     struct sockaddr_un address;
38     address.sun_family = AF_UNIX;
39     strcpy(address.sun_path, path);
40     return address;
41   }
42 
43   // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
address_length(struct sockaddr_un address)44   size_t address_length(struct sockaddr_un address) {
45     return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
46   }
47 
recv(void * _buffer,size_t num_bytes)48   void recv(void* _buffer, size_t num_bytes) {
49     char* buffer = (char*)_buffer;
50     size_t bytes_received = 0;
51     ssize_t step_received;
52     struct pollfd pfd = {};
53     pfd.fd = socket_fd;
54     pfd.events = POLLIN;
55     while (bytes_received < num_bytes) {
56       SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
57       if (pfd.revents & POLLIN) {
58         SYSCHECK_ERR_RETURN_NEG1(
59             step_received =
60                 ::read(socket_fd, buffer, num_bytes - bytes_received));
61         if (step_received == 0)
62           throw std::runtime_error("Other end has closed the connection");
63         bytes_received += step_received;
64         buffer += step_received;
65       } else if (pfd.revents & (POLLERR | POLLHUP)) {
66         throw std::runtime_error(
67             "An error occurred while waiting for the data");
68       } else {
69         throw std::runtime_error(
70             "Shared memory manager connection has timed out");
71       }
72     }
73   }
74 
send(const void * _buffer,size_t num_bytes)75   void send(const void* _buffer, size_t num_bytes) {
76     const char* buffer = (const char*)_buffer;
77     size_t bytes_sent = 0;
78     ssize_t step_sent;
79     while (bytes_sent < num_bytes) {
80       SYSCHECK_ERR_RETURN_NEG1(
81           step_sent = ::write(socket_fd, buffer, num_bytes));
82       bytes_sent += step_sent;
83       buffer += step_sent;
84     }
85   }
86 };
87 
88 class ManagerSocket : public Socket {
89  public:
ManagerSocket(int fd)90   explicit ManagerSocket(int fd) : Socket(fd) {}
91 
receive()92   AllocInfo receive() {
93     AllocInfo info;
94     recv(&info, sizeof(info));
95     return info;
96   }
97 
confirm()98   void confirm() {
99     send("OK", 2);
100   }
101 };
102 
103 class ManagerServerSocket : public Socket {
104  public:
ManagerServerSocket(const std::string & path)105   explicit ManagerServerSocket(const std::string& path) {
106     socket_path = path;
107     try {
108       struct sockaddr_un address = prepare_address(path.c_str());
109       size_t len = address_length(address);
110       SYSCHECK_ERR_RETURN_NEG1(
111           bind(socket_fd, (struct sockaddr*)&address, len));
112       SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
113     } catch (std::exception&) {
114       SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
115       throw;
116     }
117   }
118 
remove()119   void remove() {
120     struct stat file_stat;
121     if (fstat(socket_fd, &file_stat) == 0)
122       SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
123   }
124 
~ManagerServerSocket()125   virtual ~ManagerServerSocket() {
126     unlink(socket_path.c_str());
127   }
128 
accept()129   ManagerSocket accept() {
130     int client_fd;
131     struct sockaddr_un addr;
132     socklen_t addr_len = sizeof(addr);
133     SYSCHECK_ERR_RETURN_NEG1(
134         client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
135     return ManagerSocket(client_fd);
136   }
137 
138   std::string socket_path;
139 };
140 
141 class ClientSocket : public Socket {
142  public:
ClientSocket(const std::string & path)143   explicit ClientSocket(const std::string& path) {
144     try {
145       struct sockaddr_un address = prepare_address(path.c_str());
146       size_t len = address_length(address);
147       SYSCHECK_ERR_RETURN_NEG1(
148           connect(socket_fd, (struct sockaddr*)&address, len));
149     } catch (std::exception&) {
150       SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
151       throw;
152     }
153   }
154 
register_allocation(AllocInfo & info)155   void register_allocation(AllocInfo& info) {
156     char buffer[3] = {0, 0, 0};
157     send(&info, sizeof(info));
158     recv(buffer, 2);
159     if (strcmp(buffer, "OK") != 0)
160       throw std::runtime_error(
161           "Shared memory manager didn't respond with an OK");
162   }
163 
register_deallocation(AllocInfo & info)164   void register_deallocation(AllocInfo& info) {
165     send(&info, sizeof(info));
166   }
167 };
168