diff --git a/src/python/gem5/resources/client_api/abstract_client.py b/src/python/gem5/resources/client_api/abstract_client.py index 7f8ad6166e..0365b5ca60 100644 --- a/src/python/gem5/resources/client_api/abstract_client.py +++ b/src/python/gem5/resources/client_api/abstract_client.py @@ -30,26 +30,6 @@ import urllib.parse class AbstractClient(ABC): - def verify_status_code(self, status_code: int) -> None: - """ - Verifies that the status code is 200. - :param status_code: The status code to verify. - """ - if status_code == 200: - return - if status_code == 429: - raise Exception("Panic: Too many requests") - if status_code == 401: - raise Exception("Panic: Unauthorized") - if status_code == 404: - raise Exception("Panic: Not found") - if status_code == 400: - raise Exception("Panic: Bad request") - if status_code == 500: - raise Exception("Panic: Internal server error") - - raise Exception(f"Panic: Unknown status code {status_code}") - def _url_validator(self, url: str) -> bool: """ Validates the provided URL. diff --git a/src/python/gem5/resources/client_api/atlasclient.py b/src/python/gem5/resources/client_api/atlasclient.py index 7d2a27c3f7..511ef71fed 100644 --- a/src/python/gem5/resources/client_api/atlasclient.py +++ b/src/python/gem5/resources/client_api/atlasclient.py @@ -25,10 +25,39 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from urllib import request, parse -from urllib.error import HTTPError, URLError from typing import Optional, Dict, Union, Type, Tuple, List, Any import json +import time +import itertools from .abstract_client import AbstractClient +from urllib.error import HTTPError +from m5.util import warn + + +class AtlasClientHttpJsonRequestError(Exception): + def __init__( + self, + client: "AtlasClient", + data: Dict[str, Any], + purpose_of_request: Optional[str], + ): + """An exception raised when an HTTP request to Atlas MongoDB fails. + :param client: The AtlasClient instance that raised the exception. + :param purpose_of_request: A string describing the purpose of the + request. + """ + error_str = ( + f"Http Request to Atlas MongoDB failed.\n" + f"Atlas URL: {client.url}\n" + f"Auth URL: {client.authUrl}\n" + f"Database: {client.database}\n" + f"Collection: {client.collection}\n\n" + f"Data sent:\n\n{json.dumps(data,indent=4)}\n\n" + ) + + if purpose_of_request: + error_str += f"Purpose of Request: {purpose_of_request}\n\n" + super().__init__(error_str) class AtlasClient(AbstractClient): @@ -47,22 +76,69 @@ class AtlasClient(AbstractClient): self.authUrl = config["authUrl"] def get_token(self): - data = {"key": self.apiKey} - data = json.dumps(data).encode("utf-8") + return self._atlas_http_json_req( + self.authUrl, + data_json={"key": self.apiKey}, + headers={"Content-Type": "application/json"}, + purpose_of_request="Get Access Token with API key", + )["access_token"] + + def _atlas_http_json_req( + self, + url: str, + data_json: Dict[str, Any], + headers: Dict[str, str], + purpose_of_request: Optional[str], + max_failed_attempts: int = 4, + reattempt_pause_base: int = 2, + ) -> Dict[str, Any]: + """Sends a JSON object over HTTP to a given Atlas MongoDB server and + returns the response. This function will attempt to reconnect to the + server if the connection fails a set number of times before raising an + exception. + + :param url: The URL to open the connection. + :param data_json: The JSON object to send. + :param headers: The headers to send with the request. + :param purpose_of_request: A string describing the purpose of the + request. This is optional. It's used to give context to the user if an + exception is raised. + :param max_failed_attempts: The maximum number of times to an attempt + at making a request should be done before throwing an exception. + :param reattempt_pause_base: The base of the exponential backoff -- the + time between each attempt. + + **Warning**: This function assumes a JSON response. + """ + data = json.dumps(data_json).encode("utf-8") req = request.Request( - self.authUrl, + url, data=data, - headers={"Content-Type": "application/json"}, + headers=headers, ) - try: - response = request.urlopen(req) - except HTTPError as e: - self.verify_status_code(e.status) - return None - result = json.loads(response.read().decode("utf-8")) - token = result["access_token"] - return token + + for attempt in itertools.count(start=1): + try: + response = request.urlopen(req) + break + except HTTPError as e: + if attempt >= max_failed_attempts: + raise AtlasClientHttpJsonRequestError( + client=self, + data=data_json, + purpose_of_request=purpose_of_request, + ) + pause = reattempt_pause_base**attempt + warn( + f"Attempt {attempt} of Atlas HTTP Request failed.\n" + f"Purpose of Request: {purpose_of_request}.\n\n" + f"Failed with Exception:\n{e}\n\n" + f"Retrying after {pause} seconds..." + ) + time.sleep(pause) + + return json.loads(response.read().decode("utf-8")) def get_resources( self, @@ -84,21 +160,18 @@ class AtlasClient(AbstractClient): if filter: data["filter"] = filter - data = json.dumps(data).encode("utf-8") headers = { "Authorization": f"Bearer {self.get_token()}", "Content-Type": "application/json", } - req = request.Request(url, data=data, headers=headers) - try: - response = request.urlopen(req) - except HTTPError as e: - self.verify_status_code(e.status) - return None - result = json.loads(response.read().decode("utf-8")) - resources = result["documents"] + resources = self._atlas_http_json_req( + url, + data_json=data, + headers=headers, + purpose_of_request="Get Resources", + )["documents"] # I do this as a lazy post-processing step because I can't figure out # how to do this via an Atlas query, which may be more efficient. diff --git a/src/python/gem5/resources/client_api/client_wrapper.py b/src/python/gem5/resources/client_api/client_wrapper.py index ccb92cfb20..12030e1649 100644 --- a/src/python/gem5/resources/client_api/client_wrapper.py +++ b/src/python/gem5/resources/client_api/client_wrapper.py @@ -30,6 +30,7 @@ from _m5 import core from typing import Optional, Dict, List, Tuple import itertools from m5.util import warn +import sys class ClientWrapper: @@ -114,7 +115,12 @@ class ClientWrapper: self.clients[client].get_resources_by_id(resource_id) ) except Exception as e: - warn(f"Error getting resources from client {client}: {str(e)}") + print( + f"Exception thrown while getting resource '{resource_id}' " + f"from client '{client}'\n", + file=sys.stderr, + ) + raise e # check if no 2 resources have the same id and version for res1, res2 in itertools.combinations(resources, 2): if res1["resource_version"] == res2["resource_version"]: diff --git a/tests/pyunit/stdlib/resources/pyunit_client_wrapper_checks.py b/tests/pyunit/stdlib/resources/pyunit_client_wrapper_checks.py index f190b1ed5f..66b934a16f 100644 --- a/tests/pyunit/stdlib/resources/pyunit_client_wrapper_checks.py +++ b/tests/pyunit/stdlib/resources/pyunit_client_wrapper_checks.py @@ -34,6 +34,10 @@ import io import contextlib from pathlib import Path +from gem5.resources.client_api.atlasclient import ( + AtlasClientHttpJsonRequestError, +) + mock_json_path = Path(__file__).parent / "refs/resources.json" mock_config_json = { "sources": { @@ -419,21 +423,11 @@ class ClientWrapperTestSuite(unittest.TestCase): @patch("urllib.request.urlopen", side_effect=mocked_requests_post) def test_invalid_auth_url(self, mock_get): resource_id = "test-resource" - f = io.StringIO() - with self.assertRaises(Exception) as context: - with contextlib.redirect_stderr(f): - get_resource_json_obj( - resource_id, - gem5_version="develop", - ) - self.assertTrue( - "Error getting resources from client gem5-resources:" - " Panic: Not found" in str(f.getvalue()) - ) - self.assertTrue( - "Resource with ID 'test-resource' not found." - in str(context.exception) - ) + with self.assertRaises(AtlasClientHttpJsonRequestError) as context: + get_resource_json_obj( + resource_id, + gem5_version="develop", + ) @patch( "gem5.resources.client.clientwrapper", @@ -442,21 +436,11 @@ class ClientWrapperTestSuite(unittest.TestCase): @patch("urllib.request.urlopen", side_effect=mocked_requests_post) def test_invalid_url(self, mock_get): resource_id = "test-resource" - f = io.StringIO() - with self.assertRaises(Exception) as context: - with contextlib.redirect_stderr(f): - get_resource_json_obj( - resource_id, - gem5_version="develop", - ) - self.assertTrue( - "Error getting resources from client gem5-resources:" - " Panic: Not found" in str(f.getvalue()) - ) - self.assertTrue( - "Resource with ID 'test-resource' not found." - in str(context.exception) - ) + with self.assertRaises(AtlasClientHttpJsonRequestError) as context: + get_resource_json_obj( + resource_id, + gem5_version="develop", + ) @patch( "gem5.resources.client.clientwrapper", @@ -465,18 +449,8 @@ class ClientWrapperTestSuite(unittest.TestCase): @patch("urllib.request.urlopen", side_effect=mocked_requests_post) def test_invalid_url(self, mock_get): resource_id = "test-too-many" - f = io.StringIO() - with self.assertRaises(Exception) as context: - with contextlib.redirect_stderr(f): - get_resource_json_obj( - resource_id, - gem5_version="develop", - ) - self.assertTrue( - "Error getting resources from client gem5-resources:" - " Panic: Too many requests" in str(f.getvalue()) - ) - self.assertTrue( - "Resource with ID 'test-too-many' not found." - in str(context.exception) - ) + with self.assertRaises(AtlasClientHttpJsonRequestError) as context: + get_resource_json_obj( + resource_id, + gem5_version="develop", + )