diff --git a/src/python/gem5/utils/looppoint.py b/src/python/gem5/utils/looppoint.py index d1851a8478..8e01e3030f 100644 --- a/src/python/gem5/utils/looppoint.py +++ b/src/python/gem5/utils/looppoint.py @@ -28,7 +28,7 @@ from m5.util import fatal from m5.params import PcCountPair from pathlib import Path -from typing import List, Dict +from typing import List, Dict, Tuple from gem5.components.processors.abstract_processor import AbstractProcessor from m5.objects import PcCountTrackerManager import csv @@ -330,76 +330,55 @@ class LoopPointCheckpoint(LoopPoint): class LoopPointRestore(LoopPoint): - def __init__(self, looppoint_file: Path, checkpoint_path: Path) -> None: + def __init__(self, looppoint_file: Path, region_id: int) -> None: """ - This class is specifically designed to take in the LoopPoint data file and - generator information needed to restore a checkpoint taken by the + This class is specifically designed to take in the LoopPoint data file + and generator information needed to restore a checkpoint taken by the LoopPointCheckPoint. :param looppoint_file: a json file generated by gem5 that has all the LoopPoint data information - :param checkpoint_path: the director of the checkpoint taken by the gem5 - standard library looppoint_save_checkpoint_generator - + :param region_id: The region ID we will be restoring to. """ - _json_file = {} - _targets = [] - _region_id = {} - - self.profile_restore( - looppoint_file, checkpoint_path, _targets, _json_file, _region_id - ) - - super().__init__( - _targets, - _region_id, - _json_file, - ) - - def profile_restore( - self, - looppoint_file_path: Path, - checkpoint_dir: Path, - targets: List[PcCountPair], - json_file: Dict[int, Dict], - region_id: Dict[PcCountPair, int], - ) -> None: - """ - This function is used to profile data from the LoopPoint data file to - information needed to restore the LoopPoint checkpoint - :param looppoint_file_path: the director of the LoopPoint data file - :param targets: a list of PcCountPair - :param json_file: a dictionary for all the LoopPoint data - :param region_id: a dictionary for all the significant PcCountPair and - its corresponding region id - """ - regex = re.compile(r"cpt.Region([0-9]+)") - rid = regex.findall(checkpoint_dir.as_posix())[0] - # finds out the region id from the directory name - with open(looppoint_file_path) as file: + with open(looppoint_file) as file: json_file = json.load(file) - if rid not in json_file: - # if the region id does not exist in the LoopPoint data file - # raise a fatal message - fatal(f"{rid} is not a valid region\n") - region = json_file[rid] - if "warmup" in region: - if "relative" not in region["simulation"]["start"]: - # if there are not relative counts for the PC Count pair - # then it means there is not enough information to restore - # this checkpoint - fatal(f"region {rid} doesn't have relative count info\n") - start = PcCountPair( - region["simulation"]["start"]["pc"], - region["simulation"]["start"]["relative"], - ) - region_id[start] = rid - targets.append(start) - if "relative" not in region["simulation"]["end"]: - fatal(f"region {rid} doesn't have relative count info\n") - end = PcCountPair( - region["simulation"]["end"]["pc"], - region["simulation"]["end"]["relative"], + + targets, regions = self.get_region( + json_file=json_file, region_id=region_id + ) + + super().__init__(targets=targets, regions=regions, json_file=json_file) + + def get_region( + self, json_file: Dict[int, Dict], region_id: int + ) -> Tuple[List[PcCountPair], Dict[PcCountPair, int]]: + to_return_region = {} + to_return_targets = [] + + if region_id not in json_file: + # if the region id does not exist in the LoopPoint data + # file raise a fatal message + fatal(f"{region_id} is not a valid region\n") + region = json_file[region_id] + if "warmup" in region: + if "relative" not in region["simulation"]["start"]: + # if there are not relative counts for the PC Count + # pair then it means there is not enough information to + # restore this checkpoint + fatal(f"region {region_id} doesn't have relative count info\n") + start = PcCountPair( + region["simulation"]["start"]["pc"], + region["simulation"]["start"]["relative"], ) - region_id[end] = rid - targets.append(end) + to_return_region[start] = region_id + to_return_targets.append(start) + if "relative" not in region["simulation"]["end"]: + fatal(f"region {region_id} doesn't have relative count info\n") + end = PcCountPair( + region["simulation"]["end"]["pc"], + region["simulation"]["end"]["relative"], + ) + to_return_region[end] = region_id + to_return_targets.append(end) + + return to_return_targets, to_return_region