From 7c56a81cdea18d287a39db56df8570ca5f812e20 Mon Sep 17 00:00:00 2001 From: Martin <maarand@teklia.com> Date: Fri, 5 Nov 2021 17:54:17 +0100 Subject: [PATCH] add cached api client --- kaldi_data_generator/main.py | 21 +++++------ kaldi_data_generator/utils.py | 67 +++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/kaldi_data_generator/main.py b/kaldi_data_generator/main.py index 2a61077..ccfc8c4 100644 --- a/kaldi_data_generator/main.py +++ b/kaldi_data_generator/main.py @@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import ( rotate, trim, ) -from kaldi_data_generator.utils import TranscriptionData, logger, write_file +from kaldi_data_generator.utils import TranscriptionData, logger, write_file, CachedApiClient SEED = 42 random.seed(SEED) @@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = { } -def create_api_client(): +def create_api_client(cache_dir=None): logger.info("Creating API client") - return ArkindexClient(**options_from_env()) + # return ArkindexClient(**options_from_env()) + return CachedApiClient(cache_root=cache_dir, **options_from_env()) class Extraction(Enum): @@ -164,7 +165,7 @@ class HTRDataGenerator: def get_accepted_zones(self, page_id: str): try: accepted_zones = [] - for elt in self.api_client.paginate( + for elt in self.api_client.cached_paginate( "ListElementChildren", id=page_id, with_best_classes=True ): printed = True @@ -213,7 +214,7 @@ class HTRDataGenerator: def get_transcriptions(self, page_id: str, accepted_zones): lines = [] try: - for res in self.api_client.paginate( + for res in self.api_client.cached_paginate( "ListTranscriptions", id=page_id, recursive=True ): if ( @@ -298,7 +299,7 @@ class HTRDataGenerator: for best_class in elem["best_classes"] if best_class["state"] != "rejected" ] - for elem in self.api_client.paginate( + for elem in self.api_client.cached_paginate( "ListElementChildren", id=page_id, recursive=True, @@ -502,7 +503,7 @@ class HTRDataGenerator: logger.info(f"Volume {volume_id}") pages = [ page - for page in self.api_client.paginate( + for page in self.api_client.cached_paginate( "ListElementChildren", id=volume_id, recursive=True, type="page" ) ] @@ -513,7 +514,7 @@ class HTRDataGenerator: logger.info(f"Folder {elem_id}") vol_ids = [ page["id"] - for page in self.api_client.paginate( + for page in self.api_client.cached_paginate( "ListElementChildren", id=elem_id, recursive=True, type=volume_type ) ] @@ -524,7 +525,7 @@ class HTRDataGenerator: logger.info(f"Corpus {corpus_id}") vol_ids = [ vol["id"] - for vol in self.api_client.paginate( + for vol in self.api_client.cached_paginate( "ListElements", corpus=corpus_id, type=volume_type ) ] @@ -812,7 +813,7 @@ def main(): logger.info(f"ARGS {args} \n") - api_client = create_api_client() + api_client = create_api_client(args.cache_dir) if not args.split_only: data_generator = HTRDataGenerator( diff --git a/kaldi_data_generator/utils.py b/kaldi_data_generator/utils.py index 59a4761..7e88f5f 100644 --- a/kaldi_data_generator/utils.py +++ b/kaldi_data_generator/utils.py @@ -1,10 +1,15 @@ # -*- coding: utf-8 -*- import logging import os +from pathlib import Path from typing import NamedTuple +import sys import cv2 import numpy as np +from arkindex import ArkindexClient + +import json logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s" @@ -15,6 +20,11 @@ BoundingBox = NamedTuple( "BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)] ) +if not os.environ.get('PYTHONHASHSEED'): + logger.info("Starting again with PYTHONHASHSEED for reproducibility") + os.environ['PYTHONHASHSEED'] = "0" + os.execv(sys.executable, ['python3'] + sys.argv) + class TranscriptionData: def __init__( @@ -65,3 +75,60 @@ class TranscriptionData: def write_file(file_name, content): with open(file_name, "w") as f: f.write(content) + + +class CachedApiClient(ArkindexClient): + + def __init__(self, cache_root: Path = None, **kwargs): + logger.info("Creating cached api client") + super().__init__(**kwargs) + self.cache_root = cache_root + + self.server_name = self.document.url.split("//")[-1].split(".")[0] + logger.info(f"Server is {self.server_name}") + + self.cache_location = self.cache_root / "requests" / self.server_name + + self.cache_location.mkdir(parents=True, exist_ok=True) + + def cached_paginate(self, *args, **kwargs): + logger.info(f"Params: args: {args} ---- kwargs: {kwargs}") + + # expecting only one positional argument - operation_id + if len(args) == 1: + operation_id = args[0] + logger.info(f"op: {operation_id}") + else: + raise ValueError(f"Unexpected number of positional arguments: {args}") + + if "id" in kwargs: + ark_id = kwargs["id"] + elif "corpus" in kwargs: + ark_id = kwargs["corpus"] + else: + raise ValueError(f"Id or corpus must be defined: {kwargs}") + + kwargs_hash = hash(json.dumps(kwargs, sort_keys=True)) + logger.info(f"Kwargs hash: {kwargs_hash}") + + request_cache = self.cache_location / operation_id / ark_id / f"hash_{kwargs_hash}.json" + if request_cache.exists(): + logger.info(f"Loading from cache: {operation_id} - {ark_id}") + cached_results = json.loads(request_cache.read_text()) + for res in cached_results: + yield res + logger.info("Used cached results") + return + + logger.info(f"Running actual query: {operation_id} - {ark_id}") + results = [] + paginator = self.paginate(*args, **kwargs) + for res in paginator: + # logger.info(f"Res: {res}") + results.append(res) + yield res + + logger.info(f"Saving to cache: {operation_id} - {ark_id}") + request_cache.parent.mkdir(parents=True, exist_ok=True) + request_cache.write_text(json.dumps(results)) + logger.info("Saved") -- GitLab