stdlib,tests: Add Pyunit tests to check Pyunit nav, fix bugs

Bigs fixed of note:

1. The 'find' method has been fixed to work. This involved making
   'children' a class implemented per-subclass as required.
2. The 'get_all_stats_of_name' method has been removed. This was not
   working at all correctly and is largely doing what 'find' does.
2. The functionality to get an element in a vector via an attribute call
   (i.e., self.vector1 == self.vector[1]) has been implemented this
   maintaining backwards compatibility with the regular Python stats.

Change-Id: I31a4ccc723937018a3038dcdf491c82629ddbbb2
This commit is contained in:
Bobby R. Bruce
2024-05-29 23:47:46 -07:00
parent 2d4a213046
commit 7f0290985f
4 changed files with 374 additions and 45 deletions

View File

@@ -26,10 +26,12 @@
import re
from typing import (
Any,
Callable,
List,
Optional,
Pattern,
Tuple,
Union,
)
@@ -60,20 +62,11 @@ class AbstractStat(SerializableStat):
If it returns ``True``, then the child is yielded.
Otherwise, the child is skipped. If not provided then
all children are returned.
Note: This is method must be implemented in AbstractStat subclasses
which have children, otherwise it will return an empty list.
"""
to_return = []
for attr in self.__dict__:
obj = getattr(self, attr)
if isinstance(obj, AbstractStat):
if (predicate and predicate(attr)) or not predicate:
to_return.append(obj)
if recursive:
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
return []
def find(self, regex: Union[str, Pattern]) -> List["AbstractStat"]:
"""Find all stats that match the name, recursively through all the
@@ -100,5 +93,51 @@ class AbstractStat(SerializableStat):
lambda _name: re.match(pattern, _name), recursive=True
)
def _get_vector_item(self, item: str) -> Optional[Tuple[str, int, Any]]:
"""It has been the case in gem5 that SimObject vectors are stored as
strings such as "cpu0" or "cpu1". This function splits the string into
the SimObject name and index, (e.g.: ["cpu", 0] and ["cpu", 1]) and
returns the item for that name and it's index. If the string cannot be
split into a SimObject name and index, or if the SimObject does not
exit at `Simobject[index]`, the function returns None.
"""
regex = re.compile("[0-9]+$")
match = regex.search(item)
if not match:
return None
match_str = match.group()
assert match_str.isdigit(), f"Regex match must be a digit: {match_str}"
vector_index = int(match_str)
vector_name = item[: (-1 * len(match_str))]
if hasattr(self, vector_name):
vector = getattr(self, vector_name)
try:
vector_value = vector[vector_index]
return vector_name, vector_index, vector_value
except KeyError:
pass
return None
def __iter__(self):
return iter(self.__dict__)
def __getattr__(self, item: str) -> Any:
vector_item = self._get_vector_item(item)
if not vector_item:
return None
assert (
len(vector_item) == 3
), f"Vector item must have 3 elements: {vector_item}"
return vector_item[2]
def __getitem__(self, item: str):
return getattr(self, item)
def __contains__(self, item: Any) -> bool:
return (
isinstance(item, str) and self._get_vector_item(item)
) or hasattr(self, item)

View File

@@ -26,6 +26,7 @@
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@@ -62,6 +63,23 @@ class Group(AbstractStat):
for key, value in kwargs.items():
setattr(self, key, value)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.__dict__:
obj = getattr(self, attr)
if isinstance(obj, AbstractStat):
if (predicate and predicate(attr)) or not predicate:
to_return.append(obj)
if recursive:
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
class SimObjectGroup(Group):
"""A group of statistics encapulated within a SimObject."""
@@ -93,30 +111,22 @@ class SimObjectVectorGroup(Group):
def __len__(self):
return len(self.value)
def get_all_stats_of_name(self, name: str) -> List[AbstractStat]:
"""
Get all the stats in the vector of that name. Useful for performing
operations on all the stats of the same name in a vector.
"""
to_return = []
for stat in self.value:
if hasattr(stat, name):
to_return.append(getattr(stat, name))
# If the name is in the format "sim.bla.whatever", we are looking for
# the "bla.whatever" stats in the "sim" group.
# This is messy, but it works.
name_split = name.split(".")
if len(name_split) == 1:
return to_return
if name_split[0] not in self:
return to_return
to_return.extend(
self[name_split[0]].get_all_stats_of_name(".".join(name_split[1:]))
)
return to_return
def __getitem__(self, item: int):
return self.value[item]
def __contains__(self, item):
if isinstance(item, int):
return item >= 0 and item < len(self)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for child in self.value:
to_return = to_return + child.children(
predicate=predicate, recursive=recursive
)
return to_return

View File

@@ -27,8 +27,10 @@
from abc import ABC
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
)
@@ -102,16 +104,32 @@ class Vector(Statistic):
description=description,
)
def __getitem__(self, index: Union[int, str, float]) -> Scalar:
def __getitem__(self, item: Union[int, str, float]) -> Scalar:
assert self.value != None
# In the case of string, we cast strings to integers of floats if they
# are numeric. This avoids users having to cast strings to integers.
if isinstance(index, str):
if index.isindex():
index = int(index)
elif index.isnumeric():
index = float(index)
return self.value[index]
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return self.value[item]
def __contains__(self, item) -> bool:
assert self.value != None
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return item in self.value
def __iner__(self) -> None:
return iter(self.value)
def __len__(self) -> int:
assert self.value != None
return len(self.value.values())
def size(self) -> int:
"""
@@ -141,6 +159,26 @@ class Vector(Statistic):
assert self.value != None
return sum(float(self.value[key]) for key in self.values)
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.value.keys():
obj = self.value[attr]
if isinstance(obj, AbstractStat):
if (
isinstance(attr, str)
and (predicate and predicate(attr))
or not predicate
):
to_return.append(obj)
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
class Vector2d(Statistic):
"""
@@ -179,6 +217,12 @@ class Vector2d(Statistic):
"""Returns the total number of elements."""
return self.x_size() * self.y_size()
def __len__(self) -> int:
return self.x_size()
def __iter__(self):
return iter(self.keys())
def total(self) -> int:
"""The total (sum) of all the entries in the 2d vector/"""
assert self.value is not None
@@ -199,6 +243,34 @@ class Vector2d(Statistic):
index = float(index)
return self.value[index]
def children(
self,
predicate: Optional[Callable[[str], bool]] = None,
recursive: bool = False,
) -> List["AbstractStat"]:
to_return = []
for attr in self.value.keys():
obj = self.value[attr]
if (
isinstance(attr, str)
and (predicate and predicate(attr))
or not predicate
):
to_return.append(obj)
to_return = to_return + obj.children(
predicate=predicate, recursive=True
)
return to_return
def __contains__(self, item) -> bool:
assert self.value is not None
if isinstance(item, str):
if item.isdigit():
item = int(item)
elif item.isnumeric():
item = float(item)
return item in self.value
class Distribution(Vector):
"""

View File

@@ -0,0 +1,208 @@
# Copyright (c) 2024 The Regents of The University of California
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met: redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer;
# redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution;
# neither the name of the copyright holders nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import unittest
from datetime import datetime
from m5.ext.pystats import (
Distribution,
Scalar,
SimObjectGroup,
SimObjectVectorGroup,
SimStat,
SparseHist,
Vector,
Vector2d,
)
def _get_mock_simstat() -> SimStat:
"""Used to create a mock SimStat for testing.
This SimStat is contains all simstat Statistic types and attempts to use
most of the different types of values that can be stored in a Statistic.
"""
simobject_vector_group = SimObjectVectorGroup(
value=[
SimObjectGroup(
**{
"vector2d": Vector2d(
value={
0: Vector(
value={
"a": Scalar(value=1, description="one"),
"b": Scalar(value=2.0, description="two"),
"c": Scalar(value=-3, description="three"),
}
),
1: Vector(
value={
1: Scalar(value=4),
0.2: Scalar(value=5.0),
0.3: Scalar(value=6),
},
description="vector 1",
),
},
description="vector 2d",
),
}
),
SimObjectGroup(
**{
"distribution": Distribution(
value={
0: Scalar(1),
1: Scalar(2),
2: Scalar(3),
3: Scalar(4),
4: Scalar(5),
},
min=0,
max=4,
num_bins=5,
bin_size=1,
),
"sparse_hist": SparseHist(
value={
0.5: Scalar(4),
0.51: Scalar(1),
0.511: Scalar(4),
5: Scalar(2),
},
description="sparse hist",
),
},
),
],
)
return SimStat(
creation_time=datetime.fromisoformat("2021-01-01T00:00:00"),
time_conversion=None,
simulated_begin_time=123,
simulated_end_time=558644,
simobject_vector=simobject_vector_group,
)
class NavigatingPyStatsTestCase(unittest.TestCase):
"""A test case for navigating the PyStats data structure, primarily
on how to access children of a SimStat object, and the "find" methods to
search for a specific statistic.
"""
def setUp(self) -> None:
"""Overrides the setUp method to create a mock SimStat for testing.
Runs before each test method.
"""
self.failFast = True
self.simstat = _get_mock_simstat()
super().setUp()
def test_simstat_index(self):
self.assertTrue("simobject_vector" in self.simstat)
self.assertIsInstance(
self.simstat["simobject_vector"], SimObjectVectorGroup
)
def test_simstat_attribute(self):
self.assertTrue(hasattr(self.simstat, "simobject_vector"))
self.assertIsInstance(
self.simstat.simobject_vector, SimObjectVectorGroup
)
def test_simobject_vector_attribute(self):
# To maintan compatibility with the old way of accessing the vector,
# the simobject vectors values can be accessed by attributes of that
# simoobject vector name and the index appended to it.
# E.g., `simstat.simobject_vector0 is the same
# is simstat.simobject_vector[0]`. In cases where there is already
# an attribute with the same name as the vector+index, the attribute
# will be returned.
self.assertEqual(
self.simstat.simobject_vector0, self.simstat.simobject_vector[0]
)
def test_simobject_vector_index(self):
self.assertTrue(self.simstat.simobject_vector[0], SimObjectGroup)
def test_simobject_group_index(self):
self.assertTrue("vector2d" in self.simstat.simobject_vector[0])
self.assertIsInstance(
self.simstat.simobject_vector[0]["vector2d"], Vector2d
)
def test_simobject_group_attribute(self):
self.assertTrue(hasattr(self.simstat.simobject_vector[0], "vector2d"))
self.assertIsInstance(
self.simstat.simobject_vector[0].vector2d, Vector2d
)
def test_vector2d_index(self):
self.assertEqual(2, len(self.simstat.simobject_vector[0]["vector2d"]))
self.assertTrue(0 in self.simstat.simobject_vector[0].vector2d)
self.assertIsInstance(
self.simstat.simobject_vector[0].vector2d[0], Vector
)
def test_vector_index_int(self):
self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[1]))
self.assertTrue(1 in self.simstat.simobject_vector[0].vector2d[1])
self.assertIsInstance(
self.simstat.simobject_vector[0].vector2d[1][1], Scalar
)
def test_vector_index_str(self):
self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[0]))
self.assertTrue("a" in self.simstat.simobject_vector[0].vector2d[0])
self.assertIsInstance(
self.simstat.simobject_vector[0].vector2d[0]["a"], Scalar
)
def test_vector_index_float(self):
self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[1]))
self.assertTrue(0.2 in self.simstat.simobject_vector[0].vector2d[1])
self.assertIsInstance(
self.simstat.simobject_vector[0].vector2d[1][0.2], Scalar
)
def test_distriibution_index(self):
self.assertTrue(0 in self.simstat.simobject_vector[1]["distribution"])
self.assertIsInstance(
self.simstat.simobject_vector[1]["distribution"][0], Scalar
)
def test_sparse_hist_index(self):
self.assertTrue(0.5 in self.simstat.simobject_vector[1]["sparse_hist"])
self.assertIsInstance(
self.simstat.simobject_vector[1]["sparse_hist"][0.5], Scalar
)
def test_pystat_find(self):
self.assertEqual(
self.simstat.find("sparse_hist"),
[self.simstat.simobject_vector[1]["sparse_hist"]],
)