Skip to content
Snippets Groups Projects
Commit 7c56a81c authored by Martin's avatar Martin
Browse files

add cached api client

parent d6cf517d
No related branches found
No related tags found
No related merge requests found
Pipeline #74321 failed
......@@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import (
rotate,
trim,
)
from kaldi_data_generator.utils import TranscriptionData, logger, write_file
from kaldi_data_generator.utils import TranscriptionData, logger, write_file, CachedApiClient
SEED = 42
random.seed(SEED)
......@@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = {
}
def create_api_client():
def create_api_client(cache_dir=None):
logger.info("Creating API client")
return ArkindexClient(**options_from_env())
# return ArkindexClient(**options_from_env())
return CachedApiClient(cache_root=cache_dir, **options_from_env())
class Extraction(Enum):
......@@ -164,7 +165,7 @@ class HTRDataGenerator:
def get_accepted_zones(self, page_id: str):
try:
accepted_zones = []
for elt in self.api_client.paginate(
for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_best_classes=True
):
printed = True
......@@ -213,7 +214,7 @@ class HTRDataGenerator:
def get_transcriptions(self, page_id: str, accepted_zones):
lines = []
try:
for res in self.api_client.paginate(
for res in self.api_client.cached_paginate(
"ListTranscriptions", id=page_id, recursive=True
):
if (
......@@ -298,7 +299,7 @@ class HTRDataGenerator:
for best_class in elem["best_classes"]
if best_class["state"] != "rejected"
]
for elem in self.api_client.paginate(
for elem in self.api_client.cached_paginate(
"ListElementChildren",
id=page_id,
recursive=True,
......@@ -502,7 +503,7 @@ class HTRDataGenerator:
logger.info(f"Volume {volume_id}")
pages = [
page
for page in self.api_client.paginate(
for page in self.api_client.cached_paginate(
"ListElementChildren", id=volume_id, recursive=True, type="page"
)
]
......@@ -513,7 +514,7 @@ class HTRDataGenerator:
logger.info(f"Folder {elem_id}")
vol_ids = [
page["id"]
for page in self.api_client.paginate(
for page in self.api_client.cached_paginate(
"ListElementChildren", id=elem_id, recursive=True, type=volume_type
)
]
......@@ -524,7 +525,7 @@ class HTRDataGenerator:
logger.info(f"Corpus {corpus_id}")
vol_ids = [
vol["id"]
for vol in self.api_client.paginate(
for vol in self.api_client.cached_paginate(
"ListElements", corpus=corpus_id, type=volume_type
)
]
......@@ -812,7 +813,7 @@ def main():
logger.info(f"ARGS {args} \n")
api_client = create_api_client()
api_client = create_api_client(args.cache_dir)
if not args.split_only:
data_generator = HTRDataGenerator(
......
# -*- coding: utf-8 -*-
import logging
import os
from pathlib import Path
from typing import NamedTuple
import sys
import cv2
import numpy as np
from arkindex import ArkindexClient
import json
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
......@@ -15,6 +20,11 @@ BoundingBox = NamedTuple(
"BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)]
)
if not os.environ.get('PYTHONHASHSEED'):
logger.info("Starting again with PYTHONHASHSEED for reproducibility")
os.environ['PYTHONHASHSEED'] = "0"
os.execv(sys.executable, ['python3'] + sys.argv)
class TranscriptionData:
def __init__(
......@@ -65,3 +75,60 @@ class TranscriptionData:
def write_file(file_name, content):
with open(file_name, "w") as f:
f.write(content)
class CachedApiClient(ArkindexClient):
def __init__(self, cache_root: Path = None, **kwargs):
logger.info("Creating cached api client")
super().__init__(**kwargs)
self.cache_root = cache_root
self.server_name = self.document.url.split("//")[-1].split(".")[0]
logger.info(f"Server is {self.server_name}")
self.cache_location = self.cache_root / "requests" / self.server_name
self.cache_location.mkdir(parents=True, exist_ok=True)
def cached_paginate(self, *args, **kwargs):
logger.info(f"Params: args: {args} ---- kwargs: {kwargs}")
# expecting only one positional argument - operation_id
if len(args) == 1:
operation_id = args[0]
logger.info(f"op: {operation_id}")
else:
raise ValueError(f"Unexpected number of positional arguments: {args}")
if "id" in kwargs:
ark_id = kwargs["id"]
elif "corpus" in kwargs:
ark_id = kwargs["corpus"]
else:
raise ValueError(f"Id or corpus must be defined: {kwargs}")
kwargs_hash = hash(json.dumps(kwargs, sort_keys=True))
logger.info(f"Kwargs hash: {kwargs_hash}")
request_cache = self.cache_location / operation_id / ark_id / f"hash_{kwargs_hash}.json"
if request_cache.exists():
logger.info(f"Loading from cache: {operation_id} - {ark_id}")
cached_results = json.loads(request_cache.read_text())
for res in cached_results:
yield res
logger.info("Used cached results")
return
logger.info(f"Running actual query: {operation_id} - {ark_id}")
results = []
paginator = self.paginate(*args, **kwargs)
for res in paginator:
# logger.info(f"Res: {res}")
results.append(res)
yield res
logger.info(f"Saving to cache: {operation_id} - {ark_id}")
request_cache.parent.mkdir(parents=True, exist_ok=True)
request_cache.write_text(json.dumps(results))
logger.info("Saved")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment