base: socket: add UnixSocketAddr for representing socket paths

Added UnixSocketAddr that wraps around sockaddr_un. Using this
wrapper, users can create both file based sockets as well as
abstract sockets.

Change-Id: Ibf105b92a6a6ac7fc9136ed307f824c83e45c06c
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/66471
Maintainer: Gabe Black <gabe.black@gmail.com>
Reviewed-by: Gabe Black <gabe.black@gmail.com>
Tested-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
Simon Park
2023-01-03 01:07:03 -08:00
parent 313f557b93
commit 15cb9c7abe
7 changed files with 216 additions and 32 deletions

View File

@@ -68,7 +68,8 @@ Source('pollevent.cc')
Source('random.cc')
Source('remote_gdb.cc')
Source('socket.cc')
GTest('socket.test', 'socket.test.cc', 'socket.cc')
SourceLib('z', tags='socket_test')
GTest('socket.test', 'socket.test.cc', 'socket.cc', 'output.cc', with_tag('socket_test'))
Source('statistics.cc')
Source('str.cc', add_tags=['gem5 trace', 'gem5 serialize'])
GTest('str.test', 'str.test.cc', 'str.cc')

View File

@@ -35,22 +35,88 @@
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include <cerrno>
#include "base/logging.hh"
#include "base/output.hh"
#include "base/str.hh"
#include "base/types.hh"
#include "sim/byteswap.hh"
namespace gem5
{
namespace
{
bool
isSocketNameAbstract(const std::string &path)
{
if (path.empty()) {
return false;
}
// No null byte should be present in the path
return path.front() == '@';
}
std::string
resolve(const std::string &path)
{
if (path.empty()) {
return path;
}
if (isSocketNameAbstract(path)) {
return '\0' + path.substr(1);
}
return simout.resolve(path);
}
} // namespace
bool ListenSocket::listeningDisabled = false;
bool ListenSocket::anyListening = false;
bool ListenSocket::bindToLoopback = false;
UnixSocketAddr
UnixSocketAddr::build(const std::string &path)
{
sockaddr_un addr = {.sun_family = AF_UNIX, .sun_path = {}};
const bool is_abstract = isSocketNameAbstract(path);
size_t max_len = sizeof(addr.sun_path);
if (!is_abstract) {
// File based socket names need to be null terminated
max_len -= 1;
}
std::string resolved_path = resolve(path);
std::string fmt_path = replace(resolved_path, '\0', '@');
if (resolved_path.size() > max_len) {
resolved_path = resolved_path.substr(0, max_len);
const std::string untruncated_path = std::move(fmt_path);
fmt_path = replace(resolved_path, '\0', '@');
warn("SocketPath: unix socket path truncated from '%s' to '%s'",
untruncated_path, fmt_path);
}
// We can't use strncpy here, since abstract sockets start with \0 which
// will make strncpy think that the string is empty.
memcpy(addr.sun_path, resolved_path.c_str(), resolved_path.size());
// We can't use sizeof(sockaddr_un) for abstract sockets, since all
// sizeof(sun_path) bytes are used in representing the path.
const size_t path_size =
is_abstract ? resolved_path.size() : sizeof(addr.sun_path);
const size_t addr_size = offsetof(sockaddr_un, sun_path) + path_size;
return UnixSocketAddr{.addr = std::move(addr),
.addrSize = addr_size,
.isAbstract = is_abstract,
.formattedPath = std::move(fmt_path)};
}
void
ListenSocket::cleanup()
{

View File

@@ -31,10 +31,44 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <string>
namespace gem5
{
/**
* @brief Wrapper around sockaddr_un, so that it can be used for both file
* based unix sockets as well as abstract unix sockets.
*/
struct UnixSocketAddr
{
/**
* @brief Builds UnixSocketAddr from the given path.
* @pre: `path` either represents a file based unix socket, or an abstract
* unix socket. If `path` represents an abstract socket, it should
* start with the character '@', and it should not have any null
* bytes in the name.
* @param path: Pathname, where the socket should be instantiated.
* @return UnixSocketAddr
*/
static UnixSocketAddr build(const std::string &path);
sockaddr_un addr;
// Size of `sockaddr_un addr`. This is equal to sizeof(sockaddr_un) if
// `addr` represents a normal file based unix socket. For abstract sockets
// however, the size could be different. Because all sizeof(sun_path) is
// used to represent the name of an abstract socket, addrSize for abstract
// sockets only count the number of characters actually used by sun_path,
// excluding any trailing null bytes.
size_t addrSize;
bool isAbstract;
// Formatted string for file based sockets look the same as addr.sun_path.
// For abstract sockets however, all null bytes are replaced with @
std::string formattedPath;
};
class ListenSocket
{
protected:

View File

@@ -28,6 +28,10 @@
#include <gtest/gtest.h>
#include <cstring>
#include <sstream>
#include <utility>
#include "base/gtest/logging.hh"
#include "base/socket.hh"
@@ -41,6 +45,79 @@ using namespace gem5;
* socket.cc have not been fully tested due to interaction with system-calls.
*/
namespace {
std::string
repeat(const std::string& str, size_t n)
{
std::stringstream ss;
for (int i = 0; i < n; ++i) {
ss << str;
}
return ss.str();
}
} // namespace
TEST(UnixSocketAddrTest, AbstractSocket)
{
UnixSocketAddr sock_addr = UnixSocketAddr::build("@abstract");
EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family);
// null byte will not show, so compare from the first byte
EXPECT_STREQ("abstract", sock_addr.addr.sun_path + 1);
EXPECT_TRUE(sock_addr.isAbstract);
EXPECT_STREQ("@abstract", sock_addr.formattedPath.c_str());
}
TEST(UnixSocketAddrTest, TruncatedAbstractSocket)
{
// Test that address is truncated if longer than sizeof(sun_path)
constexpr size_t MaxSize = sizeof(std::declval<sockaddr_un>().sun_path);
// >sizeof(sun_path) bytes
std::string addr = "@" + repeat("123456789", 100);
ASSERT_GT(addr.size(), MaxSize);
std::string truncated_addr = addr.substr(0, MaxSize);
UnixSocketAddr sock_addr = UnixSocketAddr::build(addr);
EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family);
// Use memcmp so that we can compare null bytes as well
std::string null_formatted = '\0' + truncated_addr.substr(1);
EXPECT_EQ(0, std::memcmp(null_formatted.c_str(), sock_addr.addr.sun_path,
MaxSize));
EXPECT_TRUE(sock_addr.isAbstract);
EXPECT_EQ(truncated_addr, sock_addr.formattedPath);
}
TEST(UnixSocketAddrTest, FileBasedSocket)
{
std::string addr = "/home/parent/dir/x";
UnixSocketAddr sock_addr = UnixSocketAddr::build(addr);
EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family);
EXPECT_STREQ(addr.c_str(), sock_addr.addr.sun_path);
EXPECT_FALSE(sock_addr.isAbstract);
EXPECT_EQ(addr, sock_addr.formattedPath);
}
TEST(UnixSocketAddrTest, TruncatedFileBasedSocket)
{
// sun_path should null terminate, so test that address is truncated if
// longer than sizeof(sun_path) - 1 bytes.
constexpr size_t MaxSize =
sizeof(std::declval<sockaddr_un>().sun_path) - 1;
// >sizeof(sun_path) - 1 bytes
std::string addr = "/" + repeat("123456789", 100);
ASSERT_GT(addr.size(), MaxSize);
std::string truncated_addr = addr.substr(0, MaxSize);
UnixSocketAddr sock_addr = UnixSocketAddr::build(addr);
EXPECT_EQ(AF_UNIX, sock_addr.addr.sun_family);
EXPECT_STREQ(truncated_addr.c_str(), sock_addr.addr.sun_path);
EXPECT_FALSE(sock_addr.isAbstract);
EXPECT_EQ(truncated_addr, sock_addr.formattedPath);
}
class MockListenSocket : public ListenSocket
{
public:

View File

@@ -32,6 +32,7 @@
#ifndef __BASE_STR_HH__
#define __BASE_STR_HH__
#include <algorithm>
#include <cstring>
#include <limits>
#include <locale>
@@ -251,6 +252,14 @@ startswith(const std::string &s, const std::string &prefix)
return (s.compare(0, prefix.size(), prefix) == 0);
}
inline std::string
replace(const std::string &s, char from, char to)
{
std::string replaced = s;
std::replace(replaced.begin(), replaced.end(), from, to);
return replaced;
}
} // namespace gem5
#endif //__BASE_STR_HH__

View File

@@ -34,7 +34,6 @@
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include <algorithm>
@@ -44,7 +43,6 @@
#include "base/logging.hh"
#include "base/output.hh"
#include "base/pollevent.hh"
#include "base/socket.hh"
namespace gem5
{
@@ -52,51 +50,49 @@ namespace memory
{
SharedMemoryServer::SharedMemoryServer(const SharedMemoryServerParams& params)
: SimObject(params), unixSocketPath(simout.resolve(params.server_path)),
system(params.system), serverFd(-1)
: SimObject(params),
sockAddr(UnixSocketAddr::build(params.server_path)),
system(params.system),
serverFd(-1)
{
fatal_if(system == nullptr, "Requires a system to share memory from!");
// Create a new unix socket.
serverFd = ListenSocket::socketCloexec(AF_UNIX, SOCK_STREAM, 0);
panic_if(serverFd < 0, "%s: cannot create unix socket: %s", name(),
strerror(errno));
// Bind to the specified path.
sockaddr_un serv_addr = {};
serv_addr.sun_family = AF_UNIX;
strncpy(serv_addr.sun_path, unixSocketPath.c_str(),
sizeof(serv_addr.sun_path) - 1);
// If the target path is truncated, warn the user that the actual path is
// different and update the target path.
if (strlen(serv_addr.sun_path) != unixSocketPath.size()) {
warn("%s: unix socket path truncated, expect '%s' but get '%s'",
name(), unixSocketPath, serv_addr.sun_path);
unixSocketPath = serv_addr.sun_path;
const auto& [serv_addr, addr_size, is_abstract, formatted_path] = sockAddr;
if (!is_abstract) {
// Ensure the unix socket path to use is not occupied. Also, if there's
// actually anything to be removed, warn the user something might be
// off.
bool old_sock_removed = unlink(serv_addr.sun_path) == 0;
warn_if(old_sock_removed,
"%s: server path %s was occupied and will be replaced. Please "
"make sure there is no other server using the same path.",
name(), serv_addr.sun_path);
}
// Ensure the unix socket path to use is not occupied. Also, if there's
// actually anything to be removed, warn the user something might be off.
bool old_sock_removed = unlink(unixSocketPath.c_str()) == 0;
warn_if(old_sock_removed,
"%s: the server path %s was occupied and will be replaced. Please "
"make sure there is no other server using the same path.",
name(), unixSocketPath);
int bind_retv = bind(serverFd, reinterpret_cast<sockaddr*>(&serv_addr),
sizeof(serv_addr));
fatal_if(bind_retv != 0, "%s: cannot bind unix socket: %s", name(),
strerror(errno));
int bind_retv = bind(
serverFd, reinterpret_cast<const sockaddr*>(&serv_addr), addr_size);
fatal_if(bind_retv != 0, "%s: cannot bind unix socket '%s': %s", name(),
formatted_path, strerror(errno));
// Start listening.
int listen_retv = listen(serverFd, 1);
fatal_if(listen_retv != 0, "%s: listen failed: %s", name(),
strerror(errno));
listenSocketEvent.reset(new ListenSocketEvent(serverFd, this));
pollQueue.schedule(listenSocketEvent.get());
inform("%s: listening at %s", name(), unixSocketPath);
inform("%s: listening at %s", name(), formatted_path);
}
SharedMemoryServer::~SharedMemoryServer()
{
int unlink_retv = unlink(unixSocketPath.c_str());
warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(),
strerror(errno));
if (!sockAddr.isAbstract) {
int unlink_retv = unlink(sockAddr.addr.sun_path);
warn_if(unlink_retv != 0, "%s: cannot unlink unix socket: %s", name(),
strerror(errno));
}
int close_retv = close(serverFd);
warn_if(close_retv != 0, "%s: cannot close unix socket: %s", name(),
strerror(errno));

View File

@@ -33,6 +33,7 @@
#include <unordered_map>
#include "base/pollevent.hh"
#include "base/socket.hh"
#include "params/SharedMemoryServer.hh"
#include "sim/sim_object.hh"
#include "sim/system.hh"
@@ -82,7 +83,7 @@ class SharedMemoryServer : public SimObject
void process(int revent) override;
};
std::string unixSocketPath;
UnixSocketAddr sockAddr;
System* system;
int serverFd;