1 #pragma once 2 3 #include <poll.h> 4 #include <sys/socket.h> 5 #include <sys/stat.h> 6 #include <sys/types.h> 7 #include <sys/un.h> 8 #include <unistd.h> 9 #include <cstddef> 10 #include <cstdio> 11 #include <cstring> 12 #include <string> 13 14 #include <libshm/alloc_info.h> 15 #include <libshm/err.h> 16 17 class Socket { 18 public: 19 int socket_fd; 20 21 protected: Socket()22 Socket() { 23 SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0)); 24 } 25 Socket(const Socket& other) = delete; Socket(Socket && other)26 Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { 27 other.socket_fd = -1; 28 }; Socket(int fd)29 explicit Socket(int fd) : socket_fd(fd) {} 30 ~Socket()31 virtual ~Socket() { 32 if (socket_fd != -1) 33 close(socket_fd); 34 } 35 prepare_address(const char * path)36 struct sockaddr_un prepare_address(const char* path) { 37 struct sockaddr_un address; 38 address.sun_family = AF_UNIX; 39 strcpy(address.sun_path, path); 40 return address; 41 } 42 43 // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html address_length(struct sockaddr_un address)44 size_t address_length(struct sockaddr_un address) { 45 return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1; 46 } 47 recv(void * _buffer,size_t num_bytes)48 void recv(void* _buffer, size_t num_bytes) { 49 char* buffer = (char*)_buffer; 50 size_t bytes_received = 0; 51 ssize_t step_received; 52 struct pollfd pfd = {}; 53 pfd.fd = socket_fd; 54 pfd.events = POLLIN; 55 while (bytes_received < num_bytes) { 56 SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000)); 57 if (pfd.revents & POLLIN) { 58 SYSCHECK_ERR_RETURN_NEG1( 59 step_received = 60 ::read(socket_fd, buffer, num_bytes - bytes_received)); 61 if (step_received == 0) 62 throw std::runtime_error("Other end has closed the connection"); 63 bytes_received += step_received; 64 buffer += step_received; 65 } else if (pfd.revents & (POLLERR | POLLHUP)) { 66 throw std::runtime_error( 67 "An error occurred while waiting for the data"); 68 } else { 69 throw std::runtime_error( 70 "Shared memory manager connection has timed out"); 71 } 72 } 73 } 74 send(const void * _buffer,size_t num_bytes)75 void send(const void* _buffer, size_t num_bytes) { 76 const char* buffer = (const char*)_buffer; 77 size_t bytes_sent = 0; 78 ssize_t step_sent; 79 while (bytes_sent < num_bytes) { 80 SYSCHECK_ERR_RETURN_NEG1( 81 step_sent = ::write(socket_fd, buffer, num_bytes)); 82 bytes_sent += step_sent; 83 buffer += step_sent; 84 } 85 } 86 }; 87 88 class ManagerSocket : public Socket { 89 public: ManagerSocket(int fd)90 explicit ManagerSocket(int fd) : Socket(fd) {} 91 receive()92 AllocInfo receive() { 93 AllocInfo info; 94 recv(&info, sizeof(info)); 95 return info; 96 } 97 confirm()98 void confirm() { 99 send("OK", 2); 100 } 101 }; 102 103 class ManagerServerSocket : public Socket { 104 public: ManagerServerSocket(const std::string & path)105 explicit ManagerServerSocket(const std::string& path) { 106 socket_path = path; 107 try { 108 struct sockaddr_un address = prepare_address(path.c_str()); 109 size_t len = address_length(address); 110 SYSCHECK_ERR_RETURN_NEG1( 111 bind(socket_fd, (struct sockaddr*)&address, len)); 112 SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10)); 113 } catch (std::exception&) { 114 SYSCHECK_ERR_RETURN_NEG1(close(socket_fd)); 115 throw; 116 } 117 } 118 remove()119 void remove() { 120 struct stat file_stat; 121 if (fstat(socket_fd, &file_stat) == 0) 122 SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str())); 123 } 124 ~ManagerServerSocket()125 virtual ~ManagerServerSocket() { 126 unlink(socket_path.c_str()); 127 } 128 accept()129 ManagerSocket accept() { 130 int client_fd; 131 struct sockaddr_un addr; 132 socklen_t addr_len = sizeof(addr); 133 SYSCHECK_ERR_RETURN_NEG1( 134 client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len)); 135 return ManagerSocket(client_fd); 136 } 137 138 std::string socket_path; 139 }; 140 141 class ClientSocket : public Socket { 142 public: ClientSocket(const std::string & path)143 explicit ClientSocket(const std::string& path) { 144 try { 145 struct sockaddr_un address = prepare_address(path.c_str()); 146 size_t len = address_length(address); 147 SYSCHECK_ERR_RETURN_NEG1( 148 connect(socket_fd, (struct sockaddr*)&address, len)); 149 } catch (std::exception&) { 150 SYSCHECK_ERR_RETURN_NEG1(close(socket_fd)); 151 throw; 152 } 153 } 154 register_allocation(AllocInfo & info)155 void register_allocation(AllocInfo& info) { 156 char buffer[3] = {0, 0, 0}; 157 send(&info, sizeof(info)); 158 recv(buffer, 2); 159 if (strcmp(buffer, "OK") != 0) 160 throw std::runtime_error( 161 "Shared memory manager didn't respond with an OK"); 162 } 163 register_deallocation(AllocInfo & info)164 void register_deallocation(AllocInfo& info) { 165 send(&info, sizeof(info)); 166 } 167 }; 168