Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
G
Generic Training Dataset
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Workers
Generic Training Dataset
Merge requests
!25
Draft: Refactor and implement API version of the worker
Code
Review changes
Check out branch
Download
Patches
Plain diff
Open
Draft: Refactor and implement API version of the worker
new-api-worker
into
main
Overview
2
Commits
9
Pipelines
11
Changes
4
Open
Yoann Schneider
requested to merge
new-api-worker
into
main
10 months ago
Overview
2
Commits
9
Pipelines
11
Changes
4
Expand
Closes
#17
0
0
Merge request reports
Viewing commit
3a76f403
Prev
Next
Show latest version
4 files
+
45
−
46
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
4
Search (e.g. *.vue) (Ctrl+P)
Verified
3a76f403
Fix API worker
· 3a76f403
Yoann Schneider
authored
10 months ago
worker_generic_training_dataset/__init__.py
+
352
−
0
Options
# -*- coding: utf-8 -*-
import
contextlib
import
json
import
logging
import
sys
import
tempfile
import
uuid
from
collections.abc
import
Iterable
from
itertools
import
groupby
from
operator
import
attrgetter
from
pathlib
import
Path
from
typing
import
List
,
Optional
from
apistar.exceptions
import
ErrorResponse
from
arkindex_export
import
Element
,
WorkerRun
,
WorkerVersion
from
arkindex_worker.cache
import
(
CachedClassification
,
CachedDataset
,
CachedDatasetElement
,
CachedElement
,
CachedEntity
,
CachedImage
,
CachedTranscription
,
CachedTranscriptionEntity
,
create_tables
,
create_version_table
,
)
from
arkindex_worker.cache
import
db
as
cache_database
from
arkindex_worker.cache
import
init_cache_db
from
arkindex_worker.image
import
download_image
from
arkindex_worker.models
import
Dataset
from
arkindex_worker.models
import
Element
as
ArkindexElement
from
arkindex_worker.models
import
Set
from
arkindex_worker.utils
import
create_tar_zst_archive
from
arkindex_worker.worker
import
DatasetWorker
from
arkindex_worker.worker.dataset
import
DatasetState
from
peewee
import
CharField
from
worker_generic_training_dataset.utils
import
build_image_url
logger
:
logging
.
Logger
=
logging
.
getLogger
(
__name__
)
BULK_BATCH_SIZE
=
50
DEFAULT_TRANSCRIPTION_ORIENTATION
=
"
horizontal-lr
"
def
get_object_id
(
instance
:
WorkerVersion
|
WorkerRun
|
None
)
->
CharField
|
None
:
return
instance
.
id
if
instance
else
None
class
Extractor
(
DatasetWorker
):
def
configure_storage
(
self
)
->
None
:
self
.
data_folder
=
tempfile
.
TemporaryDirectory
(
suffix
=
"
-arkindex-data
"
)
self
.
data_folder_path
=
Path
(
self
.
data_folder
.
name
)
# Initialize db that will be written
self
.
configure_cache
()
# CachedImage downloaded and created in DB
self
.
cached_images
=
dict
()
# Where to save the downloaded images
self
.
images_folder
=
self
.
data_folder_path
/
"
images
"
self
.
images_folder
.
mkdir
(
parents
=
True
)
logger
.
info
(
f
"
Images will be saved at `
{
self
.
images_folder
}
`.
"
)
def
configure_cache
(
self
)
->
None
:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self
.
cache_path
:
Path
=
self
.
data_folder_path
/
"
db.sqlite
"
logger
.
info
(
f
"
Cached database will be saved at `
{
self
.
cache_path
}
`.
"
)
init_cache_db
(
self
.
cache_path
)
create_version_table
()
create_tables
()
def
insert_classifications
(
self
,
element
:
CachedElement
,
classifications
:
list
[
dict
]
)
->
None
:
logger
.
info
(
"
Listing classifications
"
)
element_classifications
:
list
[
CachedClassification
]
=
self
.
get_classifications
(
element
,
classifications
)
if
element_classifications
:
logger
.
info
(
f
"
Inserting
{
len
(
element_classifications
)
}
classification(s)
"
)
with
cache_database
.
atomic
():
CachedClassification
.
bulk_create
(
model_list
=
element_classifications
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_transcriptions
(
self
,
element
:
CachedElement
)
->
List
[
CachedTranscription
]:
logger
.
info
(
"
Listing transcriptions
"
)
transcriptions
:
list
[
CachedTranscription
]
=
self
.
get_transcriptions
(
element
)
if
transcriptions
:
logger
.
info
(
f
"
Inserting
{
len
(
transcriptions
)
}
transcription(s)
"
)
with
cache_database
.
atomic
():
CachedTranscription
.
bulk_create
(
model_list
=
transcriptions
,
batch_size
=
BULK_BATCH_SIZE
,
)
return
transcriptions
def
insert_entities
(
self
,
transcriptions
:
List
[
CachedTranscription
])
->
None
:
logger
.
info
(
"
Listing entities
"
)
entities
:
List
[
CachedEntity
]
=
[]
transcription_entities
:
List
[
CachedTranscriptionEntity
]
=
[]
for
transcription
in
transcriptions
:
parsed_entities
=
self
.
get_transcription_entities
(
transcription
)
entities
.
extend
(
parsed_entities
[
0
])
transcription_entities
.
extend
(
parsed_entities
[
1
])
if
entities
:
# First insert entities since they are foreign keys on transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
entities
)
}
entities
"
)
with
cache_database
.
atomic
():
CachedEntity
.
bulk_create
(
model_list
=
entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
if
transcription_entities
:
# Insert transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
transcription_entities
)
}
transcription entities
"
)
with
cache_database
.
atomic
():
CachedTranscriptionEntity
.
bulk_create
(
model_list
=
transcription_entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_element
(
self
,
element
:
Element
|
ArkindexElement
,
split_name
:
Optional
[
str
]
=
None
,
parent_id
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn
'
t already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions
'
entities (both Entity and TranscriptionEntity)
The element will also be linked to the appropriate split in the current dataset.
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger
.
info
(
f
"
Processing element (
{
element
}
)
"
)
polygon
=
element
.
polygon
if
isinstance
(
element
,
Element
):
# SQL result
image
=
element
.
image
wk_version
=
get_object_id
(
element
.
worker_version
)
wk_run
=
get_object_id
(
element
.
worker_run
)
else
:
# API result
polygon
=
json
.
dumps
(
polygon
)
image
=
element
.
zone
.
image
wk_version
=
(
element
.
worker_version
if
hasattr
(
element
,
"
worker_version
"
)
else
element
.
worker_version_id
)
wk_run
=
element
.
worker_run
.
id
if
element
.
worker_run
else
None
if
image
and
image
.
id
not
in
self
.
cached_images
:
# Download image
logger
.
info
(
"
Downloading image
"
)
download_image
(
url
=
build_image_url
(
image
,
polygon
)).
save
(
self
.
images_folder
/
f
"
{
image
.
id
}
.jpg
"
)
# Insert image
logger
.
info
(
"
Inserting image
"
)
# Store images in case some other elements use it as well
with
cache_database
.
atomic
():
self
.
cached_images
[
image
.
id
]
=
CachedImage
.
create
(
id
=
image
.
id
,
width
=
image
.
width
,
height
=
image
.
height
,
url
=
image
.
url
,
)
# Insert element
logger
.
info
(
"
Inserting element
"
)
with
cache_database
.
atomic
():
cached_element
:
CachedElement
=
CachedElement
.
create
(
id
=
element
.
id
,
parent_id
=
parent_id
,
type
=
element
.
type
,
image
=
self
.
cached_images
[
image
.
id
]
if
image
else
None
,
polygon
=
polygon
,
rotation_angle
=
element
.
rotation_angle
,
mirrored
=
element
.
mirrored
,
worker_version_id
=
wk_version
,
worker_run_id
=
wk_run
,
confidence
=
element
.
confidence
,
)
# Insert classifications
classifications
=
[]
if
isinstance
(
element
,
ArkindexElement
):
classifications
=
(
element
.
classifications
if
hasattr
(
element
,
"
classifications
"
)
else
element
.
classes
)
self
.
insert_classifications
(
cached_element
,
classifications
=
classifications
)
# Insert transcriptions
transcriptions
:
List
[
CachedTranscription
]
=
self
.
insert_transcriptions
(
cached_element
)
# Insert entities
self
.
insert_entities
(
transcriptions
)
# Link the element to the dataset split
if
split_name
:
logger
.
info
(
f
"
Linking element
{
cached_element
.
id
}
to dataset (
{
self
.
cached_dataset
.
id
}
)
"
)
with
cache_database
.
atomic
():
CachedDatasetElement
.
create
(
id
=
uuid
.
uuid4
(),
element
=
cached_element
,
dataset
=
self
.
cached_dataset
,
set_name
=
split_name
,
)
def
process_split
(
self
,
split_name
:
str
,
elements
:
Iterable
[
Element
|
ArkindexElement
]
)
->
None
:
logger
.
info
(
f
"
Filling the cache with information from elements in the split
{
split_name
}
"
)
for
idx
,
element
in
enumerate
(
elements
,
start
=
1
):
logger
.
info
(
f
"
Processing `
{
split_name
}
` element (n°
{
idx
}
)
"
)
# Insert page
self
.
insert_element
(
element
,
split_name
=
split_name
)
# List children
children
=
self
.
list_element_children
(
element
)
for
child_idx
,
child
in
enumerate
(
children
,
start
=
1
):
logger
.
info
(
f
"
Processing
{
child
}
(n°
{
child_idx
}
)
"
)
# Insert child
self
.
insert_element
(
child
,
parent_id
=
element
.
id
)
def
insert_dataset
(
self
,
dataset
:
Dataset
)
->
None
:
"""
Insert the given dataset in the cache database.
:param dataset: Dataset to insert.
"""
logger
.
info
(
f
"
Inserting dataset (
{
dataset
.
id
}
)
"
)
with
cache_database
.
atomic
():
self
.
cached_dataset
=
CachedDataset
.
create
(
id
=
dataset
.
id
,
name
=
dataset
.
name
,
state
=
dataset
.
state
,
sets
=
json
.
dumps
(
dataset
.
sets
),
)
def
process_dataset
(
self
,
dataset
:
Dataset
,
sets
:
list
[
Set
]):
# Configure temporary storage for the dataset data (cache + images)
self
.
configure_storage
()
# Insert dataset in cache database
self
.
insert_dataset
(
dataset
)
# Iterate over given splits
for
dataset_set
in
sets
:
elements
=
self
.
list_set_elements
(
dataset_set
)
self
.
process_split
(
dataset_set
.
name
,
elements
)
# TAR + ZST the cache and the images folder, and store as task artifact
zst_archive_path
:
Path
=
self
.
work_dir
/
dataset
.
filepath
logger
.
info
(
f
"
Compressing the images to
{
zst_archive_path
}
"
)
create_tar_zst_archive
(
source
=
self
.
data_folder_path
,
destination
=
zst_archive_path
)
self
.
data_folder
.
cleanup
()
def
run
(
self
):
self
.
configure
()
dataset_sets
:
list
[
Set
]
=
list
(
self
.
list_sets
())
grouped_sets
:
list
[
tuple
[
Dataset
,
list
[
Set
]]]
=
[
(
dataset
,
list
(
sets
))
for
dataset
,
sets
in
groupby
(
dataset_sets
,
attrgetter
(
"
dataset
"
))
]
if
not
grouped_sets
:
logger
.
warning
(
"
No datasets to process, stopping.
"
)
sys
.
exit
(
1
)
# Process every dataset
count
=
len
(
grouped_sets
)
failed
=
0
for
i
,
(
dataset
,
sets
)
in
enumerate
(
grouped_sets
,
start
=
1
):
try
:
# assert dataset.state in [
# DatasetState.Open.value,
# DatasetState.Error.value,
# ], "When generating a new dataset, its state should be Open or Error."
# Update the dataset state to Building
logger
.
info
(
f
"
Building
{
dataset
}
(
{
i
}
/
{
count
}
)
"
)
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Building
)
logger
.
info
(
f
"
Processing
{
dataset
}
(
{
i
}
/
{
count
}
)
"
)
self
.
process_dataset
(
dataset
,
sets
)
# Update the dataset state to Complete
logger
.
info
(
f
"
Completed
{
dataset
}
(
{
i
}
/
{
count
}
)
"
)
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Complete
)
except
Exception
as
e
:
# Handle errors occurring while processing or patching the state for this dataset
failed
+=
1
import
traceback
traceback
.
print_exc
()
if
isinstance
(
e
,
ErrorResponse
):
message
=
f
"
An API error occurred while processing
{
dataset
}
:
{
e
.
title
}
-
{
e
.
content
}
"
else
:
message
=
f
"
Failed running worker on
{
dataset
}
:
{
repr
(
e
)
}
"
logger
.
warning
(
message
,
exc_info
=
e
if
self
.
args
.
verbose
else
None
,
)
# Try to update the state to Error regardless of the response
with
contextlib
.
suppress
(
Exception
):
self
.
update_dataset_state
(
dataset
,
DatasetState
.
Error
)
message
=
f
'
Ran on
{
count
}
dataset
{
"
s
"
[
:
count
>
1
]
}
:
{
count
-
failed
}
completed,
{
failed
}
failed
'
if
failed
:
logger
.
error
(
message
)
if
failed
>=
count
:
# Everything failed!
sys
.
exit
(
1
)
else
:
logger
.
info
(
message
)
Loading