From 3bdbe482c23369f2c19c4074d252858e35552341 Mon Sep 17 00:00:00 2001 From: Gabriel Busnot Date: Tue, 24 Jan 2023 09:59:30 +0000 Subject: [PATCH] base: Strengthen safe_cast and make it work for reference types safe_cast now supports the exact same types as dynamic_cast would. In particular, it now supports l-value references and rejects r-value references. The non-debug version has also been updated to make it build only in the same cases as the debug version of safe_cast would. Change-Id: I86692561c169b1ad063000c990a52ea80c6637ca Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/67453 Reviewed-by: Giacomo Travaglini Maintainer: Giacomo Travaglini Tested-by: kokoro --- src/base/cast.hh | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/base/cast.hh b/src/base/cast.hh index cdc3c624a7..01464d9b2c 100644 --- a/src/base/cast.hh +++ b/src/base/cast.hh @@ -30,6 +30,8 @@ #define __BASE_CAST_HH__ #include +#include +#include "base/logging.hh" namespace gem5 { @@ -44,10 +46,20 @@ namespace gem5 template inline T -safe_cast(U ptr) +safe_cast(U&& ref_or_ptr) { - T ret = dynamic_cast(ptr); - assert(ret); + /* + * srd::forward used in conjunction with forwarding references (template T + * + T&&) ensures that dynamic_cast will see the exact same type that was + * passed to safe_cast (a.k.a., perfect forwarding). + * + * Not using std::forward would make safe_cast compile with references to + * temporary objects and thus return a dangling reference. + */ + T ret = dynamic_cast(std::forward(ref_or_ptr)); + if constexpr (std::is_pointer_v) { + gem5_assert(ret); + } return ret; } @@ -59,9 +71,19 @@ safe_cast(U ptr) template inline T -safe_cast(U ptr) +safe_cast(U&& ref_or_ptr) { - return static_cast(ptr); + /* + * safe_cast should be reserved to polymorphic types while static_cast is + * also allowed for non-polymorphic types. It could make safe_cast able to + * compile in a non-debug build and fail in a debug build. + */ + static_assert(std::is_polymorphic_v< + std::remove_pointer_t< + std::remove_reference_t< + U>> + >); + return static_cast(std::forward(ref_or_ptr)); } #endif