From 7c56a81cdea18d287a39db56df8570ca5f812e20 Mon Sep 17 00:00:00 2001
From: Martin <maarand@teklia.com>
Date: Fri, 5 Nov 2021 17:54:17 +0100
Subject: [PATCH] add cached api client

---
 kaldi_data_generator/main.py  | 21 +++++------
 kaldi_data_generator/utils.py | 67 +++++++++++++++++++++++++++++++++++
 2 files changed, 78 insertions(+), 10 deletions(-)

diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py
index 2a61077..ccfc8c4 100644
--- a/kaldi_data_generator/main.py
+++ b/kaldi_data_generator/main.py
@@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import (
     rotate,
     trim,
 )
-from kaldi_data_generator.utils import TranscriptionData, logger, write_file
+from kaldi_data_generator.utils import TranscriptionData, logger, write_file, CachedApiClient
 
 SEED = 42
 random.seed(SEED)
@@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = {
 }
 
 
-def create_api_client():
+def create_api_client(cache_dir=None):
     logger.info("Creating API client")
-    return ArkindexClient(**options_from_env())
+    # return ArkindexClient(**options_from_env())
+    return CachedApiClient(cache_root=cache_dir, **options_from_env())
 
 
 class Extraction(Enum):
@@ -164,7 +165,7 @@ class HTRDataGenerator:
     def get_accepted_zones(self, page_id: str):
         try:
             accepted_zones = []
-            for elt in self.api_client.paginate(
+            for elt in self.api_client.cached_paginate(
                 "ListElementChildren", id=page_id, with_best_classes=True
             ):
                 printed = True
@@ -213,7 +214,7 @@ class HTRDataGenerator:
     def get_transcriptions(self, page_id: str, accepted_zones):
         lines = []
         try:
-            for res in self.api_client.paginate(
+            for res in self.api_client.cached_paginate(
                 "ListTranscriptions", id=page_id, recursive=True
             ):
                 if (
@@ -298,7 +299,7 @@ class HTRDataGenerator:
                 for best_class in elem["best_classes"]
                 if best_class["state"] != "rejected"
             ]
-            for elem in self.api_client.paginate(
+            for elem in self.api_client.cached_paginate(
                 "ListElementChildren",
                 id=page_id,
                 recursive=True,
@@ -502,7 +503,7 @@ class HTRDataGenerator:
             logger.info(f"Volume {volume_id}")
             pages = [
                 page
-                for page in self.api_client.paginate(
+                for page in self.api_client.cached_paginate(
                     "ListElementChildren", id=volume_id, recursive=True, type="page"
                 )
             ]
@@ -513,7 +514,7 @@ class HTRDataGenerator:
             logger.info(f"Folder {elem_id}")
             vol_ids = [
                 page["id"]
-                for page in self.api_client.paginate(
+                for page in self.api_client.cached_paginate(
                     "ListElementChildren", id=elem_id, recursive=True, type=volume_type
                 )
             ]
@@ -524,7 +525,7 @@ class HTRDataGenerator:
             logger.info(f"Corpus {corpus_id}")
             vol_ids = [
                 vol["id"]
-                for vol in self.api_client.paginate(
+                for vol in self.api_client.cached_paginate(
                     "ListElements", corpus=corpus_id, type=volume_type
                 )
             ]
@@ -812,7 +813,7 @@ def main():
 
     logger.info(f"ARGS {args} \n")
 
-    api_client = create_api_client()
+    api_client = create_api_client(args.cache_dir)
 
     if not args.split_only:
         data_generator = HTRDataGenerator(
diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py
index 59a4761..7e88f5f 100644
--- a/kaldi_data_generator/utils.py
+++ b/kaldi_data_generator/utils.py
@@ -1,10 +1,15 @@
 # -*- coding: utf-8 -*-
 import logging
 import os
+from pathlib import Path
 from typing import NamedTuple
 
+import sys
 import cv2
 import numpy as np
+from arkindex import ArkindexClient
+
+import json
 
 logging.basicConfig(
     level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
@@ -15,6 +20,11 @@ BoundingBox = NamedTuple(
     "BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)]
 )
 
+if not os.environ.get('PYTHONHASHSEED'):
+    logger.info("Starting again with PYTHONHASHSEED for reproducibility")
+    os.environ['PYTHONHASHSEED'] = "0"
+    os.execv(sys.executable, ['python3'] + sys.argv)
+
 
 class TranscriptionData:
     def __init__(
@@ -65,3 +75,60 @@ class TranscriptionData:
 def write_file(file_name, content):
     with open(file_name, "w") as f:
         f.write(content)
+
+
+class CachedApiClient(ArkindexClient):
+
+    def __init__(self, cache_root: Path = None, **kwargs):
+        logger.info("Creating cached api client")
+        super().__init__(**kwargs)
+        self.cache_root = cache_root
+
+        self.server_name = self.document.url.split("//")[-1].split(".")[0]
+        logger.info(f"Server is {self.server_name}")
+
+        self.cache_location = self.cache_root / "requests" / self.server_name
+
+        self.cache_location.mkdir(parents=True, exist_ok=True)
+
+    def cached_paginate(self, *args, **kwargs):
+        logger.info(f"Params: args: {args} ---- kwargs: {kwargs}")
+
+        # expecting only one positional argument - operation_id
+        if len(args) == 1:
+            operation_id = args[0]
+            logger.info(f"op: {operation_id}")
+        else:
+            raise ValueError(f"Unexpected number of positional arguments: {args}")
+
+        if "id" in kwargs:
+            ark_id = kwargs["id"]
+        elif "corpus" in kwargs:
+            ark_id = kwargs["corpus"]
+        else:
+            raise ValueError(f"Id or corpus must be defined: {kwargs}")
+
+        kwargs_hash = hash(json.dumps(kwargs, sort_keys=True))
+        logger.info(f"Kwargs hash: {kwargs_hash}")
+
+        request_cache = self.cache_location / operation_id / ark_id / f"hash_{kwargs_hash}.json"
+        if request_cache.exists():
+            logger.info(f"Loading from cache: {operation_id} - {ark_id}")
+            cached_results = json.loads(request_cache.read_text())
+            for res in cached_results:
+                yield res
+            logger.info("Used cached results")
+            return
+
+        logger.info(f"Running actual query: {operation_id} - {ark_id}")
+        results = []
+        paginator = self.paginate(*args, **kwargs)
+        for res in paginator:
+            # logger.info(f"Res: {res}")
+            results.append(res)
+            yield res
+
+        logger.info(f"Saving to cache: {operation_id} - {ark_id}")
+        request_cache.parent.mkdir(parents=True, exist_ok=True)
+        request_cache.write_text(json.dumps(results))
+        logger.info("Saved")
-- 
GitLab