diff --git a/sickle/app.py b/sickle/app.py index 5beacb6..6fa3b1b 100644 --- a/sickle/app.py +++ b/sickle/app.py @@ -12,9 +12,9 @@ import time import requests + from sickle.iterator import BaseOAIIterator, OAIItemIterator from sickle.response import OAIResponse - from .models import (Set, Record, Header, MetadataFormat, Identify) @@ -52,8 +52,15 @@ class Sickle(object): :type protocol_version: str :param iterator: The type of the returned iterator (default: :class:`sickle.iterator.OAIItemIterator`) - :param max_retries: Number of retries if HTTP request fails. + :param max_retries: Number of retry attempts if an HTTP request fails (default: 0 = request only once). Sickle will + use the value from the retry-after header (if present) and will wait the specified number of + seconds between retries. :type max_retries: int + :param retry_status_codes: HTTP status codes to retry (default will only retry on 503) + :type retry_status_codes: iterable + :param default_retry_after: default number of seconds to wait between retries in case no retry-after header is found + on the response (defaults to 60 seconds) + :type default_retry_after: int :type protocol_version: str :param class_mapping: A dictionary that maps OAI verbs to classes representing OAI items. If not provided, @@ -73,9 +80,17 @@ class Sickle(object): for all available parameters. """ - def __init__(self, endpoint, http_method='GET', protocol_version='2.0', - iterator=OAIItemIterator, max_retries=5, - class_mapping=None, encoding=None, **request_args): + def __init__(self, endpoint, + http_method='GET', + protocol_version='2.0', + iterator=OAIItemIterator, + max_retries=0, + retry_status_codes=None, + default_retry_after=60, + class_mapping=None, + encoding=None, + **request_args): + self.endpoint = endpoint if http_method not in ['GET', 'POST']: raise ValueError("Invalid HTTP method: %s! Must be GET or POST.") @@ -90,6 +105,8 @@ def __init__(self, endpoint, http_method='GET', protocol_version='2.0', raise TypeError( "Argument 'iterator' must be subclass of %s" % BaseOAIIterator.__name__) self.max_retries = max_retries + self.retry_status_codes = retry_status_codes or [503] + self.default_retry_after = default_retry_after self.oai_namespace = OAI_NAMESPACE % self.protocol_version self.class_mapping = class_mapping or DEFAULT_CLASS_MAP self.encoding = encoding @@ -101,26 +118,24 @@ def harvest(self, **kwargs): # pragma: no cover :param kwargs: OAI HTTP parameters. :rtype: :class:`sickle.OAIResponse` """ + http_response = self._request(kwargs) for _ in range(self.max_retries): - if self.http_method == 'GET': - http_response = requests.get(self.endpoint, params=kwargs, - **self.request_args) - else: - http_response = requests.post(self.endpoint, data=kwargs, - **self.request_args) - if http_response.status_code == 503: - try: - retry_after = int(http_response.headers.get('retry-after')) - except TypeError: - retry_after = 20 - logger.info( - "HTTP 503! Retrying after %d seconds..." % retry_after) + if self._is_error_code(http_response.status_code) \ + and http_response.status_code in self.retry_status_codes: + retry_after = self.get_retry_after(http_response) + logger.warning( + "HTTP %d! Retrying after %d seconds..." % (http_response.status_code, retry_after)) time.sleep(retry_after) - else: - http_response.raise_for_status() - if self.encoding: - http_response.encoding = self.encoding - return OAIResponse(http_response, params=kwargs) + http_response = self._request(kwargs) + http_response.raise_for_status() + if self.encoding: + http_response.encoding = self.encoding + return OAIResponse(http_response, params=kwargs) + + def _request(self, kwargs): + if self.http_method == 'GET': + return requests.get(self.endpoint, params=kwargs, **self.request_args) + return requests.post(self.endpoint, data=kwargs, **self.request_args) def ListRecords(self, ignore_deleted=False, **kwargs): """Issue a ListRecords request. @@ -178,3 +193,15 @@ def ListMetadataFormats(self, **kwargs): params = kwargs params.update({'verb': 'ListMetadataFormats'}) return self.iterator(self, params) + + def get_retry_after(self, http_response): + if http_response.status_code == 503: + try: + return int(http_response.headers.get('retry-after')) + except TypeError: + return self.default_retry_after + return self.default_retry_after + + @staticmethod + def _is_error_code(status_code): + return status_code >= 400 diff --git a/sickle/tests/test_sickle.py b/sickle/tests/test_sickle.py index ab94063..62b75f5 100644 --- a/sickle/tests/test_sickle.py +++ b/sickle/tests/test_sickle.py @@ -10,6 +10,8 @@ from mock import patch, Mock from nose.tools import raises +from requests import HTTPError + from sickle import Sickle this_dir, this_filename = os.path.split(__file__) @@ -29,7 +31,7 @@ def test_invalid_iterator(self): Sickle("http://localhost", iterator=None) def test_pass_request_args(self): - mock_response = Mock(text=u'', content='') + mock_response = Mock(text=u'', content='', status_code=200) mock_get = Mock(return_value=mock_response) with patch('sickle.app.requests.get', mock_get): sickle = Sickle('url', timeout=10, proxies=dict(), @@ -41,9 +43,56 @@ def test_pass_request_args(self): auth=('user', 'password')) def test_override_encoding(self): - mock_response = Mock(text='', content='') + mock_response = Mock(text='', content='', status_code=200) mock_get = Mock(return_value=mock_response) with patch('sickle.app.requests.get', mock_get): sickle = Sickle('url', encoding='encoding') sickle.ListSets() - self.assertEqual(mock_response.encoding, 'encoding') + mock_get.assert_called_once_with('url', + params={'verb': 'ListSets'}) + + def test_no_retry(self): + mock_response = Mock(status_code=503, + headers={'retry-after': '10'}, + raise_for_status=Mock(side_effect=HTTPError)) + mock_get = Mock(return_value=mock_response) + with patch('sickle.app.requests.get', mock_get): + sickle = Sickle('url') + try: + sickle.ListRecords() + except HTTPError: + pass + self.assertEqual(1, mock_get.call_count) + + def test_retry_on_503(self): + mock_response = Mock(status_code=503, + headers={'retry-after': '10'}, + raise_for_status=Mock(side_effect=HTTPError)) + mock_get = Mock(return_value=mock_response) + sleep_mock = Mock() + with patch('time.sleep', sleep_mock): + with patch('sickle.app.requests.get', mock_get): + sickle = Sickle('url', max_retries=3, default_retry_after=0) + try: + sickle.ListRecords() + except HTTPError: + pass + mock_get.assert_called_with('url', + params={'verb': 'ListRecords'}) + self.assertEqual(4, mock_get.call_count) + self.assertEqual(3, sleep_mock.call_count) + sleep_mock.assert_called_with(10) + + def test_retry_on_custom_code(self): + mock_response = Mock(status_code=500, + raise_for_status=Mock(side_effect=HTTPError)) + mock_get = Mock(return_value=mock_response) + with patch('sickle.app.requests.get', mock_get): + sickle = Sickle('url', max_retries=3, default_retry_after=0, retry_status_codes=(503, 500)) + try: + sickle.ListRecords() + except HTTPError: + pass + mock_get.assert_called_with('url', + params={'verb': 'ListRecords'}) + self.assertEqual(4, mock_get.call_count)