From 8ef51f368f7c970ed95f9d9e1d6de9609b78c08c Mon Sep 17 00:00:00 2001
From: Bastien Abadie <bastien@nextcairn.com>
Date: Mon, 29 Mar 2021 15:09:12 +0000
Subject: [PATCH] Retry all managed API calls that result in a 50x

---
 arkindex_worker/worker.py                     | 71 ++++++++++++++-----
 requirements.txt                              |  2 +-
 tests/conftest.py                             | 11 +++
 .../test_classifications.py                   |  7 +-
 tests/test_elements_worker/test_elements.py   | 14 +++-
 tests/test_elements_worker/test_entities.py   |  7 +-
 tests/test_elements_worker/test_metadata.py   |  7 +-
 .../test_transcriptions.py                    | 21 +++++-
 8 files changed, 112 insertions(+), 28 deletions(-)

diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py
index 031b9da3..9d89c6fa 100644
--- a/arkindex_worker/worker.py
+++ b/arkindex_worker/worker.py
@@ -13,6 +13,13 @@ import apistar
 import gnupg
 import yaml
 from apistar.exceptions import ErrorResponse
+from tenacity import (
+    before_sleep_log,
+    retry,
+    retry_if_exception,
+    stop_after_attempt,
+    wait_exponential,
+)
 
 from arkindex import ArkindexClient, options_from_env
 from arkindex_worker import logger
@@ -25,6 +32,17 @@ MANUAL_SLUG = "manual"
 CACHE_DIR = f"/data/{os.environ.get('TASK_ID')}"
 
 
+def _is_500_error(exc):
+    """
+    Check if an Arkindex API error is a 50x
+    This is used to retry most API calls implemented here
+    """
+    if not isinstance(exc, ErrorResponse):
+        return False
+
+    return 500 <= exc.status_code < 600
+
+
 class BaseWorker(object):
     def __init__(self, description="Arkindex Base Worker", use_cache=False):
         self.parser = argparse.ArgumentParser(description=description)
@@ -100,13 +118,13 @@ class BaseWorker(object):
         logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
 
         # Load features available on backend, and check authentication
-        user = self.api_client.request("RetrieveUser")
+        user = self.request("RetrieveUser")
         logger.debug(f"Connected as {user['display_name']} - {user['email']}")
         self.features = user["features"]
 
         if self.worker_version_id:
             # Retrieve initial configuration from API
-            worker_version = self.api_client.request(
+            worker_version = self.request(
                 "RetrieveWorkerVersion", id=self.worker_version_id
             )
             logger.info(
@@ -135,7 +153,7 @@ class BaseWorker(object):
 
         # Load from the backend
         try:
-            resp = self.api_client.request("RetrieveSecret", name=name)
+            resp = self.request("RetrieveSecret", name=name)
             secret = resp["content"]
             logging.info(f"Loaded API secret {name}")
         except ErrorResponse as e:
@@ -175,6 +193,25 @@ class BaseWorker(object):
         # By default give raw secret payload
         return secret
 
+    @retry(
+        retry=retry_if_exception(_is_500_error),
+        wait=wait_exponential(multiplier=2, min=3),
+        reraise=True,
+        stop=stop_after_attempt(5),
+        before_sleep=before_sleep_log(logger, logging.INFO),
+    )
+    def request(self, *args, **kwargs):
+        """
+        Proxy all Arkindex API requests with a retry mechanism
+        in case of 50X errors
+        The same API call will be retried 5 times, with an exponential sleep time
+        going through 3, 4, 8 and 16 seconds of wait between call.
+        If the 5th call still gives a 50x, the exception is re-raised
+        and the caller should catch it
+        Log messages are displayed before sleeping (when at least one exception occurred)
+        """
+        return self.api_client.request(*args, **kwargs)
+
     def add_arguments(self):
         """Override this method to add argparse argument to this worker"""
 
@@ -269,9 +306,7 @@ class ElementsWorker(BaseWorker):
         for i, element_id in enumerate(elements, start=1):
             try:
                 # Load element using Arkindex API
-                element = Element(
-                    **self.api_client.request("RetrieveElement", id=element_id)
-                )
+                element = Element(**self.request("RetrieveElement", id=element_id))
                 logger.info(f"Processing {element} ({i}/{count})")
 
                 # Report start of process, run process, then report end of process
@@ -335,7 +370,7 @@ class ElementsWorker(BaseWorker):
         if ml_class_id is None:
             logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}")
             try:
-                response = self.api_client.request(
+                response = self.request(
                     "CreateMLClass", id=corpus_id, body={"name": ml_class}
                 )
                 ml_class_id = self.classes[corpus_id][ml_class] = response["id"]
@@ -385,7 +420,7 @@ class ElementsWorker(BaseWorker):
             logger.warning("Cannot create element as this worker is in read-only mode")
             return
 
-        sub_element = self.api_client.request(
+        sub_element = self.request(
             "CreateElement",
             body={
                 "type": type,
@@ -446,7 +481,7 @@ class ElementsWorker(BaseWorker):
             logger.warning("Cannot create elements as this worker is in read-only mode")
             return
 
-        created_ids = self.api_client.request(
+        created_ids = self.request(
             "CreateElements",
             id=parent.id,
             body={
@@ -500,7 +535,7 @@ class ElementsWorker(BaseWorker):
             )
             return
 
-        created = self.api_client.request(
+        created = self.request(
             "CreateTranscription",
             id=element.id,
             body={
@@ -557,7 +592,7 @@ class ElementsWorker(BaseWorker):
                 score is not None and isinstance(score, float) and 0 <= score <= 1
             ), f"Transcription at index {index} in transcriptions: score shouldn't be null and should be a float in [0..1] range"
 
-        created_trs = self.api_client.request(
+        created_trs = self.request(
             "CreateTranscriptions",
             body={
                 "worker_version": self.worker_version_id,
@@ -614,7 +649,7 @@ class ElementsWorker(BaseWorker):
             return
 
         try:
-            self.api_client.request(
+            self.request(
                 "CreateClassification",
                 body={
                     "element": element.id,
@@ -668,7 +703,7 @@ class ElementsWorker(BaseWorker):
             logger.warning("Cannot create entity as this worker is in read-only mode")
             return
 
-        entity = self.api_client.request(
+        entity = self.request(
             "CreateEntity",
             body={
                 "name": name,
@@ -727,7 +762,7 @@ class ElementsWorker(BaseWorker):
             )
             return
 
-        annotations = self.api_client.request(
+        annotations = self.request(
             "CreateElementTranscriptions",
             id=element.id,
             body={
@@ -813,7 +848,7 @@ class ElementsWorker(BaseWorker):
             logger.warning("Cannot create metadata as this worker is in read-only mode")
             return
 
-        metadata = self.api_client.request(
+        metadata = self.request(
             "CreateMetaData",
             id=element.id,
             body={
@@ -835,9 +870,7 @@ class ElementsWorker(BaseWorker):
         if worker_version_id in self._worker_version_cache:
             return self._worker_version_cache[worker_version_id]
 
-        worker_version = self.api_client.request(
-            "RetrieveWorkerVersion", id=worker_version_id
-        )
+        worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
         self._worker_version_cache[worker_version_id] = worker_version
 
         return worker_version
@@ -1012,7 +1045,7 @@ class ElementsWorker(BaseWorker):
             return
 
         try:
-            out = self.api_client.request(
+            out = self.request(
                 "UpdateWorkerActivity",
                 id=self.worker_version_id,
                 body={
diff --git a/requirements.txt b/requirements.txt
index 10792bb4..c5dd3cbd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,4 +3,4 @@ Pillow==8.1.0
 python-gitlab==2.6.0
 python-gnupg==0.4.6
 sh==1.14.1
-tenacity==6.3.1
+tenacity==7.0.0
diff --git a/tests/conftest.py b/tests/conftest.py
index b7ff454b..de8ed278 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,6 +3,7 @@ import hashlib
 import json
 import os
 import sys
+import time
 from pathlib import Path
 
 import pytest
@@ -19,6 +20,16 @@ CACHE_FILE = os.path.join(CACHE_DIR, "db.sqlite")
 __yaml_cache = {}
 
 
+@pytest.fixture(autouse=True)
+def disable_sleep(monkeypatch):
+    """
+    Do not sleep at all in between API executions
+    when errors occur in unit tests.
+    This speeds up the test execution a lot
+    """
+    monkeypatch.setattr(time, "sleep", lambda x: None)
+
+
 @pytest.fixture
 def cache_yaml(monkeypatch):
     """
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 5c79d53c..05e8d8d6 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -362,10 +362,15 @@ def test_create_classification_api_error(responses, mock_elements_worker):
             high_confidence=True,
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/classifications/",
+        "http://testserver/api/v1/classifications/",
+        "http://testserver/api/v1/classifications/",
+        "http://testserver/api/v1/classifications/",
         "http://testserver/api/v1/classifications/",
     ]
 
diff --git a/tests/test_elements_worker/test_elements.py b/tests/test_elements_worker/test_elements.py
index 03028cad..7bf069ad 100644
--- a/tests/test_elements_worker/test_elements.py
+++ b/tests/test_elements_worker/test_elements.py
@@ -378,10 +378,15 @@ def test_create_sub_element_api_error(responses, mock_elements_worker):
             polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/elements/create/",
+        "http://testserver/api/v1/elements/create/",
+        "http://testserver/api/v1/elements/create/",
+        "http://testserver/api/v1/elements/create/",
         "http://testserver/api/v1/elements/create/",
     ]
 
@@ -663,10 +668,15 @@ def test_create_elements_api_error(responses, mock_elements_worker):
             ],
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
         "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
     ]
 
diff --git a/tests/test_elements_worker/test_entities.py b/tests/test_elements_worker/test_entities.py
index c8752a0d..be513d6b 100644
--- a/tests/test_elements_worker/test_entities.py
+++ b/tests/test_elements_worker/test_entities.py
@@ -147,10 +147,15 @@ def test_create_entity_api_error(responses, mock_elements_worker):
             corpus="12341234-1234-1234-1234-123412341234",
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/entity/",
+        "http://testserver/api/v1/entity/",
+        "http://testserver/api/v1/entity/",
+        "http://testserver/api/v1/entity/",
         "http://testserver/api/v1/entity/",
     ]
 
diff --git a/tests/test_elements_worker/test_metadata.py b/tests/test_elements_worker/test_metadata.py
index f2a79903..e62c06b7 100644
--- a/tests/test_elements_worker/test_metadata.py
+++ b/tests/test_elements_worker/test_metadata.py
@@ -133,10 +133,15 @@ def test_create_metadata_api_error(responses, mock_elements_worker):
             value="La Turbine, Grenoble 38000",
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
+        "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
         "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/metadata/",
     ]
 
diff --git a/tests/test_elements_worker/test_transcriptions.py b/tests/test_elements_worker/test_transcriptions.py
index 3f66c63d..c3c43af0 100644
--- a/tests/test_elements_worker/test_transcriptions.py
+++ b/tests/test_elements_worker/test_transcriptions.py
@@ -127,10 +127,15 @@ def test_create_transcription_api_error(responses, mock_elements_worker):
             score=0.42,
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
+        f"http://testserver/api/v1/element/{elt.id}/transcription/",
         f"http://testserver/api/v1/element/{elt.id}/transcription/",
     ]
 
@@ -442,10 +447,15 @@ def test_create_transcriptions_api_error(responses, mock_elements_worker):
     with pytest.raises(ErrorResponse):
         mock_elements_worker.create_transcriptions(transcriptions=trans)
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        "http://testserver/api/v1/transcription/bulk/",
+        "http://testserver/api/v1/transcription/bulk/",
+        "http://testserver/api/v1/transcription/bulk/",
+        "http://testserver/api/v1/transcription/bulk/",
         "http://testserver/api/v1/transcription/bulk/",
     ]
 
@@ -917,10 +927,15 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
             transcriptions=TRANSCRIPTIONS_SAMPLE,
         )
 
-    assert len(responses.calls) == 3
+    assert len(responses.calls) == 7
     assert [call.request.url for call in responses.calls] == [
         "http://testserver/api/v1/user/",
         "http://testserver/api/v1/workers/versions/12341234-1234-1234-1234-123412341234/",
+        # We retry 5 times the API call
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
+        f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
         f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
     ]
 
-- 
GitLab