Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
G
Generic Training Dataset
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Workers
Generic Training Dataset
Merge requests
!8
New DatasetExtractor using a DatasetWorker
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
New DatasetExtractor using a DatasetWorker
dataset-worker
into
main
Overview
13
Commits
8
Pipelines
12
Changes
3
All threads resolved!
Hide all comments
Merged
Eva Bardou
requested to merge
dataset-worker
into
main
1 year ago
Overview
13
Commits
8
Pipelines
12
Changes
3
All threads resolved!
Hide all comments
Expand
Depends on
workers/base-worker!411 (merged)
Closes
#3 (closed)
,
#6 (closed)
Edited
1 year ago
by
Eva Bardou
0
0
Merge request reports
Viewing commit
a520d362
Prev
Next
Show latest version
3 files
+
159
−
527
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
3
Search (e.g. *.vue) (Ctrl+P)
a520d362
Move code
· a520d362
Eva Bardou
authored
1 year ago
worker_generic_training_dataset/dataset_worker.py deleted
100644 → 0
+
0
−
458
Options
# -*- coding: utf-8 -*-
import
logging
import
sys
import
tempfile
from
argparse
import
Namespace
from
itertools
import
groupby
from
operator
import
itemgetter
from
pathlib
import
Path
from
tempfile
import
_TemporaryFileWrapper
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
from
uuid
import
UUID
from
apistar.exceptions
import
ErrorResponse
from
arkindex_export
import
Element
,
open_database
from
arkindex_export.queries
import
list_children
from
arkindex_worker.cache
import
(
CachedClassification
,
CachedElement
,
CachedEntity
,
CachedImage
,
CachedTranscription
,
CachedTranscriptionEntity
,
create_tables
,
create_version_table
,
)
from
arkindex_worker.cache
import
db
as
cache_database
from
arkindex_worker.cache
import
init_cache_db
from
arkindex_worker.image
import
download_image
from
arkindex_worker.models
import
Dataset
from
arkindex_worker.utils
import
create_tar_zst_archive
from
arkindex_worker.worker.base
import
BaseWorker
from
arkindex_worker.worker.dataset
import
DatasetMixin
,
DatasetState
from
worker_generic_training_dataset.db
import
(
list_classifications
,
list_transcription_entities
,
list_transcriptions
,
)
from
worker_generic_training_dataset.utils
import
build_image_url
from
worker_generic_training_dataset.worker
import
(
BULK_BATCH_SIZE
,
DEFAULT_TRANSCRIPTION_ORIENTATION
,
)
logger
:
logging
.
Logger
=
logging
.
getLogger
(
__name__
)
class
DatasetWorker
(
BaseWorker
,
DatasetMixin
):
def
__init__
(
self
,
description
:
str
=
"
Arkindex Elements Worker
"
,
support_cache
:
bool
=
False
,
generator
:
bool
=
False
,
):
super
().
__init__
(
description
,
support_cache
)
self
.
parser
.
add_argument
(
"
--dataset
"
,
type
=
UUID
,
nargs
=
"
+
"
,
help
=
"
One or more Arkindex dataset ID
"
,
)
self
.
generator
=
generator
def
list_dataset_elements_per_set
(
self
,
dataset
:
Dataset
)
->
Iterator
[
Tuple
[
str
,
Element
]]:
"""
Calls `list_dataset_elements` but returns results grouped by Set
"""
def
format_element
(
element
):
return
Element
.
get
(
Element
.
id
==
element
[
1
].
id
)
def
format_set
(
set
):
return
(
set
[
0
],
list
(
map
(
format_element
,
list
(
set
[
1
]))))
return
list
(
map
(
format_set
,
groupby
(
sorted
(
self
.
list_dataset_elements
(
dataset
),
key
=
itemgetter
(
0
)),
key
=
itemgetter
(
0
),
),
)
)
def
process_dataset
(
self
,
dataset
:
Dataset
):
"""
Override this method to implement your worker and process a single Arkindex dataset at once.
:param dataset: The dataset to process.
"""
def
list_datasets
(
self
)
->
List
[
Dataset
]
|
List
[
str
]:
"""
Calls `list_process_datasets` if not is_read_only,
else simply give the list of IDs provided via CLI
"""
if
self
.
is_read_only
:
return
list
(
map
(
str
,
self
.
args
.
dataset
))
return
self
.
list_process_datasets
()
def
run
(
self
):
self
.
configure
()
datasets
:
List
[
Dataset
]
|
List
[
str
]
=
self
.
list_datasets
()
if
not
datasets
:
logger
.
warning
(
"
No datasets to process, stopping.
"
)
sys
.
exit
(
1
)
# Process every dataset
count
=
len
(
datasets
)
failed
=
0
for
i
,
item
in
enumerate
(
datasets
,
start
=
1
):
dataset
=
None
try
:
if
not
self
.
is_read_only
:
# Just use the result of list_datasets as the dataset
dataset
=
item
else
:
# Load dataset using the Arkindex API
dataset
=
Dataset
(
**
self
.
request
(
"
RetrieveDataset
"
,
id
=
item
))
if
self
.
generator
:
assert
(
dataset
.
state
==
DatasetState
.
Open
.
value
),
"
When generating a new dataset, its state should be Open
"
else
:
assert
(
dataset
.
state
==
DatasetState
.
Complete
.
value
),
"
When processing an existing dataset, its state should be Complete
"
if
self
.
generator
:
# Update the dataset state to Building
logger
.
info
(
f
"
Building
{
dataset
}
(
{
i
}
/
{
count
}
)
"
)
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Building
)
# Process the dataset
self
.
process_dataset
(
dataset
)
if
self
.
generator
:
# Update the dataset state to Complete
logger
.
info
(
f
"
Completed
{
dataset
}
(
{
i
}
/
{
count
}
)
"
)
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Complete
)
except
Exception
as
e
:
# Handle errors occurring while retrieving, processing or patching the state for this dataset.
failed
+=
1
# Handle the case where we failed retrieving the dataset
dataset_id
=
dataset
.
id
if
dataset
else
item
if
isinstance
(
e
,
ErrorResponse
):
message
=
f
"
An API error occurred while processing dataset
{
dataset_id
}
:
{
e
.
title
}
-
{
e
.
content
}
"
else
:
message
=
(
f
"
Failed running worker on dataset
{
dataset_id
}
:
{
repr
(
e
)
}
"
)
logger
.
warning
(
message
,
exc_info
=
e
if
self
.
args
.
verbose
else
None
,
)
if
dataset
and
self
.
generator
:
# Try to update the state to Error regardless of the response
try
:
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Error
)
except
Exception
:
pass
if
failed
:
logger
.
error
(
"
Ran on {} dataset: {} completed, {} failed
"
.
format
(
count
,
count
-
failed
,
failed
)
)
if
failed
>=
count
:
# Everything failed!
sys
.
exit
(
1
)
class
DatasetExtractor
(
DatasetWorker
):
def
configure
(
self
)
->
None
:
self
.
args
:
Namespace
=
self
.
parser
.
parse_args
()
if
self
.
is_read_only
:
super
().
configure_for_developers
()
else
:
super
().
configure
()
if
self
.
user_configuration
:
logger
.
info
(
"
Overriding with user_configuration
"
)
self
.
config
.
update
(
self
.
user_configuration
)
# Download corpus
self
.
download_latest_export
()
# Initialize db that will be written
self
.
configure_cache
()
# CachedImage downloaded and created in DB
self
.
cached_images
=
dict
()
# Where to save the downloaded images
self
.
image_folder
=
Path
(
tempfile
.
mkdtemp
(
suffix
=
"
-arkindex-data
"
))
logger
.
info
(
f
"
Images will be saved at `
{
self
.
image_folder
}
`.
"
)
def
configure_cache
(
self
)
->
None
:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self
.
use_cache
=
True
self
.
cache_path
:
Path
=
self
.
args
.
database
or
self
.
work_dir
/
"
db.sqlite
"
# Remove previous execution result if present
self
.
cache_path
.
unlink
(
missing_ok
=
True
)
init_cache_db
(
self
.
cache_path
)
create_version_table
()
create_tables
()
def
download_latest_export
(
self
)
->
None
:
"""
Download the latest export of the current corpus.
Export must be in `
"
done
"
` state.
"""
try
:
exports
=
list
(
self
.
api_client
.
paginate
(
"
ListExports
"
,
id
=
self
.
corpus_id
,
)
)
except
ErrorResponse
as
e
:
logger
.
error
(
f
"
Could not list exports of corpus (
{
self
.
corpus_id
}
):
{
str
(
e
.
content
)
}
"
)
raise
e
# Find the latest that is in "done" state
exports
:
List
[
dict
]
=
sorted
(
list
(
filter
(
lambda
exp
:
exp
[
"
state
"
]
==
"
done
"
,
exports
)),
key
=
itemgetter
(
"
updated
"
),
reverse
=
True
,
)
assert
(
len
(
exports
)
>
0
),
f
"
No available exports found for the corpus
{
self
.
corpus_id
}
.
"
# Download latest export
try
:
export_id
:
str
=
exports
[
0
][
"
id
"
]
logger
.
info
(
f
"
Downloading export (
{
export_id
}
)...
"
)
self
.
export
:
_TemporaryFileWrapper
=
self
.
api_client
.
request
(
"
DownloadExport
"
,
id
=
export_id
,
)
logger
.
info
(
f
"
Downloaded export (
{
export_id
}
) @ `
{
self
.
export
.
name
}
`
"
)
open_database
(
self
.
export
.
name
)
except
ErrorResponse
as
e
:
logger
.
error
(
f
"
Could not download export (
{
export_id
}
) of corpus (
{
self
.
corpus_id
}
):
{
str
(
e
.
content
)
}
"
)
raise
e
def
insert_classifications
(
self
,
element
:
CachedElement
)
->
None
:
logger
.
info
(
"
Listing classifications
"
)
classifications
:
list
[
CachedClassification
]
=
[
CachedClassification
(
id
=
classification
.
id
,
element
=
element
,
class_name
=
classification
.
class_name
,
confidence
=
classification
.
confidence
,
state
=
classification
.
state
,
worker_run_id
=
classification
.
worker_run
,
)
for
classification
in
list_classifications
(
element
.
id
)
]
if
classifications
:
logger
.
info
(
f
"
Inserting
{
len
(
classifications
)
}
classification(s)
"
)
with
cache_database
.
atomic
():
CachedClassification
.
bulk_create
(
model_list
=
classifications
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_transcriptions
(
self
,
element
:
CachedElement
)
->
List
[
CachedTranscription
]:
logger
.
info
(
"
Listing transcriptions
"
)
transcriptions
:
list
[
CachedTranscription
]
=
[
CachedTranscription
(
id
=
transcription
.
id
,
element
=
element
,
text
=
transcription
.
text
,
confidence
=
transcription
.
confidence
,
orientation
=
DEFAULT_TRANSCRIPTION_ORIENTATION
,
worker_version_id
=
transcription
.
worker_version
,
worker_run_id
=
transcription
.
worker_run
,
)
for
transcription
in
list_transcriptions
(
element
.
id
)
]
if
transcriptions
:
logger
.
info
(
f
"
Inserting
{
len
(
transcriptions
)
}
transcription(s)
"
)
with
cache_database
.
atomic
():
CachedTranscription
.
bulk_create
(
model_list
=
transcriptions
,
batch_size
=
BULK_BATCH_SIZE
,
)
return
transcriptions
def
insert_entities
(
self
,
transcriptions
:
List
[
CachedTranscription
])
->
None
:
logger
.
info
(
"
Listing entities
"
)
entities
:
List
[
CachedEntity
]
=
[]
transcription_entities
:
List
[
CachedTranscriptionEntity
]
=
[]
for
transcription
in
transcriptions
:
for
transcription_entity
in
list_transcription_entities
(
transcription
.
id
):
entity
=
CachedEntity
(
id
=
transcription_entity
.
entity
.
id
,
type
=
transcription_entity
.
entity
.
type
.
name
,
name
=
transcription_entity
.
entity
.
name
,
validated
=
transcription_entity
.
entity
.
validated
,
metas
=
transcription_entity
.
entity
.
metas
,
worker_run_id
=
transcription_entity
.
entity
.
worker_run
,
)
entities
.
append
(
entity
)
transcription_entities
.
append
(
CachedTranscriptionEntity
(
id
=
transcription_entity
.
id
,
transcription
=
transcription
,
entity
=
entity
,
offset
=
transcription_entity
.
offset
,
length
=
transcription_entity
.
length
,
confidence
=
transcription_entity
.
confidence
,
worker_run_id
=
transcription_entity
.
worker_run
,
)
)
if
entities
:
# First insert entities since they are foreign keys on transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
entities
)
}
entities
"
)
with
cache_database
.
atomic
():
CachedEntity
.
bulk_create
(
model_list
=
entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
if
transcription_entities
:
# Insert transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
transcription_entities
)
}
transcription entities
"
)
with
cache_database
.
atomic
():
CachedTranscriptionEntity
.
bulk_create
(
model_list
=
transcription_entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_element
(
self
,
element
:
Element
,
parent_id
:
Optional
[
UUID
]
=
None
)
->
None
:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn
'
t already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions
'
entities (both Entity and TranscriptionEntity)
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger
.
info
(
f
"
Processing element (
{
element
.
id
}
)
"
)
if
element
.
image
and
element
.
image
.
id
not
in
self
.
cached_images
:
# Download image
logger
.
info
(
"
Downloading image
"
)
download_image
(
url
=
build_image_url
(
element
)).
save
(
self
.
image_folder
/
f
"
{
element
.
image
.
id
}
.jpg
"
)
# Insert image
logger
.
info
(
"
Inserting image
"
)
# Store images in case some other elements use it as well
with
cache_database
.
atomic
():
self
.
cached_images
[
element
.
image
.
id
]
=
CachedImage
.
create
(
id
=
element
.
image
.
id
,
width
=
element
.
image
.
width
,
height
=
element
.
image
.
height
,
url
=
element
.
image
.
url
,
)
# Insert element
logger
.
info
(
"
Inserting element
"
)
with
cache_database
.
atomic
():
cached_element
:
CachedElement
=
CachedElement
.
create
(
id
=
element
.
id
,
parent_id
=
parent_id
,
type
=
element
.
type
,
image
=
self
.
cached_images
[
element
.
image
.
id
]
if
element
.
image
else
None
,
polygon
=
element
.
polygon
,
rotation_angle
=
element
.
rotation_angle
,
mirrored
=
element
.
mirrored
,
worker_version_id
=
element
.
worker_version
,
worker_run_id
=
element
.
worker_run
,
confidence
=
element
.
confidence
,
)
# Insert classifications
self
.
insert_classifications
(
cached_element
)
# Insert transcriptions
transcriptions
:
List
[
CachedTranscription
]
=
self
.
insert_transcriptions
(
cached_element
)
# Insert entities
self
.
insert_entities
(
transcriptions
)
def
process_set
(
self
,
set_name
:
str
,
elements
:
List
[
Element
])
->
None
:
logger
.
info
(
f
"
Filling the cache with information from elements in the set
{
set_name
}
"
)
# First list all pages
nb_elements
:
int
=
len
(
elements
)
for
idx
,
element
in
enumerate
(
elements
,
start
=
1
):
logger
.
info
(
f
"
Processing `
{
set_name
}
` element (
{
idx
}
/
{
nb_elements
}
)
"
)
# Insert page
self
.
insert_element
(
element
)
# List children
children
=
list_children
(
element
.
id
)
nb_children
:
int
=
children
.
count
()
for
child_idx
,
child
in
enumerate
(
children
,
start
=
1
):
logger
.
info
(
f
"
Processing child (
{
child_idx
}
/
{
nb_children
}
)
"
)
# Insert child
self
.
insert_element
(
child
,
parent_id
=
element
.
id
)
def
process_dataset
(
self
,
dataset
:
Dataset
):
# Iterate over given sets
for
set_name
,
elements
in
self
.
list_dataset_elements_per_set
(
dataset
):
self
.
process_set
(
set_name
,
elements
)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path
:
Path
=
self
.
work_dir
/
f
"
{
dataset
.
id
}
.zstd
"
logger
.
info
(
f
"
Compressing the images to
{
zstd_archive_path
}
"
)
create_tar_zst_archive
(
source
=
self
.
image_folder
,
destination
=
zstd_archive_path
)
def
main
():
DatasetExtractor
(
description
=
"
Fill base-worker cache with information about dataset and extract images
"
,
generator
=
True
,
).
run
()
if
__name__
==
"
__main__
"
:
main
()
Loading