mem: Use HostSocket in the SharedMemoryServer.

Use a HostSocket parameter to accept connections, rather than a hand
implementation for unix domain sockets. This consolidates this code
with the code derived from it in ListenSocket, and also makes it
possible to connect to the SharedMemoryServer over an AF_INET socket.

Change-Id: I8e05434d08cffaebdf6c68a967e2ee7613c10a76
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/69168
Maintainer: Gabe Black <gabeblack@google.com>
Tested-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Jui-min Lee <fcrh@google.com>
This commit is contained in:
Gabe Black
2023-03-18 23:10:03 -07:00
parent e79d6616dd
commit f9cf3de711
2 changed files with 26 additions and 45 deletions

View File

@@ -39,6 +39,7 @@
#include <algorithm>
#include <cerrno>
#include <cstring>
#include <filesystem>
#include "base/logging.hh"
#include "base/output.hh"
@@ -49,54 +50,37 @@ namespace gem5
namespace memory
{
namespace
{
ListenSocketPtr
buildListenSocket(const std::string &path, const std::string &name)
{
fatal_if(path.empty(), "%s: Empty socket path", name);
if (path[0] == '@')
return listenSocketUnixAbstractConfig(path.substr(1)).build(name);
std::filesystem::path p(path);
return listenSocketUnixFileConfig(
p.parent_path(), p.filename()).build(name);
}
} // anonymous namespace
SharedMemoryServer::SharedMemoryServer(const SharedMemoryServerParams& params)
: SimObject(params),
sockAddr(UnixSocketAddr::build(params.server_path)),
system(params.system),
serverFd(-1)
listener(buildListenSocket(params.server_path, name()))
{
fatal_if(system == nullptr, "Requires a system to share memory from!");
// Create a new unix socket.
serverFd = ListenSocket::socketCloexec(AF_UNIX, SOCK_STREAM, 0);
panic_if(serverFd < 0, "%s: cannot create unix socket: %s", name(),
strerror(errno));
listener->listen();
const auto& [serv_addr, addr_size, is_abstract, formatted_path] = sockAddr;
if (!is_abstract) {
// Ensure the unix socket path to use is not occupied. Also, if there's
// actually anything to be removed, warn the user something might be
// off.
bool old_sock_removed = unlink(serv_addr.sun_path) == 0;
warn_if(old_sock_removed,
"%s: server path %s was occupied and will be replaced. Please "
"make sure there is no other server using the same path.",
name(), serv_addr.sun_path);
}
int bind_retv = bind(
serverFd, reinterpret_cast<const sockaddr*>(&serv_addr), addr_size);
fatal_if(bind_retv != 0, "%s: cannot bind unix socket '%s': %s", name(),
formatted_path, strerror(errno));
// Start listening.
int listen_retv = listen(serverFd, 1);
fatal_if(listen_retv != 0, "%s: listen failed: %s", name(),
strerror(errno));
listenSocketEvent.reset(new ListenSocketEvent(serverFd, this));
listenSocketEvent.reset(new ListenSocketEvent(listener->getfd(), this));
pollQueue.schedule(listenSocketEvent.get());
inform("%s: listening at %s", name(), formatted_path);
inform("%s: listening at %s", name(), *listener);
}
SharedMemoryServer::~SharedMemoryServer()
{
if (!sockAddr.isAbstract) {
int unlink_retv = unlink(sockAddr.addr.sun_path);
warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(),
strerror(errno));
}
int close_retv = close(serverFd);
warn_if(close_retv != 0, "%s: cannot close unix socket: %s", name(),
strerror(errno));
}
SharedMemoryServer::~SharedMemoryServer() {}
SharedMemoryServer::BaseShmPollEvent::BaseShmPollEvent(
int fd, SharedMemoryServer* shm_server)
@@ -130,10 +114,7 @@ SharedMemoryServer::BaseShmPollEvent::tryReadAll(void* buffer, size_t size)
void
SharedMemoryServer::ListenSocketEvent::process(int revents)
{
panic_if(revents & (POLLERR | POLLNVAL), "%s: listen socket is broken",
name());
int cli_fd = ListenSocket::acceptCloexec(pfd.fd, nullptr, nullptr);
panic_if(cli_fd < 0, "%s: accept failed: %s", name(), strerror(errno));
int cli_fd = shmServer->listener->accept();
inform("%s: accept new connection %d", name(), cli_fd);
shmServer->clientSocketEvents[cli_fd].reset(
new ClientSocketEvent(cli_fd, shmServer));

View File

@@ -83,10 +83,10 @@ class SharedMemoryServer : public SimObject
void process(int revent) override;
};
UnixSocketAddr sockAddr;
System* system;
int serverFd;
ListenSocketPtr listener;
std::unique_ptr<ListenSocketEvent> listenSocketEvent;
std::unordered_map<int, std::unique_ptr<ClientSocketEvent>>
clientSocketEvents;