stdlib,tests: Add Pyunit tests to check Pyunit nav, fix bugs

Bigs fixed of note:

1. The 'find' method has been fixed to work. This involved making
   'children' a class implemented per-subclass as required.
2. The 'get_all_stats_of_name' method has been removed. This was not
   working at all correctly and is largely doing what 'find' does.
2. The functionality to get an element in a vector via an attribute call
   (i.e., self.vector1 == self.vector[1]) has been implemented this
   maintaining backwards compatibility with the regular Python stats.

Change-Id: I31a4ccc723937018a3038dcdf491c82629ddbbb2
This commit is contained in:
Bobby R. Bruce
2024-05-29 23:47:46 -07:00
parent 2d4a213046
commit 7f0290985f
4 changed files with 374 additions and 45 deletions

View File

@@ -26,10 +26,12 @@
import re
from typing import (
Any,
Callable,
List,
Optional,
Pattern,
Tuple,
Union,
)
@@ -60,20 +62,11 @@ class AbstractStat(SerializableStat):
If it returns ``True``, then the child is yielded.
Otherwise, the child is skipped. If not provided then
all children are returned.
Note: This is method must be implemented in AbstractStat subclasses
which have children, otherwise it will return an empty list.
"""
to_return = []
for attr in self.__dict__:
obj = getattr(self, attr)
if isinstance(obj, AbstractStat):
if (predicate and predicate(attr)) or not predicate:
to_return.append(obj)
if recursive:
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
return []
def find(self, regex: Union[str, Pattern]) -> List["AbstractStat"]:
"""Find all stats that match the name, recursively through all the
@@ -100,5 +93,51 @@ class AbstractStat(SerializableStat):
lambda _name: re.match(pattern, _name), recursive=True
)
def _get_vector_item(self, item: str) -> Optional[Tuple[str, int, Any]]:
"""It has been the case in gem5 that SimObject vectors are stored as
strings such as "cpu0" or "cpu1". This function splits the string into
the SimObject name and index, (e.g.: ["cpu", 0] and ["cpu", 1]) and
returns the item for that name and it's index. If the string cannot be
split into a SimObject name and index, or if the SimObject does not
exit at `Simobject[index]`, the function returns None.
"""
regex = re.compile("[0-9]+$")
match = regex.search(item)
if not match:
return None
match_str = match.group()
assert match_str.isdigit(), f"Regex match must be a digit: {match_str}"
vector_index = int(match_str)
vector_name = item[: (-1 * len(match_str))]
if hasattr(self, vector_name):
vector = getattr(self, vector_name)
try:
vector_value = vector[vector_index]
return vector_name, vector_index, vector_value
except KeyError:
pass
return None
def __iter__(self):
return iter(self.__dict__)
def __getattr__(self, item: str) -> Any:
vector_item = self._get_vector_item(item)
if not vector_item:
return None
assert (
len(vector_item) == 3
), f"Vector item must have 3 elements: {vector_item}"
return vector_item[2]
def __getitem__(self, item: str):
return getattr(self, item)
def __contains__(self, item: Any) -> bool:
return (
isinstance(item, str) and self._get_vector_item(item)
) or hasattr(self, item)

View File

@@ -26,6 +26,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@@ -62,6 +63,23 @@ class Group(AbstractStat):
for key, value in kwargs.items():
setattr(self, key, value)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.__dict__:
obj = getattr(self, attr)
if isinstance(obj, AbstractStat):
if (predicate and predicate(attr)) or not predicate:
to_return.append(obj)
if recursive:
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
class SimObjectGroup(Group):
"""A group of statistics encapulated within a SimObject."""
@@ -93,30 +111,22 @@ class SimObjectVectorGroup(Group):
def __len__(self):
return len(self.value)
def get_all_stats_of_name(self, name: str) -> List[AbstractStat]:
"""
Get all the stats in the vector of that name. Useful for performing
operations on all the stats of the same name in a vector.
"""
to_return = []
for stat in self.value:
if hasattr(stat, name):
to_return.append(getattr(stat, name))
# If the name is in the format "sim.bla.whatever", we are looking for
# the "bla.whatever" stats in the "sim" group.
# This is messy, but it works.
name_split = name.split(".")
if len(name_split) == 1:
return to_return
if name_split[0] not in self:
return to_return
to_return.extend(
self[name_split[0]].get_all_stats_of_name(".".join(name_split[1:]))
)
return to_return
def __getitem__(self, item: int):
return self.value[item]
def __contains__(self, item):
if isinstance(item, int):
return item >= 0 and item < len(self)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for child in self.value:
to_return = to_return + child.children(
predicate=predicate, recursive=recursive
)
return to_return

View File

@@ -27,8 +27,10 @@
from abc import ABC
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
)
@@ -102,16 +104,32 @@ class Vector(Statistic):
description=description,
)
def __getitem__(self, index: Union[int, str, float]) -> Scalar:
def __getitem__(self, item: Union[int, str, float]) -> Scalar:
assert self.value != None
# In the case of string, we cast strings to integers of floats if they
# are numeric. This avoids users having to cast strings to integers.
if isinstance(index, str):
if index.isindex():
index = int(index)
elif index.isnumeric():
index = float(index)
return self.value[index]
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return self.value[item]
def __contains__(self, item) -> bool:
assert self.value != None
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return item in self.value
def __iner__(self) -> None:
return iter(self.value)
def __len__(self) -> int:
assert self.value != None
return len(self.value.values())
def size(self) -> int:
"""
@@ -141,6 +159,26 @@ class Vector(Statistic):
assert self.value != None
return sum(float(self.value[key]) for key in self.values)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.value.keys():
obj = self.value[attr]
if isinstance(obj, AbstractStat):
if (
isinstance(attr, str)
and (predicate and predicate(attr))
or not predicate
):
to_return.append(obj)
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
class Vector2d(Statistic):
"""
@@ -179,6 +217,12 @@ class Vector2d(Statistic):
"""Returns the total number of elements."""
return self.x_size() * self.y_size()
def __len__(self) -> int:
return self.x_size()
def __iter__(self):
return iter(self.keys())
def total(self) -> int:
"""The total (sum) of all the entries in the 2d vector/"""
assert self.value is not None
@@ -199,6 +243,34 @@ class Vector2d(Statistic):
index = float(index)
return self.value[index]
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.value.keys():
obj = self.value[attr]
if (
isinstance(attr, str)
and (predicate and predicate(attr))
or not predicate
):
to_return.append(obj)
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
def __contains__(self, item) -> bool:
assert self.value is not None
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return item in self.value
class Distribution(Vector):
"""