Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
G
Generic Training Dataset
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Arkindex
Workers
Generic Training Dataset
Commits
f2988ecd
Commit
f2988ecd
authored
1 year ago
by
Eva Bardou
Committed by
Yoann Schneider
1 year ago
Browse files
Options
Downloads
Patches
Plain Diff
New DatasetExtractor using a DatasetWorker
parent
07969c74
No related branches found
No related tags found
1 merge request
!8
New DatasetExtractor using a DatasetWorker
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
.arkindex.yml
+0
-13
0 additions, 13 deletions
.arkindex.yml
requirements.txt
+1
-1
1 addition, 1 deletion
requirements.txt
tests/test_worker.py
+25
-18
25 additions, 18 deletions
tests/test_worker.py
worker_generic_training_dataset/worker.py
+46
-72
46 additions, 72 deletions
worker_generic_training_dataset/worker.py
with
72 additions
and
104 deletions
.arkindex.yml
+
0
−
13
View file @
f2988ecd
...
...
@@ -9,16 +9,3 @@ workers:
type
:
data-extract
docker
:
build
:
Dockerfile
user_configuration
:
train_folder_id
:
type
:
string
title
:
ID of the training folder on Arkindex
required
:
true
validation_folder_id
:
type
:
string
title
:
ID of the validation folder on Arkindex
required
:
true
test_folder_id
:
type
:
string
title
:
ID of the testing folder on Arkindex
required
:
true
This diff is collapsed.
Click to expand it.
requirements.txt
+
1
−
1
View file @
f2988ecd
arkindex-base-worker
@
git+https://gitlab.teklia.com/workers/base-worker.git@master
arkindex-base-worker
==0.3.5rc4
arkindex-export
==0.1.7
This diff is collapsed.
Click to expand it.
tests/test_worker.py
+
25
−
18
View file @
f2988ecd
...
...
@@ -11,34 +11,38 @@ from arkindex_worker.cache import (
CachedTranscription
,
CachedTranscriptionEntity
,
)
from
worker_generic_training_dataset.db
import
retrieve_element
from
worker_generic_training_dataset.worker
import
DatasetExtractor
def
test_process_split
(
tmp_path
,
downloaded_images
):
# Parent is train folder
parent_id
:
UUID
=
UUID
(
"
a0c4522d-2d80-4766-a01c-b9d686f41f6a
"
)
worker
=
DatasetExtractor
()
# Parse some arguments
worker
.
args
=
Namespace
(
database
=
None
)
worker
.
data_folder_path
=
tmp_path
worker
.
configure_cache
()
worker
.
cached_images
=
dict
()
# Where to save the downloaded images
worker
.
image_folder
=
tmp_path
worker
.
process_split
(
"
train
"
,
parent_id
)
worker
.
images_folder
=
tmp_path
/
"
images
"
worker
.
images_folder
.
mkdir
(
parents
=
True
)
# Should have created 20 elements in total
assert
CachedElement
.
select
().
count
()
==
20
first_page_id
=
UUID
(
"
e26e6803-18da-4768-be30-a0a68132107c
"
)
second_page_id
=
UUID
(
"
c673bd94-96b1-4a2e-8662-a4d806940b5f
"
)
# Should have created two pages under root folder
assert
(
CachedElement
.
select
().
where
(
CachedElement
.
parent_id
==
parent_id
).
count
()
==
2
worker
.
process_split
(
"
train
"
,
[
retrieve_element
(
first_page_id
),
retrieve_element
(
second_page_id
),
],
)
first_page_id
=
UUID
(
"
e26e6803-18da-4768-be30-a0a68132107c
"
)
second_page_id
=
UUID
(
"
c673bd94-96b1-4a2e-8662-a4d806940b5f
"
)
# Should have created 20 elements in total
assert
CachedElement
.
select
().
count
()
==
19
# Should have created two pages at root
assert
CachedElement
.
select
().
where
(
CachedElement
.
parent_id
.
is_null
()).
count
()
==
2
# Should have created 8 text_lines under first page
assert
(
...
...
@@ -78,11 +82,6 @@ def test_process_split(tmp_path, downloaded_images):
==
f
"
https://europe-gamma.iiif.teklia.com/iiif/2/public%2Fiam%2F
{
page_name
}
.png
"
)
assert
sorted
(
tmp_path
.
rglob
(
"
*
"
))
==
[
tmp_path
/
f
"
{
first_image_id
}
.jpg
"
,
tmp_path
/
f
"
{
second_image_id
}
.jpg
"
,
]
# Should have created 17 transcriptions
assert
CachedTranscription
.
select
().
count
()
==
17
# Check transcription of first line on first page
...
...
@@ -125,3 +124,11 @@ def test_process_split(tmp_path, downloaded_images):
assert
tr_entity
.
length
==
23
assert
tr_entity
.
confidence
==
1.0
assert
tr_entity
.
worker_run_id
is
None
# Full structure of the archive
assert
sorted
(
tmp_path
.
rglob
(
"
*
"
))
==
[
tmp_path
/
"
db.sqlite
"
,
tmp_path
/
"
images
"
,
tmp_path
/
"
images
"
/
f
"
{
first_image_id
}
.jpg
"
,
tmp_path
/
"
images
"
/
f
"
{
second_image_id
}
.jpg
"
,
]
This diff is collapsed.
Click to expand it.
worker_generic_training_dataset/worker.py
+
46
−
72
View file @
f2988ecd
# -*- coding: utf-8 -*-
import
logging
import
operator
import
tempfile
from
argparse
import
Namespace
from
operator
import
itemgetter
from
pathlib
import
Path
from
tempfile
import
_TemporaryFileWrapper
from
typing
import
List
,
Optional
from
uuid
import
UUID
from
apistar.exceptions
import
ErrorResponse
from
arkindex_export
import
Element
,
Image
,
open_database
from
arkindex_export
import
Element
,
open_database
from
arkindex_export.queries
import
list_children
from
arkindex_worker.cache
import
(
CachedClassification
,
...
...
@@ -24,8 +24,10 @@ from arkindex_worker.cache import (
from
arkindex_worker.cache
import
db
as
cache_database
from
arkindex_worker.cache
import
init_cache_db
from
arkindex_worker.image
import
download_image
from
arkindex_worker.models
import
Dataset
from
arkindex_worker.models
import
Element
as
WorkerElement
from
arkindex_worker.utils
import
create_tar_zst_archive
from
arkindex_worker.worker
.base
import
B
aseWorker
from
arkindex_worker.worker
import
Dat
ase
t
Worker
from
worker_generic_training_dataset.db
import
(
list_classifications
,
list_transcription_entities
,
...
...
@@ -40,7 +42,11 @@ BULK_BATCH_SIZE = 50
DEFAULT_TRANSCRIPTION_ORIENTATION
=
"
horizontal-lr
"
class
DatasetExtractor
(
BaseWorker
):
def
_format_element
(
element
:
WorkerElement
)
->
Element
:
return
retrieve_element
(
element
.
id
)
class
DatasetExtractor
(
DatasetWorker
):
def
configure
(
self
)
->
None
:
self
.
args
:
Namespace
=
self
.
parser
.
parse_args
()
if
self
.
is_read_only
:
...
...
@@ -52,12 +58,13 @@ class DatasetExtractor(BaseWorker):
logger
.
info
(
"
Overriding with user_configuration
"
)
self
.
config
.
update
(
self
.
user_configuration
)
# Read process information
self
.
read_training_related_information
()
# Download corpus
self
.
download_latest_export
()
def
configure_storage
(
self
)
->
None
:
self
.
data_folder
=
tempfile
.
TemporaryDirectory
(
suffix
=
"
-arkindex-data
"
)
self
.
data_folder_path
=
Path
(
self
.
data_folder
.
name
)
# Initialize db that will be written
self
.
configure_cache
()
...
...
@@ -65,39 +72,17 @@ class DatasetExtractor(BaseWorker):
self
.
cached_images
=
dict
()
# Where to save the downloaded images
self
.
image_folder
=
Path
(
tempfile
.
mkdtemp
(
suffix
=
"
-arkindex-data
"
))
logger
.
info
(
f
"
Images will be saved at `
{
self
.
image_folder
}
`.
"
)
def
read_training_related_information
(
self
)
->
None
:
"""
Read from process information
- train_folder_id
- validation_folder_id
- test_folder_id (optional)
"""
logger
.
info
(
"
Retrieving information from process_information
"
)
train_folder_id
=
self
.
process_information
.
get
(
"
train_folder_id
"
)
assert
train_folder_id
,
"
A training folder id is necessary to use this worker
"
self
.
training_folder_id
=
UUID
(
train_folder_id
)
val_folder_id
=
self
.
process_information
.
get
(
"
validation_folder_id
"
)
assert
val_folder_id
,
"
A validation folder id is necessary to use this worker
"
self
.
validation_folder_id
=
UUID
(
val_folder_id
)
test_folder_id
=
self
.
process_information
.
get
(
"
test_folder_id
"
)
self
.
testing_folder_id
:
UUID
|
None
=
(
UUID
(
test_folder_id
)
if
test_folder_id
else
None
)
self
.
images_folder
=
self
.
data_folder_path
/
"
images
"
self
.
images_folder
.
mkdir
(
parents
=
True
)
logger
.
info
(
f
"
Images will be saved at `
{
self
.
images_folder
}
`.
"
)
def
configure_cache
(
self
)
->
None
:
"""
Create an SQLite database compatible with base-worker cache and initialize it.
"""
self
.
use_cache
=
True
self
.
cache_path
:
Path
=
self
.
args
.
database
or
self
.
work_dir
/
"
db.sqlite
"
# Remove previous execution result if present
self
.
cache_path
.
unlink
(
missing_ok
=
True
)
self
.
cache_path
:
Path
=
self
.
data_folder_path
/
"
db.sqlite
"
logger
.
info
(
f
"
Cached database will be saved at `
{
self
.
cache_path
}
`.
"
)
init_cache_db
(
self
.
cache_path
)
...
...
@@ -126,7 +111,7 @@ class DatasetExtractor(BaseWorker):
# Find the latest that is in "done" state
exports
:
List
[
dict
]
=
sorted
(
list
(
filter
(
lambda
exp
:
exp
[
"
state
"
]
==
"
done
"
,
exports
)),
key
=
operator
.
itemgetter
(
"
updated
"
),
key
=
itemgetter
(
"
updated
"
),
reverse
=
True
,
)
assert
(
...
...
@@ -261,7 +246,7 @@ class DatasetExtractor(BaseWorker):
# Download image
logger
.
info
(
"
Downloading image
"
)
download_image
(
url
=
build_image_url
(
element
)).
save
(
self
.
image_folder
/
f
"
{
element
.
image
.
id
}
.jpg
"
self
.
image
s
_folder
/
f
"
{
element
.
image
.
id
}
.jpg
"
)
# Insert image
logger
.
info
(
"
Inserting image
"
)
...
...
@@ -301,60 +286,49 @@ class DatasetExtractor(BaseWorker):
# Insert entities
self
.
insert_entities
(
transcriptions
)
def
process_split
(
self
,
split_name
:
str
,
split_id
:
UUID
)
->
None
:
"""
Insert all elements under the given parent folder (all queries are recursive).
- `page` elements are linked to this folder (via parent_id foreign key)
- `page` element children are linked to their `page` parent (via parent_id foreign key)
"""
def
process_split
(
self
,
split_name
:
str
,
elements
:
List
[
Element
])
->
None
:
logger
.
info
(
f
"
Filling the
Base-Worker
cache with information from
children under elemen
t
(
{
split_
id
}
)
"
f
"
Filling the cache with information from
elements in the spli
t
{
split_
name
}
"
)
# Fill cache
# Retrieve parent and create parent
parent
:
Element
=
retrieve_element
(
split_id
)
self
.
insert_element
(
parent
)
# First list all pages
pages
=
list_children
(
split_id
).
join
(
Image
).
where
(
Element
.
type
==
"
page
"
)
nb_pages
:
int
=
pages
.
count
()
for
idx
,
page
in
enumerate
(
pages
,
start
=
1
):
logger
.
info
(
f
"
Processing `
{
split_name
}
` page (
{
idx
}
/
{
nb_pages
}
)
"
)
nb_elements
:
int
=
len
(
elements
)
for
idx
,
element
in
enumerate
(
elements
,
start
=
1
):
logger
.
info
(
f
"
Processing `
{
split_name
}
` element (
{
idx
}
/
{
nb_elements
}
)
"
)
# Insert page
self
.
insert_element
(
page
,
parent_id
=
split_id
)
self
.
insert_element
(
element
)
# List children
children
=
list_children
(
page
.
id
)
children
=
list_children
(
element
.
id
)
nb_children
:
int
=
children
.
count
()
for
child_idx
,
child
in
enumerate
(
children
,
start
=
1
):
logger
.
info
(
f
"
Processing child (
{
child_idx
}
/
{
nb_children
}
)
"
)
# Insert child
self
.
insert_element
(
child
,
parent_id
=
page
.
id
)
def
run
(
self
):
self
.
configure
()
# Iterate over given split
for
split_name
,
split_id
in
[
(
"
Train
"
,
self
.
training_folder_id
),
(
"
Validation
"
,
self
.
validation_folder_id
),
(
"
Test
"
,
self
.
testing_folder_id
),
]:
if
not
split_id
:
continue
self
.
process_split
(
split_name
,
split_id
)
# TAR + ZSTD Image folder and store as task artifact
zstd_archive_path
:
Path
=
self
.
work_dir
/
"
arkindex_data.zstd
"
self
.
insert_element
(
child
,
parent_id
=
element
.
id
)
def
process_dataset
(
self
,
dataset
:
Dataset
):
# Configure temporary storage for the dataset data (cache + images)
self
.
configure_storage
()
# Iterate over given splits
for
split_name
,
elements
in
self
.
list_dataset_elements_per_split
(
dataset
):
casted_elements
=
list
(
map
(
_format_element
,
elements
))
self
.
process_split
(
split_name
,
casted_elements
)
# TAR + ZSTD the cache and the images folder, and store as task artifact
zstd_archive_path
:
Path
=
self
.
work_dir
/
f
"
{
dataset
.
id
}
.zstd
"
logger
.
info
(
f
"
Compressing the images to
{
zstd_archive_path
}
"
)
create_tar_zst_archive
(
source
=
self
.
image_folder
,
destination
=
zstd_archive_path
)
create_tar_zst_archive
(
source
=
self
.
data_folder_path
,
destination
=
zstd_archive_path
)
self
.
data_folder
.
cleanup
()
def
main
():
DatasetExtractor
(
description
=
"
Fill base-worker cache with information about dataset and extract images
"
,
support_cache
=
True
,
generator
=
True
,
).
run
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment