Skip to content
Snippets Groups Projects
Commit 7c56a81c authored by Martin's avatar Martin
Browse files

add cached api client

parent d6cf517d
Branches
Tags
No related merge requests found
Pipeline #74321 failed
......@@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import (
rotate,
trim,
)
from kaldi_data_generator.utils import TranscriptionData, logger, write_file
from kaldi_data_generator.utils import TranscriptionData, logger, write_file, CachedApiClient
SEED = 42
random.seed(SEED)
......@@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = {
}
def create_api_client():
def create_api_client(cache_dir=None):
logger.info("Creating API client")
return ArkindexClient(**options_from_env())
# return ArkindexClient(**options_from_env())
return CachedApiClient(cache_root=cache_dir, **options_from_env())
class Extraction(Enum):
......@@ -164,7 +165,7 @@ class HTRDataGenerator:
def get_accepted_zones(self, page_id: str):
try:
accepted_zones = []
for elt in self.api_client.paginate(
for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_best_classes=True
):
printed = True
......@@ -213,7 +214,7 @@ class HTRDataGenerator:
def get_transcriptions(self, page_id: str, accepted_zones):
lines = []
try:
for res in self.api_client.paginate(
for res in self.api_client.cached_paginate(
"ListTranscriptions", id=page_id, recursive=True
):
if (
......@@ -298,7 +299,7 @@ class HTRDataGenerator:
for best_class in elem["best_classes"]
if best_class["state"] != "rejected"
]
for elem in self.api_client.paginate(
for elem in self.api_client.cached_paginate(
"ListElementChildren",
id=page_id,
recursive=True,
......@@ -502,7 +503,7 @@ class HTRDataGenerator:
logger.info(f"Volume {volume_id}")
pages = [
page
for page in self.api_client.paginate(
for page in self.api_client.cached_paginate(
"ListElementChildren", id=volume_id, recursive=True, type="page"
)
]
......@@ -513,7 +514,7 @@ class HTRDataGenerator:
logger.info(f"Folder {elem_id}")
vol_ids = [
page["id"]
for page in self.api_client.paginate(
for page in self.api_client.cached_paginate(
"ListElementChildren", id=elem_id, recursive=True, type=volume_type
)
]
......@@ -524,7 +525,7 @@ class HTRDataGenerator:
logger.info(f"Corpus {corpus_id}")
vol_ids = [
vol["id"]
for vol in self.api_client.paginate(
for vol in self.api_client.cached_paginate(
"ListElements", corpus=corpus_id, type=volume_type
)
]
......@@ -812,7 +813,7 @@ def main():
logger.info(f"ARGS {args} \n")
api_client = create_api_client()
api_client = create_api_client(args.cache_dir)
if not args.split_only:
data_generator = HTRDataGenerator(
......
# -*- coding: utf-8 -*-
import logging
import os
from pathlib import Path
from typing import NamedTuple
import sys
import cv2
import numpy as np
from arkindex import ArkindexClient
import json
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
......@@ -15,6 +20,11 @@ BoundingBox = NamedTuple(
"BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)]
)
if not os.environ.get('PYTHONHASHSEED'):
logger.info("Starting again with PYTHONHASHSEED for reproducibility")
os.environ['PYTHONHASHSEED'] = "0"
os.execv(sys.executable, ['python3'] + sys.argv)
class TranscriptionData:
def __init__(
......@@ -65,3 +75,60 @@ class TranscriptionData:
def write_file(file_name, content):
with open(file_name, "w") as f:
f.write(content)
class CachedApiClient(ArkindexClient):
def __init__(self, cache_root: Path = None, **kwargs):
logger.info("Creating cached api client")
super().__init__(**kwargs)
self.cache_root = cache_root
self.server_name = self.document.url.split("//")[-1].split(".")[0]
logger.info(f"Server is {self.server_name}")
self.cache_location = self.cache_root / "requests" / self.server_name
self.cache_location.mkdir(parents=True, exist_ok=True)
def cached_paginate(self, *args, **kwargs):
logger.info(f"Params: args: {args} ---- kwargs: {kwargs}")
# expecting only one positional argument - operation_id
if len(args) == 1:
operation_id = args[0]
logger.info(f"op: {operation_id}")
else:
raise ValueError(f"Unexpected number of positional arguments: {args}")
if "id" in kwargs:
ark_id = kwargs["id"]
elif "corpus" in kwargs:
ark_id = kwargs["corpus"]
else:
raise ValueError(f"Id or corpus must be defined: {kwargs}")
kwargs_hash = hash(json.dumps(kwargs, sort_keys=True))
logger.info(f"Kwargs hash: {kwargs_hash}")
request_cache = self.cache_location / operation_id / ark_id / f"hash_{kwargs_hash}.json"
if request_cache.exists():
logger.info(f"Loading from cache: {operation_id} - {ark_id}")
cached_results = json.loads(request_cache.read_text())
for res in cached_results:
yield res
logger.info("Used cached results")
return
logger.info(f"Running actual query: {operation_id} - {ark_id}")
results = []
paginator = self.paginate(*args, **kwargs)
for res in paginator:
# logger.info(f"Res: {res}")
results.append(res)
yield res
logger.info(f"Saving to cache: {operation_id} - {ark_id}")
request_cache.parent.mkdir(parents=True, exist_ok=True)
request_cache.write_text(json.dumps(results))
logger.info("Saved")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment