Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
G
Generic Training Dataset
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Workers
Generic Training Dataset
Merge requests
!2
Implement worker
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
Implement worker
implem
into
main
Overview
80
Commits
25
Pipelines
1
Changes
2
Merged
Yoann Schneider
requested to merge
implem
into
main
1 year ago
Overview
80
Commits
25
Pipelines
1
Changes
2
Expand
Closes
#2 (closed)
0
0
Merge request reports
Viewing commit
ffc88c55
Prev
Next
Show latest version
2 files
+
3
−
3
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
2
Search (e.g. *.vue) (Ctrl+P)
Verified
ffc88c55
fix typo
· ffc88c55
Yoann Schneider
authored
1 year ago
worker_generic_training_dataset/worker.py
+
351
−
6
Options
# -*- coding: utf-8 -*-
from
arkindex_worker.worker
import
ElementsWorker
import
logging
import
operator
import
tempfile
from
argparse
import
Namespace
from
pathlib
import
Path
from
tempfile
import
_TemporaryFileWrapper
from
typing
import
List
,
Optional
from
uuid
import
UUID
from
apistar.exceptions
import
ErrorResponse
from
arkindex_export
import
Element
,
Image
,
open_database
from
arkindex_export.queries
import
list_children
from
arkindex_worker.cache
import
(
CachedClassification
,
CachedElement
,
CachedEntity
,
CachedImage
,
CachedTranscription
,
CachedTranscriptionEntity
,
create_tables
,
create_version_table
,
)
from
arkindex_worker.cache
import
db
as
cache_database
from
arkindex_worker.cache
import
init_cache_db
from
arkindex_worker.image
import
download_image
from
arkindex_worker.utils
import
create_tar_zst_archive
from
arkindex_worker.worker.base
import
BaseWorker
from
worker_generic_training_dataset.db
import
(
list_classifications
,
list_transcription_entities
,
list_transcriptions
,
retrieve_element
,
)
from
worker_generic_training_dataset.utils
import
build_image_url
class
Demo
(
ElementsWorker
):
def
process_element
(
self
,
element
):
print
(
"
Demo processing element
"
,
element
)
logger
:
logging
.
Logger
=
logging
.
getLogger
(
__name__
)
BULK_BATCH_SIZE
=
50
DEFAULT_TRANSCRIPTION_ORIENTATION
=
"
horizontal-lr
"
class
DatasetExtractor
(
BaseWorker
):
def
configure
(
self
)
->
None
:
self
.
args
:
Namespace
=
self
.
parser
.
parse_args
()
if
self
.
is_read_only
:
super
().
configure_for_developers
()
else
:
super
().
configure
()
if
self
.
user_configuration
:
logger
.
info
(
"
Overriding with user_configuration
"
)
self
.
config
.
update
(
self
.
user_configuration
)
# Read process information
self
.
read_training_related_information
()
# Download corpus
self
.
download_latest_export
()
# Initialize db that will be written
self
.
configure_cache
()
# CachedImage downloaded and created in DB
self
.
cached_images
=
dict
()
# Where to save the downloaded images
self
.
image_folder
=
Path
(
tempfile
.
mkdtemp
(
suffix
=
"
-arkindex-data
"
))
logger
.
info
(
f
"
Images will be saved at `
{
self
.
image_folder
}
`.
"
)
def
read_training_related_information
(
self
)
->
None
:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger
.
info
(
"
Retrieving information from process_information
"
)
train_folder_id
=
self
.
process_information
.
get
(
"
train_folder_id
"
)
assert
train_folder_id
,
"
A training folder id is necessary to use this worker
"
self
.
training_folder_id
=
UUID
(
train_folder_id
)
val_folder_id
=
self
.
process_information
.
get
(
"
validation_folder_id
"
)
assert
val_folder_id
,
"
A validation folder id is necessary to use this worker
"
self
.
validation_folder_id
=
UUID
(
val_folder_id
)
test_folder_id
=
self
.
process_information
.
get
(
"
test_folder_id
"
)
self
.
testing_folder_id
:
UUID
|
None
=
(
UUID
(
test_folder_id
)
if
test_folder_id
else
None
)
def
configure_cache
(
self
)
->
None
:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self
.
use_cache
=
True
self
.
cache_path
:
Path
=
self
.
args
.
database
or
self
.
work_dir
/
"
db.sqlite
"
# Remove previous execution result if present
self
.
cache_path
.
unlink
(
missing_ok
=
True
)
init_cache_db
(
self
.
cache_path
)
create_version_table
()
create_tables
()
def
download_latest_export
(
self
)
->
None
:
"""
Download the latest export of the current corpus.
Export must be in `
"
done
"
` state.
"""
try
:
exports
=
list
(
self
.
api_client
.
paginate
(
"
ListExports
"
,
id
=
self
.
corpus_id
,
)
)
except
ErrorResponse
as
e
:
logger
.
error
(
f
"
Could not list exports of corpus (
{
self
.
corpus_id
}
):
{
str
(
e
.
content
)
}
"
)
raise
e
# Find the latest that is in "done" state
exports
:
List
[
dict
]
=
sorted
(
list
(
filter
(
lambda
exp
:
exp
[
"
state
"
]
==
"
done
"
,
exports
)),
key
=
operator
.
itemgetter
(
"
updated
"
),
reverse
=
True
,
)
assert
(
len
(
exports
)
>
0
),
f
"
No available exports found for the corpus
{
self
.
corpus_id
}
.
"
# Download latest export
try
:
export_id
:
str
=
exports
[
0
][
"
id
"
]
logger
.
info
(
f
"
Downloading export (
{
export_id
}
)...
"
)
self
.
export
:
_TemporaryFileWrapper
=
self
.
api_client
.
request
(
"
DownloadExport
"
,
id
=
export_id
,
)
logger
.
info
(
f
"
Downloaded export (
{
export_id
}
) @ `
{
self
.
export
.
name
}
`
"
)
open_database
(
self
.
export
.
name
)
except
ErrorResponse
as
e
:
logger
.
error
(
f
"
Could not download export (
{
export_id
}
) of corpus (
{
self
.
corpus_id
}
):
{
str
(
e
.
content
)
}
"
)
raise
e
def
insert_classifications
(
self
,
element
:
CachedElement
)
->
None
:
logger
.
info
(
"
Listing classifications
"
)
classifications
:
list
[
CachedClassification
]
=
[
CachedClassification
(
id
=
classification
.
id
,
element
=
element
,
class_name
=
classification
.
class_name
,
confidence
=
classification
.
confidence
,
state
=
classification
.
state
,
)
for
classification
in
list_classifications
(
element
.
id
)
]
if
classifications
:
logger
.
info
(
f
"
Inserting
{
len
(
classifications
)
}
classification(s)
"
)
with
cache_database
.
atomic
():
CachedClassification
.
bulk_create
(
model_list
=
classifications
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_transcriptions
(
self
,
element
:
CachedElement
)
->
List
[
CachedTranscription
]:
logger
.
info
(
"
Listing transcriptions
"
)
transcriptions
:
list
[
CachedTranscription
]
=
[
CachedTranscription
(
id
=
transcription
.
id
,
element
=
element
,
text
=
transcription
.
text
,
# Dodge not-null constraint for now
confidence
=
transcription
.
confidence
or
1.0
,
orientation
=
DEFAULT_TRANSCRIPTION_ORIENTATION
,
worker_version_id
=
transcription
.
worker_version
.
id
if
transcription
.
worker_version
else
None
,
)
for
transcription
in
list_transcriptions
(
element
.
id
)
]
if
transcriptions
:
logger
.
info
(
f
"
Inserting
{
len
(
transcriptions
)
}
transcription(s)
"
)
with
cache_database
.
atomic
():
CachedTranscription
.
bulk_create
(
model_list
=
transcriptions
,
batch_size
=
BULK_BATCH_SIZE
,
)
return
transcriptions
def
insert_entities
(
self
,
transcriptions
:
List
[
CachedTranscription
])
->
None
:
logger
.
info
(
"
Listing entities
"
)
entities
:
List
[
CachedEntity
]
=
[]
transcription_entities
:
List
[
CachedTranscriptionEntity
]
=
[]
for
transcription
in
transcriptions
:
for
transcription_entity
in
list_transcription_entities
(
transcription
.
id
):
entity
=
CachedEntity
(
id
=
transcription_entity
.
entity
.
id
,
type
=
transcription_entity
.
entity
.
type
.
name
,
name
=
transcription_entity
.
entity
.
name
,
validated
=
transcription_entity
.
entity
.
validated
,
metas
=
transcription_entity
.
entity
.
metas
,
)
entities
.
append
(
entity
)
transcription_entities
.
append
(
CachedTranscriptionEntity
(
id
=
transcription_entity
.
id
,
transcription
=
transcription
,
entity
=
entity
,
offset
=
transcription_entity
.
offset
,
length
=
transcription_entity
.
length
,
confidence
=
transcription_entity
.
confidence
,
)
)
if
entities
:
# First insert entities since they are foreign keys on transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
entities
)
}
entities
"
)
with
cache_database
.
atomic
():
CachedEntity
.
bulk_create
(
model_list
=
entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
if
transcription_entities
:
# Insert transcription entities
logger
.
info
(
f
"
Inserting
{
len
(
transcription_entities
)
}
transcription entities
"
)
with
cache_database
.
atomic
():
CachedTranscriptionEntity
.
bulk_create
(
model_list
=
transcription_entities
,
batch_size
=
BULK_BATCH_SIZE
,
)
def
insert_element
(
self
,
element
:
Element
,
parent_id
:
Optional
[
UUID
]
=
None
)
->
None
:
"""
Insert the given element in the cache database.
Its image will also be saved to disk, if it wasn
'
t already.
The insertion of an element includes:
- its classifications
- its transcriptions
- its transcriptions
'
entities (both Entity and TranscriptionEntity)
:param element: Element to insert.
:param parent_id: ID of the parent to use when creating the CachedElement. Do not specify for top-level elements.
"""
logger
.
info
(
f
"
Processing element (
{
element
.
id
}
)
"
)
if
element
.
image
and
element
.
image
.
id
not
in
self
.
cached_images
:
# Download image
logger
.
info
(
"
Downloading image
"
)
download_image
(
url
=
build_image_url
(
element
)).
save
(
self
.
image_folder
/
f
"
{
element
.
image
.
id
}
.jpg
"
)
# Insert image
logger
.
info
(
"
Inserting image
"
)
# Store images in case some other elements use it as well
with
cache_database
.
atomic
():
self
.
cached_images
[
element
.
image
.
id
]
=
CachedImage
.
create
(
id
=
element
.
image
.
id
,
width
=
element
.
image
.
width
,
height
=
element
.
image
.
height
,
url
=
element
.
image
.
url
,
)
# Insert element
logger
.
info
(
"
Inserting element
"
)
with
cache_database
.
atomic
():
cached_element
:
CachedElement
=
CachedElement
.
create
(
id
=
element
.
id
,
parent_id
=
parent_id
,
type
=
element
.
type
,
image
=
self
.
cached_images
[
element
.
image
.
id
]
if
element
.
image
else
None
,
polygon
=
element
.
polygon
,
rotation_angle
=
element
.
rotation_angle
,
mirrored
=
element
.
mirrored
,
worker_version_id
=
element
.
worker_version
.
id
if
element
.
worker_version
else
None
,
confidence
=
element
.
confidence
,
)
# Insert classifications
self
.
insert_classifications
(
cached_element
)
# Insert transcriptions
transcriptions
:
List
[
CachedTranscription
]
=
self
.
insert_transcriptions
(
cached_element
)
# Insert entities
self
.
insert_entities
(
transcriptions
)
def
process_split
(
self
,
split_name
:
str
,
split_id
:
UUID
)
->
None
:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
logger
.
info
(
f
"
Filling the Base-Worker cache with information from children under element (
{
split_id
}
)
"
)
# Fill cache
# Retrieve parent and create parent
parent
:
Element
=
retrieve_element
(
split_id
)
self
.
insert_element
(
parent
)
# First list all pages
pages
=
list_children
(
split_id
).
join
(
Image
).
where
(
Element
.
type
==
"
page
"
)
nb_pages
:
int
=
pages
.
count
()
for
idx
,
page
in
enumerate
(
pages
,
start
=
1
):
logger
.
info
(
f
"
Processing `
{
split_name
}
` page (
{
idx
}
/
{
nb_pages
}
)
"
)
# Insert page
self
.
insert_element
(
page
,
parent_id
=
split_id
)
# List children
children
=
list_children
(
page
.
id
)
nb_children
:
int
=
children
.
count
()
for
child_idx
,
child
in
enumerate
(
children
,
start
=
1
):
logger
.
info
(
f
"
Processing child (
{
child_idx
}
/
{
nb_children
}
)
"
)
# Insert child
self
.
insert_element
(
child
,
parent_id
=
page
.
id
)
def
run
(
self
):
self
.
configure
()
# Iterate over given split
for
split_name
,
split_id
in
[
(
"
Train
"
,
self
.
training_folder_id
),
(
"
Validation
"
,
self
.
validation_folder_id
),
(
"
Test
"
,
self
.
testing_folder_id
),
]:
if
not
split_id
:
continue
self
.
process_split
(
split_name
,
split_id
)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path
:
Path
=
self
.
work_dir
/
"
arkindex_data.zstd
"
logger
.
info
(
f
"
Compressing the images to
{
zstd_archive_path
}
"
)
create_tar_zst_archive
(
source
=
self
.
image_folder
,
destination
=
zstd_archive_path
)
def
main
():
Demo
(
description
=
"
Fill base-worker cache with information about dataset and extract images
"
DatasetExtractor
(
description
=
"
Fill base-worker cache with information about dataset and extract images
"
,
support_cache
=
True
,
).
run
()
Loading