diff --git a/sickle/app.py b/sickle/app.py
index 5beacb6..6fa3b1b 100644
--- a/sickle/app.py
+++ b/sickle/app.py
@@ -12,9 +12,9 @@
import time
import requests
+
from sickle.iterator import BaseOAIIterator, OAIItemIterator
from sickle.response import OAIResponse
-
from .models import (Set, Record, Header, MetadataFormat,
Identify)
@@ -52,8 +52,15 @@ class Sickle(object):
:type protocol_version: str
:param iterator: The type of the returned iterator
(default: :class:`sickle.iterator.OAIItemIterator`)
- :param max_retries: Number of retries if HTTP request fails.
+ :param max_retries: Number of retry attempts if an HTTP request fails (default: 0 = request only once). Sickle will
+ use the value from the retry-after header (if present) and will wait the specified number of
+ seconds between retries.
:type max_retries: int
+ :param retry_status_codes: HTTP status codes to retry (default will only retry on 503)
+ :type retry_status_codes: iterable
+ :param default_retry_after: default number of seconds to wait between retries in case no retry-after header is found
+ on the response (defaults to 60 seconds)
+ :type default_retry_after: int
:type protocol_version: str
:param class_mapping: A dictionary that maps OAI verbs to classes representing
OAI items. If not provided,
@@ -73,9 +80,17 @@ class Sickle(object):
for all available parameters.
"""
- def __init__(self, endpoint, http_method='GET', protocol_version='2.0',
- iterator=OAIItemIterator, max_retries=5,
- class_mapping=None, encoding=None, **request_args):
+ def __init__(self, endpoint,
+ http_method='GET',
+ protocol_version='2.0',
+ iterator=OAIItemIterator,
+ max_retries=0,
+ retry_status_codes=None,
+ default_retry_after=60,
+ class_mapping=None,
+ encoding=None,
+ **request_args):
+
self.endpoint = endpoint
if http_method not in ['GET', 'POST']:
raise ValueError("Invalid HTTP method: %s! Must be GET or POST.")
@@ -90,6 +105,8 @@ def __init__(self, endpoint, http_method='GET', protocol_version='2.0',
raise TypeError(
"Argument 'iterator' must be subclass of %s" % BaseOAIIterator.__name__)
self.max_retries = max_retries
+ self.retry_status_codes = retry_status_codes or [503]
+ self.default_retry_after = default_retry_after
self.oai_namespace = OAI_NAMESPACE % self.protocol_version
self.class_mapping = class_mapping or DEFAULT_CLASS_MAP
self.encoding = encoding
@@ -101,26 +118,24 @@ def harvest(self, **kwargs): # pragma: no cover
:param kwargs: OAI HTTP parameters.
:rtype: :class:`sickle.OAIResponse`
"""
+ http_response = self._request(kwargs)
for _ in range(self.max_retries):
- if self.http_method == 'GET':
- http_response = requests.get(self.endpoint, params=kwargs,
- **self.request_args)
- else:
- http_response = requests.post(self.endpoint, data=kwargs,
- **self.request_args)
- if http_response.status_code == 503:
- try:
- retry_after = int(http_response.headers.get('retry-after'))
- except TypeError:
- retry_after = 20
- logger.info(
- "HTTP 503! Retrying after %d seconds..." % retry_after)
+ if self._is_error_code(http_response.status_code) \
+ and http_response.status_code in self.retry_status_codes:
+ retry_after = self.get_retry_after(http_response)
+ logger.warning(
+ "HTTP %d! Retrying after %d seconds..." % (http_response.status_code, retry_after))
time.sleep(retry_after)
- else:
- http_response.raise_for_status()
- if self.encoding:
- http_response.encoding = self.encoding
- return OAIResponse(http_response, params=kwargs)
+ http_response = self._request(kwargs)
+ http_response.raise_for_status()
+ if self.encoding:
+ http_response.encoding = self.encoding
+ return OAIResponse(http_response, params=kwargs)
+
+ def _request(self, kwargs):
+ if self.http_method == 'GET':
+ return requests.get(self.endpoint, params=kwargs, **self.request_args)
+ return requests.post(self.endpoint, data=kwargs, **self.request_args)
def ListRecords(self, ignore_deleted=False, **kwargs):
"""Issue a ListRecords request.
@@ -178,3 +193,15 @@ def ListMetadataFormats(self, **kwargs):
params = kwargs
params.update({'verb': 'ListMetadataFormats'})
return self.iterator(self, params)
+
+ def get_retry_after(self, http_response):
+ if http_response.status_code == 503:
+ try:
+ return int(http_response.headers.get('retry-after'))
+ except TypeError:
+ return self.default_retry_after
+ return self.default_retry_after
+
+ @staticmethod
+ def _is_error_code(status_code):
+ return status_code >= 400
diff --git a/sickle/tests/test_sickle.py b/sickle/tests/test_sickle.py
index ab94063..62b75f5 100644
--- a/sickle/tests/test_sickle.py
+++ b/sickle/tests/test_sickle.py
@@ -10,6 +10,8 @@
from mock import patch, Mock
from nose.tools import raises
+from requests import HTTPError
+
from sickle import Sickle
this_dir, this_filename = os.path.split(__file__)
@@ -29,7 +31,7 @@ def test_invalid_iterator(self):
Sickle("http://localhost", iterator=None)
def test_pass_request_args(self):
- mock_response = Mock(text=u'', content='')
+ mock_response = Mock(text=u'', content='', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', timeout=10, proxies=dict(),
@@ -41,9 +43,56 @@ def test_pass_request_args(self):
auth=('user', 'password'))
def test_override_encoding(self):
- mock_response = Mock(text='', content='')
+ mock_response = Mock(text='', content='', status_code=200)
mock_get = Mock(return_value=mock_response)
with patch('sickle.app.requests.get', mock_get):
sickle = Sickle('url', encoding='encoding')
sickle.ListSets()
- self.assertEqual(mock_response.encoding, 'encoding')
+ mock_get.assert_called_once_with('url',
+ params={'verb': 'ListSets'})
+
+ def test_no_retry(self):
+ mock_response = Mock(status_code=503,
+ headers={'retry-after': '10'},
+ raise_for_status=Mock(side_effect=HTTPError))
+ mock_get = Mock(return_value=mock_response)
+ with patch('sickle.app.requests.get', mock_get):
+ sickle = Sickle('url')
+ try:
+ sickle.ListRecords()
+ except HTTPError:
+ pass
+ self.assertEqual(1, mock_get.call_count)
+
+ def test_retry_on_503(self):
+ mock_response = Mock(status_code=503,
+ headers={'retry-after': '10'},
+ raise_for_status=Mock(side_effect=HTTPError))
+ mock_get = Mock(return_value=mock_response)
+ sleep_mock = Mock()
+ with patch('time.sleep', sleep_mock):
+ with patch('sickle.app.requests.get', mock_get):
+ sickle = Sickle('url', max_retries=3, default_retry_after=0)
+ try:
+ sickle.ListRecords()
+ except HTTPError:
+ pass
+ mock_get.assert_called_with('url',
+ params={'verb': 'ListRecords'})
+ self.assertEqual(4, mock_get.call_count)
+ self.assertEqual(3, sleep_mock.call_count)
+ sleep_mock.assert_called_with(10)
+
+ def test_retry_on_custom_code(self):
+ mock_response = Mock(status_code=500,
+ raise_for_status=Mock(side_effect=HTTPError))
+ mock_get = Mock(return_value=mock_response)
+ with patch('sickle.app.requests.get', mock_get):
+ sickle = Sickle('url', max_retries=3, default_retry_after=0, retry_status_codes=(503, 500))
+ try:
+ sickle.ListRecords()
+ except HTTPError:
+ pass
+ mock_get.assert_called_with('url',
+ params={'verb': 'ListRecords'})
+ self.assertEqual(4, mock_get.call_count)