diff --git a/src/python/gem5/resources/resource.py b/src/python/gem5/resources/resource.py index 88c7617478..ea830b2a2a 100644 --- a/src/python/gem5/resources/resource.py +++ b/src/python/gem5/resources/resource.py @@ -27,6 +27,7 @@ import json import os from abc import ABCMeta +from functools import partial from pathlib import Path from typing import ( Any, @@ -96,6 +97,7 @@ class AbstractResource: local_path: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, ): """ :param local_path: The path on the host system where this resource is @@ -107,18 +109,21 @@ class AbstractResource: resource may be found. Not a required parameter. By default is ``None``. :param resource_version: Version of the resource itself. + :param downloader: A partial function which is used to download the + resource. If set, this is called if the resource is not present at the + specified `local_path`. """ - if local_path and not os.path.exists(local_path): - raise Exception( - f"Local path specified for resource, '{local_path}', does not " - "exist." - ) self._id = id self._local_path = local_path self._description = description self._source = source self._version = resource_version + self._downloader = downloader + + def get_id(self) -> str: + """Returns the ID of the resource.""" + return self._id def get_category_name(cls) -> str: raise NotImplementedError @@ -137,7 +142,19 @@ class AbstractResource: return self._version def get_local_path(self) -> Optional[str]: - """Returns the local path of the resource.""" + """Returns the local path of the resource. + + If specified the `downloader` partial function is called to download + the resource if it is not present or up-to-date at the specified + `local_path`. + """ + if self._downloader: + self._downloader() + if self._local_path and not os.path.exists(self._local_path): + raise Exception( + f"Local path specified for resource, '{self._local_path}', " + "does not exist." + ) return self._local_path def get_description(self) -> Optional[str]: @@ -161,24 +178,34 @@ class FileResource(AbstractResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): - if not os.path.isfile(local_path): - raise Exception( - f"FileResource path specified, '{local_path}', is not a file." - ) - super().__init__( local_path=local_path, id=id, description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: return "FileResource" + def get_local_path(self) -> Optional[str]: + # Here we override get_local_path to ensure the file exists. + file_path = super().get_local_path() + + if not file_path: + raise Exception("FileResource path not specified.") + + if not os.path.isfile(file_path): + raise Exception( + f"FileResource path specified, '{file_path}', is not a file." + ) + return file_path + class DirectoryResource(AbstractResource): """A resource consisting of a directory.""" @@ -190,25 +217,35 @@ class DirectoryResource(AbstractResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): - if not os.path.isdir(local_path): - raise Exception( - f"DirectoryResource path specified, {local_path}, is not a " - "directory." - ) - super().__init__( local_path=local_path, id=id, description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: return "DirectoryResource" + def get_local_path(self) -> Optional[str]: + # Here we override get_local_path to ensure the directory exists. + dir_path = super().get_local_path() + + if not dir_path: + raise Exception("DirectoryResource path not specified.") + + if not os.path.isdir(dir_path): + raise Exception( + f"DirectoryResource path specified, {dir_path}, is not a " + "directory." + ) + return dir_path + class DiskImageResource(FileResource): """A Disk Image resource.""" @@ -220,6 +257,7 @@ class DiskImageResource(FileResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, root_partition: Optional[str] = None, **kwargs, ): @@ -229,6 +267,7 @@ class DiskImageResource(FileResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._root_partition = root_partition @@ -250,6 +289,7 @@ class BinaryResource(FileResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -259,6 +299,7 @@ class BinaryResource(FileResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._architecture = None @@ -286,6 +327,7 @@ class BootloaderResource(BinaryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -296,6 +338,7 @@ class BootloaderResource(BinaryResource): architecture=architecture, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -312,6 +355,7 @@ class GitResource(DirectoryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): super().__init__( @@ -320,6 +364,7 @@ class GitResource(DirectoryResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -336,6 +381,7 @@ class KernelResource(BinaryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -346,6 +392,7 @@ class KernelResource(BinaryResource): source=source, architecture=architecture, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -367,6 +414,7 @@ class CheckpointResource(DirectoryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): super().__init__( @@ -375,6 +423,7 @@ class CheckpointResource(DirectoryResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -399,6 +448,7 @@ class SimpointResource(AbstractResource): workload_name: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, local_path: Optional[str] = None, **kwargs, ): @@ -422,6 +472,7 @@ class SimpointResource(AbstractResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._weight_list = weight_list @@ -515,6 +566,7 @@ class LooppointCsvResource(FileResource, LooppointCsvLoader): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): FileResource.__init__( @@ -524,6 +576,7 @@ class LooppointCsvResource(FileResource, LooppointCsvLoader): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) LooppointCsvLoader.__init__(self, pinpoints_file=Path(local_path)) @@ -540,6 +593,7 @@ class LooppointJsonResource(FileResource, LooppointJsonLoader): region_id: Optional[Union[str, int]] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): FileResource.__init__( @@ -549,6 +603,7 @@ class LooppointJsonResource(FileResource, LooppointJsonLoader): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) LooppointJsonLoader.__init__( self, looppoint_file=local_path, region_id=region_id @@ -574,6 +629,7 @@ class SimpointDirectoryResource(SimpointResource): workload_name: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): """ @@ -606,6 +662,7 @@ class SimpointDirectoryResource(SimpointResource): id=id, description=description, source=source, + downloader=downloader, resource_version=resource_version, ) @@ -751,6 +808,40 @@ class SuiteResource(AbstractResource): } +class ShadowResource(AbstractResource): + """A special resource class which delays the `obtain_resource` call. It is, + in a sense, half constructed. Only when a function or attribute is called + which is is neither `get_id` or `get_resource_version` does this class + fully construct itself by calling the `obtain_resource_call` partial + function. + """ + + def __init__( + self, + id: str, + resource_version: str, + obtain_resource_call: partial, + ): + super().__init__( + id=id, + resource_version=resource_version, + ) + self._workload: Optional[AbstractResource] = None + self._obtain_resource_call = obtain_resource_call + + def __getattr__(self, attr): + """if getting the id or resource version, we keep the object in the + "shdow state" where the `obtain_resource` function has not been called. + When more information is needed by calling another attribute, we call + the `obtain_resource` function and store the result in the `_workload`. + """ + if attr in {"get_id", "get_resource_version"}: + return getattr(super(), attr) + if not self._workload: + self._workload = self._obtain_resource_call() + return getattr(self._workload, attr) + + class WorkloadResource(AbstractResource): """A workload resource. This resource is used to specify a workload to run on a board. It contains the function to call and the parameters to pass to @@ -873,6 +964,10 @@ def obtain_resource( gem5_version=gem5_version, ) + # This is is used to store the partial function which is used to download + # the resource when the `get_local_path` function is called. + downloader: Optional[partial] = None + # If the "url" field is specified, the resoruce must be downloaded. if "url" in resource_json and resource_json["url"]: # If the `to_path` parameter is set, we use that as the path to which @@ -922,7 +1017,8 @@ def obtain_resource( ) # Download the resource if it does not already exist. - get_resource( + downloader = partial( + get_resource, resource_name=resource_id, to_path=to_path, download_md5_mismatch=download_md5_mismatch, @@ -946,9 +1042,10 @@ def obtain_resource( return DiskImageResource( local_path=to_path, root_partition=root_partition, + downloader=downloader, **resource_json, ) - return CustomResource(local_path=to_path) + return CustomResource(local_path=to_path, downloader=downloader) assert resources_category in _get_resource_json_type_map resource_class = _get_resource_json_type_map[resources_category] @@ -958,12 +1055,17 @@ def obtain_resource( workloads_obj = {} for workload in workloads: workloads_obj[ - obtain_resource( - workload["id"], + ShadowResource( + id=workload["id"], resource_version=workload["resource_version"], - resource_directory=resource_directory, - clients=clients, - gem5_version=gem5_version, + obtain_resource_call=partial( + obtain_resource, + workload["id"], + resource_version=workload["resource_version"], + resource_directory=resource_directory, + clients=clients, + gem5_version=gem5_version, + ), ) ] = set(workload["input_group"]) resource_json["workloads"] = workloads_obj @@ -996,7 +1098,9 @@ def obtain_resource( # Once we know what AbstractResource subclass we are using, we create it. # The fields in the JSON object are assumed to map like-for-like to the # subclass contructor, so we can pass the resource_json map directly. - return resource_class(local_path=to_path, **resource_json) + return resource_class( + local_path=to_path, downloader=downloader, **resource_json + ) def _get_default_resource_dir() -> str: