python: Improve type annotations in pystats

This fixes some errors and warning when running mypy.

`gem5/src/python/m5/ext> mypy pystats`

There is one error that is ignored, which is a bug in mypy. See
https://github.com/python/mypy/issues/6040

Change-Id: I18b648c059da12bd30d612f0e265930b976f22b4
Signed-off-by: Jason Lowe-Power <jason@lowepower.com>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/42644
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Bobby R. Bruce <bbruce@ucdavis.edu>
Tested-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
Jason Lowe-Power
2021-03-09 11:42:09 -08:00
committed by Jason Lowe-Power
parent 2dfa2ddc6f
commit 91f4ea6683
3 changed files with 31 additions and 19 deletions

View File

@@ -25,7 +25,8 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import re
from typing import Callable, Dict, Iterator, List, Optional, Pattern, Union
from typing import Callable, Dict, Iterator, List, Mapping, Optional, Pattern,\
Union
from .jsonserializable import JsonSerializable
from .statistic import Scalar, Statistic
@@ -118,8 +119,10 @@ class Group(JsonSerializable):
precompiled regex or a string in regex format
"""
if isinstance(regex, str):
regex = re.compile(regex)
yield from self.children(lambda _name: regex.search(_name))
pattern = re.compile(regex)
else:
pattern = regex
yield from self.children(lambda _name: bool(pattern.search(_name)))
class Vector(Group):
"""
@@ -129,7 +132,7 @@ class Vector(Group):
accordance to decisions made in relation to
https://gem5.atlassian.net/browse/GEM5-867.
"""
def __init__(self, scalar_map: Dict[str,Scalar]):
def __init__(self, scalar_map: Mapping[str,Scalar]):
super(Vector, self).__init__(
type="Vector",
time_conversion=None,

View File

@@ -24,11 +24,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from json.decoder import JSONDecodeError
from .simstat import SimStat
from .statistic import Scalar, Distribution, Accumulator
from .statistic import Scalar, Distribution, Accumulator, Statistic
from .group import Group, Vector
import json
from typing import IO
from typing import IO, Union
class JsonLoader(json.JSONDecoder):
"""
@@ -46,9 +47,11 @@ class JsonLoader(json.JSONDecoder):
"""
def __init__(self):
json.JSONDecoder.__init__(self, object_hook=self.__json_to_simstat)
super(JsonLoader, self).__init__(self,
object_hook=self.__json_to_simstat
)
def __json_to_simstat(self, d: dict) -> SimStat:
def __json_to_simstat(self, d: dict) -> Union[SimStat,Statistic,Group]:
if 'type' in d:
if d['type'] == 'Scalar':
d.pop('type', None)
@@ -69,6 +72,11 @@ class JsonLoader(json.JSONDecoder):
d.pop('type', None)
d.pop('time_conversion', None)
return Vector(d)
else:
raise ValueError(
f"SimStat object has invalid type {d['type']}"
)
else:
return SimStat(**d)

View File

@@ -25,7 +25,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from abc import ABC
from typing import Any, Optional, Union, List
from typing import Any, Iterable, Optional, Union, List
from .jsonserializable import JsonSerializable
from .storagetype import StorageType
@@ -76,13 +76,13 @@ class BaseScalarVector(Statistic):
"""
value: List[Union[int,float]]
def __init__(self, value: List[Union[int,float]],
def __init__(self, value: Iterable[Union[int,float]],
type: Optional[str] = None,
unit: Optional[str] = None,
description: Optional[str] = None,
datatype: Optional[StorageType] = None):
super(BaseScalarVector, self).__init__(
value=value,
value=list(value),
type=type,
unit=unit,
description=description,
@@ -104,7 +104,7 @@ class BaseScalarVector(Statistic):
from statistics import mean as statistics_mean
return statistics_mean(self.value)
def count(self) -> int:
def count(self) -> float:
"""
Returns the count across all the bins.
@@ -114,7 +114,6 @@ class BaseScalarVector(Statistic):
The sum of all bin values.
"""
assert(self.value != None)
assert(isinstance(self.value, List))
return sum(self.value)
@@ -128,7 +127,6 @@ class Distribution(BaseScalarVector):
It is assumed each bucket is of equal size.
"""
value: List[int]
min: Union[float, int]
max: Union[float, int]
num_bins: int
@@ -139,7 +137,7 @@ class Distribution(BaseScalarVector):
overflow: Optional[int]
logs: Optional[float]
def __init__(self, value: List[int],
def __init__(self, value: Iterable[int],
min: Union[float, int],
max: Union[float, int],
num_bins: int,
@@ -179,12 +177,12 @@ class Accumulator(BaseScalarVector):
A statistical type representing an accumulator.
"""
count: int
_count: int
min: Union[int, float]
max: Union[int, float]
sum_squared: Optional[int]
def __init__(self, value: List[Union[int,float]],
def __init__(self, value: Iterable[Union[int,float]],
count: int,
min: Union[int, float],
max: Union[int, float],
@@ -200,7 +198,10 @@ class Accumulator(BaseScalarVector):
datatype=datatype,
)
self.count = count
self._count = count
self.min = min
self.max = max
self.sum_squared = sum_squared
self.sum_squared = sum_squared
def count(self) -> int:
return self._count