Skip to content
Snippets Groups Projects
Commit 7c56a81c authored by Martin's avatar Martin
Browse files

add cached api client

parent d6cf517d
No related branches found
No related tags found
No related merge requests found
Pipeline #74321 failed
...@@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import ( ...@@ -24,7 +24,7 @@ from kaldi_data_generator.image_utils import (
rotate, rotate,
trim, trim,
) )
from kaldi_data_generator.utils import TranscriptionData, logger, write_file from kaldi_data_generator.utils import TranscriptionData, logger, write_file, CachedApiClient
SEED = 42 SEED = 42
random.seed(SEED) random.seed(SEED)
...@@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = { ...@@ -41,9 +41,10 @@ ROTATION_CLASSES_TO_ANGLES = {
} }
def create_api_client(): def create_api_client(cache_dir=None):
logger.info("Creating API client") logger.info("Creating API client")
return ArkindexClient(**options_from_env()) # return ArkindexClient(**options_from_env())
return CachedApiClient(cache_root=cache_dir, **options_from_env())
class Extraction(Enum): class Extraction(Enum):
...@@ -164,7 +165,7 @@ class HTRDataGenerator: ...@@ -164,7 +165,7 @@ class HTRDataGenerator:
def get_accepted_zones(self, page_id: str): def get_accepted_zones(self, page_id: str):
try: try:
accepted_zones = [] accepted_zones = []
for elt in self.api_client.paginate( for elt in self.api_client.cached_paginate(
"ListElementChildren", id=page_id, with_best_classes=True "ListElementChildren", id=page_id, with_best_classes=True
): ):
printed = True printed = True
...@@ -213,7 +214,7 @@ class HTRDataGenerator: ...@@ -213,7 +214,7 @@ class HTRDataGenerator:
def get_transcriptions(self, page_id: str, accepted_zones): def get_transcriptions(self, page_id: str, accepted_zones):
lines = [] lines = []
try: try:
for res in self.api_client.paginate( for res in self.api_client.cached_paginate(
"ListTranscriptions", id=page_id, recursive=True "ListTranscriptions", id=page_id, recursive=True
): ):
if ( if (
...@@ -298,7 +299,7 @@ class HTRDataGenerator: ...@@ -298,7 +299,7 @@ class HTRDataGenerator:
for best_class in elem["best_classes"] for best_class in elem["best_classes"]
if best_class["state"] != "rejected" if best_class["state"] != "rejected"
] ]
for elem in self.api_client.paginate( for elem in self.api_client.cached_paginate(
"ListElementChildren", "ListElementChildren",
id=page_id, id=page_id,
recursive=True, recursive=True,
...@@ -502,7 +503,7 @@ class HTRDataGenerator: ...@@ -502,7 +503,7 @@ class HTRDataGenerator:
logger.info(f"Volume {volume_id}") logger.info(f"Volume {volume_id}")
pages = [ pages = [
page page
for page in self.api_client.paginate( for page in self.api_client.cached_paginate(
"ListElementChildren", id=volume_id, recursive=True, type="page" "ListElementChildren", id=volume_id, recursive=True, type="page"
) )
] ]
...@@ -513,7 +514,7 @@ class HTRDataGenerator: ...@@ -513,7 +514,7 @@ class HTRDataGenerator:
logger.info(f"Folder {elem_id}") logger.info(f"Folder {elem_id}")
vol_ids = [ vol_ids = [
page["id"] page["id"]
for page in self.api_client.paginate( for page in self.api_client.cached_paginate(
"ListElementChildren", id=elem_id, recursive=True, type=volume_type "ListElementChildren", id=elem_id, recursive=True, type=volume_type
) )
] ]
...@@ -524,7 +525,7 @@ class HTRDataGenerator: ...@@ -524,7 +525,7 @@ class HTRDataGenerator:
logger.info(f"Corpus {corpus_id}") logger.info(f"Corpus {corpus_id}")
vol_ids = [ vol_ids = [
vol["id"] vol["id"]
for vol in self.api_client.paginate( for vol in self.api_client.cached_paginate(
"ListElements", corpus=corpus_id, type=volume_type "ListElements", corpus=corpus_id, type=volume_type
) )
] ]
...@@ -812,7 +813,7 @@ def main(): ...@@ -812,7 +813,7 @@ def main():
logger.info(f"ARGS {args} \n") logger.info(f"ARGS {args} \n")
api_client = create_api_client() api_client = create_api_client(args.cache_dir)
if not args.split_only: if not args.split_only:
data_generator = HTRDataGenerator( data_generator = HTRDataGenerator(
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
import os import os
from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
import sys
import cv2 import cv2
import numpy as np import numpy as np
from arkindex import ArkindexClient
import json
logging.basicConfig( logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s" level=logging.INFO, format="%(asctime)s %(levelname)s/%(name)s: %(message)s"
...@@ -15,6 +20,11 @@ BoundingBox = NamedTuple( ...@@ -15,6 +20,11 @@ BoundingBox = NamedTuple(
"BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)] "BoundingBox", [("x", int), ("y", int), ("width", int), ("height", int)]
) )
if not os.environ.get('PYTHONHASHSEED'):
logger.info("Starting again with PYTHONHASHSEED for reproducibility")
os.environ['PYTHONHASHSEED'] = "0"
os.execv(sys.executable, ['python3'] + sys.argv)
class TranscriptionData: class TranscriptionData:
def __init__( def __init__(
...@@ -65,3 +75,60 @@ class TranscriptionData: ...@@ -65,3 +75,60 @@ class TranscriptionData:
def write_file(file_name, content): def write_file(file_name, content):
with open(file_name, "w") as f: with open(file_name, "w") as f:
f.write(content) f.write(content)
class CachedApiClient(ArkindexClient):
def __init__(self, cache_root: Path = None, **kwargs):
logger.info("Creating cached api client")
super().__init__(**kwargs)
self.cache_root = cache_root
self.server_name = self.document.url.split("//")[-1].split(".")[0]
logger.info(f"Server is {self.server_name}")
self.cache_location = self.cache_root / "requests" / self.server_name
self.cache_location.mkdir(parents=True, exist_ok=True)
def cached_paginate(self, *args, **kwargs):
logger.info(f"Params: args: {args} ---- kwargs: {kwargs}")
# expecting only one positional argument - operation_id
if len(args) == 1:
operation_id = args[0]
logger.info(f"op: {operation_id}")
else:
raise ValueError(f"Unexpected number of positional arguments: {args}")
if "id" in kwargs:
ark_id = kwargs["id"]
elif "corpus" in kwargs:
ark_id = kwargs["corpus"]
else:
raise ValueError(f"Id or corpus must be defined: {kwargs}")
kwargs_hash = hash(json.dumps(kwargs, sort_keys=True))
logger.info(f"Kwargs hash: {kwargs_hash}")
request_cache = self.cache_location / operation_id / ark_id / f"hash_{kwargs_hash}.json"
if request_cache.exists():
logger.info(f"Loading from cache: {operation_id} - {ark_id}")
cached_results = json.loads(request_cache.read_text())
for res in cached_results:
yield res
logger.info("Used cached results")
return
logger.info(f"Running actual query: {operation_id} - {ark_id}")
results = []
paginator = self.paginate(*args, **kwargs)
for res in paginator:
# logger.info(f"Res: {res}")
results.append(res)
yield res
logger.info(f"Saving to cache: {operation_id} - {ark_id}")
request_cache.parent.mkdir(parents=True, exist_ok=True)
request_cache.write_text(json.dumps(results))
logger.info("Saved")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment