From 15cb9c7abef94d53135351422284c8651ce0133b Mon Sep 17 00:00:00 2001 From: Simon Park Date: Tue, 3 Jan 2023 01:07:03 -0800 Subject: [PATCH] base: socket: add UnixSocketAddr for representing socket paths Added UnixSocketAddr that wraps around sockaddr_un. Using this wrapper, users can create both file based sockets as well as abstract sockets. Change-Id: Ibf105b92a6a6ac7fc9136ed307f824c83e45c06c Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/66471 Maintainer: Gabe Black Reviewed-by: Gabe Black Tested-by: kokoro --- src/base/SConscript | 3 +- src/base/socket.cc | 66 ++++++++++++++++++++++++++++ src/base/socket.hh | 34 +++++++++++++++ src/base/socket.test.cc | 77 +++++++++++++++++++++++++++++++++ src/base/str.hh | 9 ++++ src/mem/shared_memory_server.cc | 56 +++++++++++------------- src/mem/shared_memory_server.hh | 3 +- 7 files changed, 216 insertions(+), 32 deletions(-) diff --git a/src/base/SConscript b/src/base/SConscript index e751d0b5ef..4a6b65fa72 100644 --- a/src/base/SConscript +++ b/src/base/SConscript @@ -68,7 +68,8 @@ Source('pollevent.cc') Source('random.cc') Source('remote_gdb.cc') Source('socket.cc') -GTest('socket.test', 'socket.test.cc', 'socket.cc') +SourceLib('z', tags='socket_test') +GTest('socket.test', 'socket.test.cc', 'socket.cc', 'output.cc', with_tag('socket_test')) Source('statistics.cc') Source('str.cc', add_tags=['gem5 trace', 'gem5 serialize']) GTest('str.test', 'str.test.cc', 'str.cc') diff --git a/src/base/socket.cc b/src/base/socket.cc index 5cf67fdd90..23f2b40d1a 100644 --- a/src/base/socket.cc +++ b/src/base/socket.cc @@ -35,22 +35,88 @@ #include #include #include +#include #include #include #include "base/logging.hh" +#include "base/output.hh" +#include "base/str.hh" #include "base/types.hh" #include "sim/byteswap.hh" namespace gem5 { +namespace +{ + +bool +isSocketNameAbstract(const std::string &path) +{ + if (path.empty()) { + return false; + } + // No null byte should be present in the path + return path.front() == '@'; +} + +std::string +resolve(const std::string &path) +{ + if (path.empty()) { + return path; + } + if (isSocketNameAbstract(path)) { + return '\0' + path.substr(1); + } + return simout.resolve(path); +} + +} // namespace bool ListenSocket::listeningDisabled = false; bool ListenSocket::anyListening = false; bool ListenSocket::bindToLoopback = false; +UnixSocketAddr +UnixSocketAddr::build(const std::string &path) +{ + sockaddr_un addr = {.sun_family = AF_UNIX, .sun_path = {}}; + + const bool is_abstract = isSocketNameAbstract(path); + size_t max_len = sizeof(addr.sun_path); + if (!is_abstract) { + // File based socket names need to be null terminated + max_len -= 1; + } + + std::string resolved_path = resolve(path); + std::string fmt_path = replace(resolved_path, '\0', '@'); + if (resolved_path.size() > max_len) { + resolved_path = resolved_path.substr(0, max_len); + const std::string untruncated_path = std::move(fmt_path); + fmt_path = replace(resolved_path, '\0', '@'); + warn("SocketPath: unix socket path truncated from '%s' to '%s'", + untruncated_path, fmt_path); + } + + // We can't use strncpy here, since abstract sockets start with \0 which + // will make strncpy think that the string is empty. + memcpy(addr.sun_path, resolved_path.c_str(), resolved_path.size()); + // We can't use sizeof(sockaddr_un) for abstract sockets, since all + // sizeof(sun_path) bytes are used in representing the path. + const size_t path_size = + is_abstract ? resolved_path.size() : sizeof(addr.sun_path); + const size_t addr_size = offsetof(sockaddr_un, sun_path) + path_size; + + return UnixSocketAddr{.addr = std::move(addr), + .addrSize = addr_size, + .isAbstract = is_abstract, + .formattedPath = std::move(fmt_path)}; +} + void ListenSocket::cleanup() { diff --git a/src/base/socket.hh b/src/base/socket.hh index 3375ccc80a..f3b2760518 100644 --- a/src/base/socket.hh +++ b/src/base/socket.hh @@ -31,10 +31,44 @@ #include #include +#include + +#include namespace gem5 { +/** + * @brief Wrapper around sockaddr_un, so that it can be used for both file + * based unix sockets as well as abstract unix sockets. + */ +struct UnixSocketAddr +{ + /** + * @brief Builds UnixSocketAddr from the given path. + * @pre: `path` either represents a file based unix socket, or an abstract + * unix socket. If `path` represents an abstract socket, it should + * start with the character '@', and it should not have any null + * bytes in the name. + * @param path: Pathname, where the socket should be instantiated. + * @return UnixSocketAddr + */ + static UnixSocketAddr build(const std::string &path); + + sockaddr_un addr; + // Size of `sockaddr_un addr`. This is equal to sizeof(sockaddr_un) if + // `addr` represents a normal file based unix socket. For abstract sockets + // however, the size could be different. Because all sizeof(sun_path) is + // used to represent the name of an abstract socket, addrSize for abstract + // sockets only count the number of characters actually used by sun_path, + // excluding any trailing null bytes. + size_t addrSize; + bool isAbstract; + // Formatted string for file based sockets look the same as addr.sun_path. + // For abstract sockets however, all null bytes are replaced with @ + std::string formattedPath; +}; + class ListenSocket { protected: diff --git a/src/base/socket.test.cc b/src/base/socket.test.cc index a267f8ce43..1ab1f21070 100644 --- a/src/base/socket.test.cc +++ b/src/base/socket.test.cc @@ -28,6 +28,10 @@ #include +#include +#include +#include + #include "base/gtest/logging.hh" #include "base/socket.hh" @@ -41,6 +45,79 @@ using namespace gem5; * socket.cc have not been fully tested due to interaction with system-calls. */ +namespace { + +std::string +repeat(const std::string& str, size_t n) +{ + std::stringstream ss; + for (int i = 0; i < n; ++i) { + ss << str; + } + return ss.str(); +} + +} // namespace + +TEST(UnixSocketAddrTest, AbstractSocket) +{ + UnixSocketAddr sock_addr = UnixSocketAddr::build("@abstract"); + EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family); + // null byte will not show, so compare from the first byte + EXPECT_STREQ("abstract", sock_addr.addr.sun_path + 1); + EXPECT_TRUE(sock_addr.isAbstract); + EXPECT_STREQ("@abstract", sock_addr.formattedPath.c_str()); +} + +TEST(UnixSocketAddrTest, TruncatedAbstractSocket) +{ + // Test that address is truncated if longer than sizeof(sun_path) + constexpr size_t MaxSize = sizeof(std::declval().sun_path); + + // >sizeof(sun_path) bytes + std::string addr = "@" + repeat("123456789", 100); + ASSERT_GT(addr.size(), MaxSize); + std::string truncated_addr = addr.substr(0, MaxSize); + + UnixSocketAddr sock_addr = UnixSocketAddr::build(addr); + EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family); + // Use memcmp so that we can compare null bytes as well + std::string null_formatted = '\0' + truncated_addr.substr(1); + EXPECT_EQ(0, std::memcmp(null_formatted.c_str(), sock_addr.addr.sun_path, + MaxSize)); + EXPECT_TRUE(sock_addr.isAbstract); + EXPECT_EQ(truncated_addr, sock_addr.formattedPath); +} + +TEST(UnixSocketAddrTest, FileBasedSocket) +{ + std::string addr = "/home/parent/dir/x"; + UnixSocketAddr sock_addr = UnixSocketAddr::build(addr); + EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family); + EXPECT_STREQ(addr.c_str(), sock_addr.addr.sun_path); + EXPECT_FALSE(sock_addr.isAbstract); + EXPECT_EQ(addr, sock_addr.formattedPath); +} + +TEST(UnixSocketAddrTest, TruncatedFileBasedSocket) +{ + // sun_path should null terminate, so test that address is truncated if + // longer than sizeof(sun_path) - 1 bytes. + constexpr size_t MaxSize = + sizeof(std::declval().sun_path) - 1; + + // >sizeof(sun_path) - 1 bytes + std::string addr = "/" + repeat("123456789", 100); + ASSERT_GT(addr.size(), MaxSize); + std::string truncated_addr = addr.substr(0, MaxSize); + + UnixSocketAddr sock_addr = UnixSocketAddr::build(addr); + EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family); + EXPECT_STREQ(truncated_addr.c_str(), sock_addr.addr.sun_path); + EXPECT_FALSE(sock_addr.isAbstract); + EXPECT_EQ(truncated_addr, sock_addr.formattedPath); +} + class MockListenSocket : public ListenSocket { public: diff --git a/src/base/str.hh b/src/base/str.hh index 00409ff3d7..855fb43b28 100644 --- a/src/base/str.hh +++ b/src/base/str.hh @@ -32,6 +32,7 @@ #ifndef __BASE_STR_HH__ #define __BASE_STR_HH__ +#include #include #include #include @@ -251,6 +252,14 @@ startswith(const std::string &s, const std::string &prefix) return (s.compare(0, prefix.size(), prefix) == 0); } +inline std::string +replace(const std::string &s, char from, char to) +{ + std::string replaced = s; + std::replace(replaced.begin(), replaced.end(), from, to); + return replaced; +} + } // namespace gem5 #endif //__BASE_STR_HH__ diff --git a/src/mem/shared_memory_server.cc b/src/mem/shared_memory_server.cc index bee663bd37..6344ee0388 100644 --- a/src/mem/shared_memory_server.cc +++ b/src/mem/shared_memory_server.cc @@ -34,7 +34,6 @@ #include #include #include -#include #include #include @@ -44,7 +43,6 @@ #include "base/logging.hh" #include "base/output.hh" #include "base/pollevent.hh" -#include "base/socket.hh" namespace gem5 { @@ -52,51 +50,49 @@ namespace memory { SharedMemoryServer::SharedMemoryServer(const SharedMemoryServerParams& params) - : SimObject(params), unixSocketPath(simout.resolve(params.server_path)), - system(params.system), serverFd(-1) + : SimObject(params), + sockAddr(UnixSocketAddr::build(params.server_path)), + system(params.system), + serverFd(-1) { fatal_if(system == nullptr, "Requires a system to share memory from!"); // Create a new unix socket. serverFd = ListenSocket::socketCloexec(AF_UNIX, SOCK_STREAM, 0); panic_if(serverFd < 0, "%s: cannot create unix socket: %s", name(), strerror(errno)); - // Bind to the specified path. - sockaddr_un serv_addr = {}; - serv_addr.sun_family = AF_UNIX; - strncpy(serv_addr.sun_path, unixSocketPath.c_str(), - sizeof(serv_addr.sun_path) - 1); - // If the target path is truncated, warn the user that the actual path is - // different and update the target path. - if (strlen(serv_addr.sun_path) != unixSocketPath.size()) { - warn("%s: unix socket path truncated, expect '%s' but get '%s'", - name(), unixSocketPath, serv_addr.sun_path); - unixSocketPath = serv_addr.sun_path; + + const auto& [serv_addr, addr_size, is_abstract, formatted_path] = sockAddr; + + if (!is_abstract) { + // Ensure the unix socket path to use is not occupied. Also, if there's + // actually anything to be removed, warn the user something might be + // off. + bool old_sock_removed = unlink(serv_addr.sun_path) == 0; + warn_if(old_sock_removed, + "%s: server path %s was occupied and will be replaced. Please " + "make sure there is no other server using the same path.", + name(), serv_addr.sun_path); } - // Ensure the unix socket path to use is not occupied. Also, if there's - // actually anything to be removed, warn the user something might be off. - bool old_sock_removed = unlink(unixSocketPath.c_str()) == 0; - warn_if(old_sock_removed, - "%s: the server path %s was occupied and will be replaced. Please " - "make sure there is no other server using the same path.", - name(), unixSocketPath); - int bind_retv = bind(serverFd, reinterpret_cast(&serv_addr), - sizeof(serv_addr)); - fatal_if(bind_retv != 0, "%s: cannot bind unix socket: %s", name(), - strerror(errno)); + int bind_retv = bind( + serverFd, reinterpret_cast(&serv_addr), addr_size); + fatal_if(bind_retv != 0, "%s: cannot bind unix socket '%s': %s", name(), + formatted_path, strerror(errno)); // Start listening. int listen_retv = listen(serverFd, 1); fatal_if(listen_retv != 0, "%s: listen failed: %s", name(), strerror(errno)); listenSocketEvent.reset(new ListenSocketEvent(serverFd, this)); pollQueue.schedule(listenSocketEvent.get()); - inform("%s: listening at %s", name(), unixSocketPath); + inform("%s: listening at %s", name(), formatted_path); } SharedMemoryServer::~SharedMemoryServer() { - int unlink_retv = unlink(unixSocketPath.c_str()); - warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(), - strerror(errno)); + if (!sockAddr.isAbstract) { + int unlink_retv = unlink(sockAddr.addr.sun_path); + warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(), + strerror(errno)); + } int close_retv = close(serverFd); warn_if(close_retv != 0, "%s: cannot close unix socket: %s", name(), strerror(errno)); diff --git a/src/mem/shared_memory_server.hh b/src/mem/shared_memory_server.hh index 8f573fef3b..d9fbeb3f20 100644 --- a/src/mem/shared_memory_server.hh +++ b/src/mem/shared_memory_server.hh @@ -33,6 +33,7 @@ #include #include "base/pollevent.hh" +#include "base/socket.hh" #include "params/SharedMemoryServer.hh" #include "sim/sim_object.hh" #include "sim/system.hh" @@ -82,7 +83,7 @@ class SharedMemoryServer : public SimObject void process(int revent) override; }; - std::string unixSocketPath; + UnixSocketAddr sockAddr; System* system; int serverFd;