diff --git a/src/python/gem5/resources/resource.py b/src/python/gem5/resources/resource.py index b5ece9dbb0..6966c8a7c6 100644 --- a/src/python/gem5/resources/resource.py +++ b/src/python/gem5/resources/resource.py @@ -27,6 +27,7 @@ import json import os from abc import ABCMeta +from functools import partial from pathlib import Path from typing import ( Any, @@ -94,6 +95,7 @@ class AbstractResource: local_path: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, ): """ :param local_path: The path on the host system where this resource is @@ -104,18 +106,17 @@ class AbstractResource: string should navigate users to where the source for this resource may be found. Not a required parameter. By default is None. :param resource_version: Version of the resource itself. + :param downloader: A partial function which is used to download the + resource. If set, this is called if the resource is not present at the + specified `local_path`. """ - if local_path and not os.path.exists(local_path): - raise Exception( - f"Local path specified for resource, '{local_path}', does not " - "exist." - ) self._id = id self._local_path = local_path self._description = description self._source = source self._version = resource_version + self._downloader = downloader def get_category_name(cls) -> str: raise NotImplementedError @@ -134,7 +135,19 @@ class AbstractResource: return self._version def get_local_path(self) -> Optional[str]: - """Returns the local path of the resource.""" + """Returns the local path of the resource. + + If specified the `downloader` partial function is called to download + the resource if it is not present or up-to-date at the specified + `local_path`. + """ + if self._downloader: + self._downloader() + if self._local_path and not os.path.exists(self._local_path): + raise Exception( + f"Local path specified for resource, '{self._local_path}', " + "does not exist." + ) return self._local_path def get_description(self) -> Optional[str]: @@ -158,24 +171,34 @@ class FileResource(AbstractResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): - if not os.path.isfile(local_path): - raise Exception( - f"FileResource path specified, '{local_path}', is not a file." - ) - super().__init__( local_path=local_path, id=id, description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: return "FileResource" + def get_local_path(self) -> Optional[str]: + # Here we override get_local_path to ensure the file exists. + file_path = super().get_local_path() + + if not file_path: + raise Exception("FileResource path not specified.") + + if not os.path.isfile(file_path): + raise Exception( + f"FileResource path specified, '{file_path}', is not a file." + ) + return file_path + class DirectoryResource(AbstractResource): """A resource consisting of a directory.""" @@ -187,25 +210,35 @@ class DirectoryResource(AbstractResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): - if not os.path.isdir(local_path): - raise Exception( - f"DirectoryResource path specified, {local_path}, is not a " - "directory." - ) - super().__init__( local_path=local_path, id=id, description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: return "DirectoryResource" + def get_local_path(self) -> Optional[str]: + # Here we override get_local_path to ensure the directory exists. + dir_path = super().get_local_path() + + if not dir_path: + raise Exception("DirectoryResource path not specified.") + + if not os.path.isdir(dir_path): + raise Exception( + f"DirectoryResource path specified, {dir_path}, is not a " + "directory." + ) + return dir_path + class DiskImageResource(FileResource): """A Disk Image resource.""" @@ -217,6 +250,7 @@ class DiskImageResource(FileResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, root_partition: Optional[str] = None, **kwargs, ): @@ -226,6 +260,7 @@ class DiskImageResource(FileResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._root_partition = root_partition @@ -247,6 +282,7 @@ class BinaryResource(FileResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -256,6 +292,7 @@ class BinaryResource(FileResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._architecture = None @@ -283,6 +320,7 @@ class BootloaderResource(BinaryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -293,6 +331,7 @@ class BootloaderResource(BinaryResource): architecture=architecture, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -309,6 +348,7 @@ class GitResource(DirectoryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): super().__init__( @@ -317,6 +357,7 @@ class GitResource(DirectoryResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -333,6 +374,7 @@ class KernelResource(BinaryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, architecture: Optional[Union[ISA, str]] = None, **kwargs, ): @@ -343,6 +385,7 @@ class KernelResource(BinaryResource): source=source, architecture=architecture, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -364,6 +407,7 @@ class CheckpointResource(DirectoryResource): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): super().__init__( @@ -372,6 +416,7 @@ class CheckpointResource(DirectoryResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) def get_category_name(cls) -> str: @@ -396,6 +441,7 @@ class SimpointResource(AbstractResource): workload_name: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, local_path: Optional[str] = None, **kwargs, ): @@ -417,6 +463,7 @@ class SimpointResource(AbstractResource): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) self._weight_list = weight_list @@ -510,6 +557,7 @@ class LooppointCsvResource(FileResource, LooppointCsvLoader): resource_version: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): FileResource.__init__( @@ -519,6 +567,7 @@ class LooppointCsvResource(FileResource, LooppointCsvLoader): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) LooppointCsvLoader.__init__(self, pinpoints_file=Path(local_path)) @@ -535,6 +584,7 @@ class LooppointJsonResource(FileResource, LooppointJsonLoader): region_id: Optional[Union[str, int]] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): FileResource.__init__( @@ -544,6 +594,7 @@ class LooppointJsonResource(FileResource, LooppointJsonLoader): description=description, source=source, resource_version=resource_version, + downloader=downloader, ) LooppointJsonLoader.__init__( self, looppoint_file=local_path, region_id=region_id @@ -569,6 +620,7 @@ class SimpointDirectoryResource(SimpointResource): workload_name: Optional[str] = None, description: Optional[str] = None, source: Optional[str] = None, + downloader: Optional[partial] = None, **kwargs, ): """ @@ -601,6 +653,7 @@ class SimpointDirectoryResource(SimpointResource): id=id, description=description, source=source, + downloader=downloader, resource_version=resource_version, ) @@ -868,6 +921,10 @@ def obtain_resource( gem5_version=gem5_version, ) + # This is is used to store the partial function which is used to download + # the resource when the `get_local_path` function is called. + downloader: Optional[partial] = None + # If the "url" field is specified, the resoruce must be downloaded. if "url" in resource_json and resource_json["url"]: # If the `to_path` parameter is set, we use that as the path to which @@ -917,7 +974,8 @@ def obtain_resource( ) # Download the resource if it does not already exist. - get_resource( + downloader = partial( + get_resource, resource_name=resource_id, to_path=to_path, download_md5_mismatch=download_md5_mismatch, @@ -941,9 +999,10 @@ def obtain_resource( return DiskImageResource( local_path=to_path, root_partition=root_partition, + downloader=downloader, **resource_json, ) - return CustomResource(local_path=to_path) + return CustomResource(local_path=to_path, downloader=downloader) assert resources_category in _get_resource_json_type_map resource_class = _get_resource_json_type_map[resources_category] @@ -991,7 +1050,9 @@ def obtain_resource( # Once we know what AbstractResource subclass we are using, we create it. # The fields in the JSON object are assumed to map like-for-like to the # subclass contructor, so we can pass the resource_json map directly. - return resource_class(local_path=to_path, **resource_json) + return resource_class( + local_path=to_path, downloader=downloader, **resource_json + ) def _get_default_resource_dir() -> str: