diff --git a/src/python/gem5/components/boards/abstract_board.py b/src/python/gem5/components/boards/abstract_board.py index 4ea8866009..aba080e239 100644 --- a/src/python/gem5/components/boards/abstract_board.py +++ b/src/python/gem5/components/boards/abstract_board.py @@ -28,7 +28,7 @@ from abc import ABCMeta, abstractmethod import inspect from .mem_mode import MemMode, mem_mode_to_string -from ...resources.workload import AbstractWorkload +from ...resources.resource import WorkloadResource from m5.objects import ( AddrRange, @@ -198,7 +198,7 @@ class AbstractBoard: ) return self._is_fs - def set_workload(self, workload: AbstractWorkload) -> None: + def set_workload(self, workload: WorkloadResource) -> None: """ Set the workload for this board to run. diff --git a/src/python/gem5/resources/resource.py b/src/python/gem5/resources/resource.py index bc9f4480ba..2f50f36631 100644 --- a/src/python/gem5/resources/resource.py +++ b/src/python/gem5/resources/resource.py @@ -35,7 +35,7 @@ from .downloader import get_resource from .looppoint import LooppointCsvLoader, LooppointJsonLoader from ..isas import ISA, get_isa_from_str -from typing import Optional, Dict, Union, Type, Tuple, List +from typing import Optional, Dict, Union, Type, Tuple, List, Any from .client import get_resource_json_obj @@ -554,6 +554,79 @@ class SimpointDirectoryResource(SimpointResource): return simpoint_list, weight_list +class WorkloadResource(AbstractResource): + """A workload resource. This resource is used to specify a workload to run + on a board. It contains the function to call and the parameters to pass to + that function. + """ + + def __init__( + self, + resource_version: Optional[str] = None, + function: str = None, + resoucres: Dict[str, Dict[str, str]] = None, + additional_params: Dict[str, str] = None, + description: Optional[str] = None, + source: Optional[str] = None, + local_path: Optional[str] = None, + ): + """ + :param function: The function to call on the board. + :param parameters: The parameters to pass to the function. + """ + + super().__init__( + local_path=local_path, + description=description, + source=source, + resource_version=resource_version, + ) + + self._func = function + self._params = {} + for key in resoucres.keys(): + assert isinstance(key, str) + value = resoucres[key] + assert isinstance(value, dict) + self._params[key] = obtain_resource( + value["id"], + resource_version=value["resource_version"], + ) + for key in additional_params.keys(): + assert isinstance(key, str) + value = additional_params[key] + assert isinstance(value, str) + self._params[key] = value + + def get_function_str(self) -> str: + """ + Returns the name of the workload function to be run. + + This function is called via the AbstractBoard's `set_workload` + function. The parameters from the `get_parameters` function are passed + to this function. + """ + return self._func + + def get_parameters(self) -> Dict[str, Any]: + """ + Returns a dictionary mapping the workload parameters to their values. + + These parameters are passed to the function specified by + `get_function_str` via the AbstractBoard's `set_workload` function. + """ + return self._params + + def set_parameter(self, parameter: str, value: Any) -> None: + """ + Used to set or override a workload parameter + + :param parameter: The parameter of the function to set. + :param value: The value to set to the parameter. + """ + self._params[parameter] = value + + def obtain_resource( resource_id: str, resource_directory: Optional[str] = None, @@ -640,7 +713,7 @@ def obtain_resource( # Obtain the type from the JSON. From this we will determine what subclass # of `AbstractResource` we are to create and return. resources_category = resource_json["category"] - + print(resource_json) if resources_category == "resource": # This is a stop-gap measure to ensure to work with older versions of # the "resource.json" file. These should be replaced with their @@ -812,4 +885,5 @@ _get_resource_json_type_map = { "resource": Resource, "looppoint-pinpoint-csv": LooppointCsvResource, "looppoint-json": LooppointJsonResource, + "workload": WorkloadResource, }