diff --git a/src/python/m5/ext/pystats/abstract_stat.py b/src/python/m5/ext/pystats/abstract_stat.py index 19932a1e29..98adda314b 100644 --- a/src/python/m5/ext/pystats/abstract_stat.py +++ b/src/python/m5/ext/pystats/abstract_stat.py @@ -26,10 +26,12 @@ import re from typing import ( + Any, Callable, List, Optional, Pattern, + Tuple, Union, ) @@ -60,20 +62,11 @@ class AbstractStat(SerializableStat): If it returns ``True``, then the child is yielded. Otherwise, the child is skipped. If not provided then all children are returned. + + Note: This is method must be implemented in AbstractStat subclasses + which have children, otherwise it will return an empty list. """ - - to_return = [] - for attr in self.__dict__: - obj = getattr(self, attr) - if isinstance(obj, AbstractStat): - if (predicate and predicate(attr)) or not predicate: - to_return.append(obj) - if recursive: - to_return = to_return + obj.children( - predicate=predicate, recursive=True - ) - - return to_return + return [] def find(self, regex: Union[str, Pattern]) -> List["AbstractStat"]: """Find all stats that match the name, recursively through all the @@ -100,5 +93,51 @@ class AbstractStat(SerializableStat): lambda _name: re.match(pattern, _name), recursive=True ) + def _get_vector_item(self, item: str) -> Optional[Tuple[str, int, Any]]: + """It has been the case in gem5 that SimObject vectors are stored as + strings such as "cpu0" or "cpu1". This function splits the string into + the SimObject name and index, (e.g.: ["cpu", 0] and ["cpu", 1]) and + returns the item for that name and it's index. If the string cannot be + split into a SimObject name and index, or if the SimObject does not + exit at `Simobject[index]`, the function returns None. + """ + regex = re.compile("[0-9]+$") + match = regex.search(item) + if not match: + return None + + match_str = match.group() + + assert match_str.isdigit(), f"Regex match must be a digit: {match_str}" + vector_index = int(match_str) + vector_name = item[: (-1 * len(match_str))] + + if hasattr(self, vector_name): + vector = getattr(self, vector_name) + try: + vector_value = vector[vector_index] + return vector_name, vector_index, vector_value + except KeyError: + pass + return None + + def __iter__(self): + return iter(self.__dict__) + + def __getattr__(self, item: str) -> Any: + vector_item = self._get_vector_item(item) + if not vector_item: + return None + + assert ( + len(vector_item) == 3 + ), f"Vector item must have 3 elements: {vector_item}" + return vector_item[2] + def __getitem__(self, item: str): return getattr(self, item) + + def __contains__(self, item: Any) -> bool: + return ( + isinstance(item, str) and self._get_vector_item(item) + ) or hasattr(self, item) diff --git a/src/python/m5/ext/pystats/group.py b/src/python/m5/ext/pystats/group.py index 7bdc5523c9..d1808221e7 100644 --- a/src/python/m5/ext/pystats/group.py +++ b/src/python/m5/ext/pystats/group.py @@ -26,6 +26,7 @@ from typing import ( Any, + Callable, Dict, List, Optional, @@ -62,6 +63,23 @@ class Group(AbstractStat): for key, value in kwargs.items(): setattr(self, key, value) + def children( + self, + predicate: Optional[Callable[[str], bool]] = None, + recursive: bool = False, + ) -> List["AbstractStat"]: + to_return = [] + for attr in self.__dict__: + obj = getattr(self, attr) + if isinstance(obj, AbstractStat): + if (predicate and predicate(attr)) or not predicate: + to_return.append(obj) + if recursive: + to_return = to_return + obj.children( + predicate=predicate, recursive=True + ) + return to_return + class SimObjectGroup(Group): """A group of statistics encapulated within a SimObject.""" @@ -93,30 +111,22 @@ class SimObjectVectorGroup(Group): def __len__(self): return len(self.value) - def get_all_stats_of_name(self, name: str) -> List[AbstractStat]: - """ - Get all the stats in the vector of that name. Useful for performing - operations on all the stats of the same name in a vector. - """ - to_return = [] - for stat in self.value: - if hasattr(stat, name): - to_return.append(getattr(stat, name)) - - # If the name is in the format "sim.bla.whatever", we are looking for - # the "bla.whatever" stats in the "sim" group. - # This is messy, but it works. - name_split = name.split(".") - if len(name_split) == 1: - return to_return - - if name_split[0] not in self: - return to_return - - to_return.extend( - self[name_split[0]].get_all_stats_of_name(".".join(name_split[1:])) - ) - return to_return - def __getitem__(self, item: int): return self.value[item] + + def __contains__(self, item): + if isinstance(item, int): + return item >= 0 and item < len(self) + + def children( + self, + predicate: Optional[Callable[[str], bool]] = None, + recursive: bool = False, + ) -> List["AbstractStat"]: + to_return = [] + for child in self.value: + to_return = to_return + child.children( + predicate=predicate, recursive=recursive + ) + + return to_return diff --git a/src/python/m5/ext/pystats/statistic.py b/src/python/m5/ext/pystats/statistic.py index b3a8d3aa12..fb98ceb93e 100644 --- a/src/python/m5/ext/pystats/statistic.py +++ b/src/python/m5/ext/pystats/statistic.py @@ -27,8 +27,10 @@ from abc import ABC from typing import ( Any, + Callable, Dict, Iterable, + List, Optional, Union, ) @@ -102,16 +104,32 @@ class Vector(Statistic): description=description, ) - def __getitem__(self, index: Union[int, str, float]) -> Scalar: + def __getitem__(self, item: Union[int, str, float]) -> Scalar: assert self.value != None # In the case of string, we cast strings to integers of floats if they # are numeric. This avoids users having to cast strings to integers. - if isinstance(index, str): - if index.isindex(): - index = int(index) - elif index.isnumeric(): - index = float(index) - return self.value[index] + if isinstance(item, str): + if item.isdigit(): + item = int(item) + elif item.isnumeric(): + item = float(item) + return self.value[item] + + def __contains__(self, item) -> bool: + assert self.value != None + if isinstance(item, str): + if item.isdigit(): + item = int(item) + elif item.isnumeric(): + item = float(item) + return item in self.value + + def __iner__(self) -> None: + return iter(self.value) + + def __len__(self) -> int: + assert self.value != None + return len(self.value.values()) def size(self) -> int: """ @@ -141,6 +159,26 @@ class Vector(Statistic): assert self.value != None return sum(float(self.value[key]) for key in self.values) + def children( + self, + predicate: Optional[Callable[[str], bool]] = None, + recursive: bool = False, + ) -> List["AbstractStat"]: + to_return = [] + for attr in self.value.keys(): + obj = self.value[attr] + if isinstance(obj, AbstractStat): + if ( + isinstance(attr, str) + and (predicate and predicate(attr)) + or not predicate + ): + to_return.append(obj) + to_return = to_return + obj.children( + predicate=predicate, recursive=True + ) + return to_return + class Vector2d(Statistic): """ @@ -179,6 +217,12 @@ class Vector2d(Statistic): """Returns the total number of elements.""" return self.x_size() * self.y_size() + def __len__(self) -> int: + return self.x_size() + + def __iter__(self): + return iter(self.keys()) + def total(self) -> int: """The total (sum) of all the entries in the 2d vector/""" assert self.value is not None @@ -199,6 +243,34 @@ class Vector2d(Statistic): index = float(index) return self.value[index] + def children( + self, + predicate: Optional[Callable[[str], bool]] = None, + recursive: bool = False, + ) -> List["AbstractStat"]: + to_return = [] + for attr in self.value.keys(): + obj = self.value[attr] + if ( + isinstance(attr, str) + and (predicate and predicate(attr)) + or not predicate + ): + to_return.append(obj) + to_return = to_return + obj.children( + predicate=predicate, recursive=True + ) + return to_return + + def __contains__(self, item) -> bool: + assert self.value is not None + if isinstance(item, str): + if item.isdigit(): + item = int(item) + elif item.isnumeric(): + item = float(item) + return item in self.value + class Distribution(Vector): """ diff --git a/tests/pyunit/pystats/pyunit_pystats.py b/tests/pyunit/pystats/pyunit_pystats.py new file mode 100644 index 0000000000..70839c245b --- /dev/null +++ b/tests/pyunit/pystats/pyunit_pystats.py @@ -0,0 +1,208 @@ +# Copyright (c) 2024 The Regents of The University of California +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer; +# redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution; +# neither the name of the copyright holders nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import unittest +from datetime import datetime + +from m5.ext.pystats import ( + Distribution, + Scalar, + SimObjectGroup, + SimObjectVectorGroup, + SimStat, + SparseHist, + Vector, + Vector2d, +) + + +def _get_mock_simstat() -> SimStat: + """Used to create a mock SimStat for testing. + This SimStat is contains all simstat Statistic types and attempts to use + most of the different types of values that can be stored in a Statistic. + """ + simobject_vector_group = SimObjectVectorGroup( + value=[ + SimObjectGroup( + **{ + "vector2d": Vector2d( + value={ + 0: Vector( + value={ + "a": Scalar(value=1, description="one"), + "b": Scalar(value=2.0, description="two"), + "c": Scalar(value=-3, description="three"), + } + ), + 1: Vector( + value={ + 1: Scalar(value=4), + 0.2: Scalar(value=5.0), + 0.3: Scalar(value=6), + }, + description="vector 1", + ), + }, + description="vector 2d", + ), + } + ), + SimObjectGroup( + **{ + "distribution": Distribution( + value={ + 0: Scalar(1), + 1: Scalar(2), + 2: Scalar(3), + 3: Scalar(4), + 4: Scalar(5), + }, + min=0, + max=4, + num_bins=5, + bin_size=1, + ), + "sparse_hist": SparseHist( + value={ + 0.5: Scalar(4), + 0.51: Scalar(1), + 0.511: Scalar(4), + 5: Scalar(2), + }, + description="sparse hist", + ), + }, + ), + ], + ) + + return SimStat( + creation_time=datetime.fromisoformat("2021-01-01T00:00:00"), + time_conversion=None, + simulated_begin_time=123, + simulated_end_time=558644, + simobject_vector=simobject_vector_group, + ) + + +class NavigatingPyStatsTestCase(unittest.TestCase): + """A test case for navigating the PyStats data structure, primarily + on how to access children of a SimStat object, and the "find" methods to + search for a specific statistic. + """ + + def setUp(self) -> None: + """Overrides the setUp method to create a mock SimStat for testing. + Runs before each test method. + """ + self.failFast = True + self.simstat = _get_mock_simstat() + super().setUp() + + def test_simstat_index(self): + self.assertTrue("simobject_vector" in self.simstat) + self.assertIsInstance( + self.simstat["simobject_vector"], SimObjectVectorGroup + ) + + def test_simstat_attribute(self): + self.assertTrue(hasattr(self.simstat, "simobject_vector")) + self.assertIsInstance( + self.simstat.simobject_vector, SimObjectVectorGroup + ) + + def test_simobject_vector_attribute(self): + # To maintan compatibility with the old way of accessing the vector, + # the simobject vectors values can be accessed by attributes of that + # simoobject vector name and the index appended to it. + # E.g., `simstat.simobject_vector0 is the same + # is simstat.simobject_vector[0]`. In cases where there is already + # an attribute with the same name as the vector+index, the attribute + # will be returned. + self.assertEqual( + self.simstat.simobject_vector0, self.simstat.simobject_vector[0] + ) + + def test_simobject_vector_index(self): + self.assertTrue(self.simstat.simobject_vector[0], SimObjectGroup) + + def test_simobject_group_index(self): + self.assertTrue("vector2d" in self.simstat.simobject_vector[0]) + self.assertIsInstance( + self.simstat.simobject_vector[0]["vector2d"], Vector2d + ) + + def test_simobject_group_attribute(self): + self.assertTrue(hasattr(self.simstat.simobject_vector[0], "vector2d")) + self.assertIsInstance( + self.simstat.simobject_vector[0].vector2d, Vector2d + ) + + def test_vector2d_index(self): + self.assertEqual(2, len(self.simstat.simobject_vector[0]["vector2d"])) + self.assertTrue(0 in self.simstat.simobject_vector[0].vector2d) + self.assertIsInstance( + self.simstat.simobject_vector[0].vector2d[0], Vector + ) + + def test_vector_index_int(self): + self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[1])) + self.assertTrue(1 in self.simstat.simobject_vector[0].vector2d[1]) + self.assertIsInstance( + self.simstat.simobject_vector[0].vector2d[1][1], Scalar + ) + + def test_vector_index_str(self): + self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[0])) + self.assertTrue("a" in self.simstat.simobject_vector[0].vector2d[0]) + self.assertIsInstance( + self.simstat.simobject_vector[0].vector2d[0]["a"], Scalar + ) + + def test_vector_index_float(self): + self.assertEqual(3, len(self.simstat.simobject_vector[0].vector2d[1])) + self.assertTrue(0.2 in self.simstat.simobject_vector[0].vector2d[1]) + self.assertIsInstance( + self.simstat.simobject_vector[0].vector2d[1][0.2], Scalar + ) + + def test_distriibution_index(self): + self.assertTrue(0 in self.simstat.simobject_vector[1]["distribution"]) + self.assertIsInstance( + self.simstat.simobject_vector[1]["distribution"][0], Scalar + ) + + def test_sparse_hist_index(self): + self.assertTrue(0.5 in self.simstat.simobject_vector[1]["sparse_hist"]) + self.assertIsInstance( + self.simstat.simobject_vector[1]["sparse_hist"][0.5], Scalar + ) + + def test_pystat_find(self): + self.assertEqual( + self.simstat.find("sparse_hist"), + [self.simstat.simobject_vector[1]["sparse_hist"]], + )