Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# -*- coding: utf-8 -*-
import logging
from operator import itemgetter
from tempfile import _TemporaryFileWrapper
from typing import List
from uuid import UUID
from apistar.exceptions import ErrorResponse
from arkindex_export import (
Classification,
DatasetElement,
Element,
Entity,
EntityType,
Image,
Transcription,
TranscriptionEntity,
WorkerRun,
WorkerVersion,
open_database,
)
from arkindex_export.queries import list_children
from arkindex_worker.models import Element as ArkindexElement
from arkindex_worker.models import Transcription as ArkindexTranscription
from peewee import CharField
from worker_generic_training_dataset import Extractor
logger: logging.Logger = logging.getLogger(__name__)
BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION = "horizontal-lr"
def get_object_id(instance: WorkerVersion | WorkerRun | None) -> CharField | None:
return instance.id if instance else None
class DatasetExtractorFromSQL(Extractor):
def configure(self):
super().configure()
self.download_latest_export()
def download_latest_export(self) -> None:
"""
Download the latest export of the current corpus.
Export must be in `"done"` state.
"""
try:
exports = list(
self.api_client.paginate(
"ListExports",
id=self.corpus_id,
)
)
except ErrorResponse as e:
logger.error(
f"Could not list exports of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
# Find the latest that is in "done" state
exports: List[dict] = sorted(
list(filter(lambda exp: exp["state"] == "done", exports)),
key=itemgetter("updated"),
reverse=True,
)
assert (
len(exports) > 0
), f"No available exports found for the corpus {self.corpus_id}."
# Download latest export
try:
export_id: str = exports[0]["id"]
logger.info(f"Downloading export ({export_id})...")
self.export: _TemporaryFileWrapper = self.api_client.request(
"DownloadExport",
id=export_id,
)
logger.info(f"Downloaded export ({export_id}) @ `{self.export.name}`")
open_database(self.export.name)
except ErrorResponse as e:
logger.error(
f"Could not download export ({export_id}) of corpus ({self.corpus_id}): {str(e.content)}"
)
raise e
def list_set_elements(self, dataset_id: UUID, set_name: str):
return (
Element.select()
.join(Image)
.switch(Element)
.join(DatasetElement, on=DatasetElement.element)
.where(
DatasetElement.dataset == dataset_id,
DatasetElement.set_name == set_name,
)
)
def list_classifications(self, element_id: UUID):
return (
Classification.select()
.where(Classification.element == element_id)
.iterator()
)
def list_transcriptions(self, element: ArkindexElement, **kwargs):
return (
Transcription.select().where(Transcription.element == element.id).iterator()
)
def list_transcription_entities(
self, transcription: ArkindexTranscription, **kwargs
): # -> Any:
return (
TranscriptionEntity.select()
.where(TranscriptionEntity.transcription == transcription.id)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
).iterator()
def list_element_children(self, element: ArkindexElement, **kwargs):
return list_children(element.id).iterator()
def main():
DatasetExtractorFromSQL(
description="Fill base-worker cache with information about dataset and extract images",
).run()
if __name__ == "__main__":
main()