diff --git a/src/mem/shared_memory_server.cc b/src/mem/shared_memory_server.cc index 6344ee0388..3e49164e6d 100644 --- a/src/mem/shared_memory_server.cc +++ b/src/mem/shared_memory_server.cc @@ -39,6 +39,7 @@ #include #include #include +#include #include "base/logging.hh" #include "base/output.hh" @@ -49,54 +50,37 @@ namespace gem5 namespace memory { +namespace +{ + +ListenSocketPtr +buildListenSocket(const std::string &path, const std::string &name) +{ + fatal_if(path.empty(), "%s: Empty socket path", name); + if (path[0] == '@') + return listenSocketUnixAbstractConfig(path.substr(1)).build(name); + + std::filesystem::path p(path); + return listenSocketUnixFileConfig( + p.parent_path(), p.filename()).build(name); +} + +} // anonymous namespace + SharedMemoryServer::SharedMemoryServer(const SharedMemoryServerParams& params) : SimObject(params), - sockAddr(UnixSocketAddr::build(params.server_path)), system(params.system), - serverFd(-1) + listener(buildListenSocket(params.server_path, name())) { fatal_if(system == nullptr, "Requires a system to share memory from!"); - // Create a new unix socket. - serverFd = ListenSocket::socketCloexec(AF_UNIX, SOCK_STREAM, 0); - panic_if(serverFd < 0, "%s: cannot create unix socket: %s", name(), - strerror(errno)); + listener->listen(); - const auto& [serv_addr, addr_size, is_abstract, formatted_path] = sockAddr; - - if (!is_abstract) { - // Ensure the unix socket path to use is not occupied. Also, if there's - // actually anything to be removed, warn the user something might be - // off. - bool old_sock_removed = unlink(serv_addr.sun_path) == 0; - warn_if(old_sock_removed, - "%s: server path %s was occupied and will be replaced. Please " - "make sure there is no other server using the same path.", - name(), serv_addr.sun_path); - } - int bind_retv = bind( - serverFd, reinterpret_cast(&serv_addr), addr_size); - fatal_if(bind_retv != 0, "%s: cannot bind unix socket '%s': %s", name(), - formatted_path, strerror(errno)); - // Start listening. - int listen_retv = listen(serverFd, 1); - fatal_if(listen_retv != 0, "%s: listen failed: %s", name(), - strerror(errno)); - listenSocketEvent.reset(new ListenSocketEvent(serverFd, this)); + listenSocketEvent.reset(new ListenSocketEvent(listener->getfd(), this)); pollQueue.schedule(listenSocketEvent.get()); - inform("%s: listening at %s", name(), formatted_path); + inform("%s: listening at %s", name(), *listener); } -SharedMemoryServer::~SharedMemoryServer() -{ - if (!sockAddr.isAbstract) { - int unlink_retv = unlink(sockAddr.addr.sun_path); - warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(), - strerror(errno)); - } - int close_retv = close(serverFd); - warn_if(close_retv != 0, "%s: cannot close unix socket: %s", name(), - strerror(errno)); -} +SharedMemoryServer::~SharedMemoryServer() {} SharedMemoryServer::BaseShmPollEvent::BaseShmPollEvent( int fd, SharedMemoryServer* shm_server) @@ -130,10 +114,7 @@ SharedMemoryServer::BaseShmPollEvent::tryReadAll(void* buffer, size_t size) void SharedMemoryServer::ListenSocketEvent::process(int revents) { - panic_if(revents & (POLLERR | POLLNVAL), "%s: listen socket is broken", - name()); - int cli_fd = ListenSocket::acceptCloexec(pfd.fd, nullptr, nullptr); - panic_if(cli_fd < 0, "%s: accept failed: %s", name(), strerror(errno)); + int cli_fd = shmServer->listener->accept(); inform("%s: accept new connection %d", name(), cli_fd); shmServer->clientSocketEvents[cli_fd].reset( new ClientSocketEvent(cli_fd, shmServer)); diff --git a/src/mem/shared_memory_server.hh b/src/mem/shared_memory_server.hh index d9fbeb3f20..a4ef63d541 100644 --- a/src/mem/shared_memory_server.hh +++ b/src/mem/shared_memory_server.hh @@ -83,10 +83,10 @@ class SharedMemoryServer : public SimObject void process(int revent) override; }; - UnixSocketAddr sockAddr; System* system; - int serverFd; + ListenSocketPtr listener; + std::unique_ptr listenSocketEvent; std::unordered_map> clientSocketEvents;