Skip to content
Snippets Groups Projects
Commit 5667f19a authored by Yoann Schneider's avatar Yoann Schneider :tennis:
Browse files

Merge branch 'bump-arkindex-export' into 'main'

Bump Python requirement arkindex-export to 0.1.12

Closes #300

See merge request !422
parents 0a2ea185 092e47aa
No related branches found
No related tags found
1 merge request!422Bump Python requirement arkindex-export to 0.1.12
......@@ -4,7 +4,9 @@
# -*- coding: utf-8 -*-
from typing import List
from arkindex_export import Image
from peewee import JOIN
from arkindex_export import Image, WorkerRun
from arkindex_export.models import (
Dataset,
DatasetElement,
......@@ -58,13 +60,16 @@ def build_worker_version_filter(ArkindexModel, worker_versions: List[str | bool]
"""
`False` worker version means `manual` worker_version -> null field.
"""
condition = None
for worker_version in worker_versions:
condition |= (
ArkindexModel.worker_version == worker_version
if worker_version
else ArkindexModel.worker_version.is_null()
# Filter `manual` worker version
condition = ArkindexModel.worker_run.is_null() if False in worker_versions else None
# Filter other worker versions
worker_versions = list(filter(None, worker_versions))
if worker_versions:
condition |= ArkindexModel.worker_run.in_(
WorkerRun.select().where(WorkerRun.worker_version.in_(worker_versions))
)
return condition
......@@ -90,12 +95,16 @@ def get_transcriptions(
"""
Retrieve transcriptions from an SQLite export of an Arkindex corpus
"""
query = Transcription.select(
Transcription.id,
Transcription.text,
Transcription.worker_version,
Transcription.worker_run,
).where((Transcription.element == element_id))
query = (
Transcription.select(
Transcription.id,
Transcription.text,
WorkerRun.id.alias("worker_run"),
WorkerRun.worker_version,
)
.join(WorkerRun, JOIN.LEFT_OUTER, on=Transcription.worker_run)
.where((Transcription.element == element_id))
)
if transcription_worker_versions:
query = query.where(
......@@ -129,11 +138,13 @@ def get_transcription_entities(
Entity.name.alias("name"),
TranscriptionEntity.offset,
TranscriptionEntity.length,
TranscriptionEntity.worker_version,
TranscriptionEntity.worker_run,
WorkerRun.id.alias("worker_run"),
WorkerRun.worker_version,
)
.join(Entity, on=TranscriptionEntity.entity)
.join(EntityType, on=Entity.type)
.switch(TranscriptionEntity)
.join(WorkerRun, JOIN.LEFT_OUTER, on=TranscriptionEntity.worker_run)
.where(
TranscriptionEntity.transcription == transcription_id,
EntityType.name.in_(supported_types),
......
......@@ -13,7 +13,7 @@ authors = [
]
dependencies = [
"albumentations==1.3.1",
"arkindex-export==0.1.9",
"arkindex-export==0.1.12",
"flashlight-text==0.0.7",
"imageio==2.26.1",
"imagesize==1.4.1",
......
......@@ -49,6 +49,7 @@ def mock_database(tmp_path_factory):
worker_run=worker_run,
)
TranscriptionEntity.create(
id=str(uuid.uuid4()),
entity=entity,
length=len(name),
offset=offset,
......@@ -203,6 +204,7 @@ def mock_database(tmp_path_factory):
name="Dataset",
state="complete",
sets=",".join(split_names),
description="My Dataset",
)
# Create dataset elements
......
......@@ -85,8 +85,8 @@ def test_get_transcriptions(sources, mock_database):
expected_transcriptions.append(
{
"text": "Laulont Francois 8",
"worker_version_id": None,
"worker_run_id": None,
"worker_version": None,
"worker_run": None,
}
)
......@@ -94,29 +94,17 @@ def test_get_transcriptions(sources, mock_database):
expected_transcriptions.append(
{
"text": "laulont francois 8",
"worker_version_id": "worker_version_id",
"worker_run_id": "worker_run_id",
"worker_version": "worker_version_id",
"worker_run": "worker_run_id",
}
)
assert (
sorted(
[
{
"text": transcription.text,
"worker_version_id": transcription.worker_version.id
if transcription.worker_version
else None,
"worker_run_id": transcription.worker_run.id
if transcription.worker_run
else None,
}
for transcription in transcriptions
],
key=itemgetter("text"),
)
== expected_transcriptions
)
# Do not compare IDs
transcriptions = transcriptions.dicts()
for transcription in transcriptions:
del transcription["id"]
assert sorted(transcriptions, key=itemgetter("text")) == expected_transcriptions
@pytest.mark.parametrize("source", (False, "id", None))
......@@ -130,8 +118,8 @@ def test_get_transcription_entities(source, mock_database, supported_types):
transcription_id = "train-page_1-line_1" + ("source" if source else "")
entities = get_transcription_entities(
transcription_id=transcription_id,
entity_worker_versions=[worker_version],
entity_worker_runs=[worker_run],
entity_worker_versions=[worker_version] if worker_version is not None else [],
entity_worker_runs=[worker_run] if worker_run is not None else [],
supported_types=supported_types,
)
......@@ -165,10 +153,4 @@ def test_get_transcription_entities(source, mock_database, supported_types):
entity["worker_version"] = worker_version or None
entity["worker_run"] = worker_run or None
assert (
sorted(
entities,
key=itemgetter("offset"),
)
== expected_entities
)
assert sorted(entities, key=itemgetter("offset")) == expected_entities
......@@ -196,7 +196,7 @@ def test_extract_transcription_no_translation(mock_database, tokens, tmp_path):
# Deleting one of the two transcriptions from the element
Transcription.get(
Transcription.element == element,
Transcription.worker_version_id == "worker_version_id",
Transcription.worker_run == "worker_run_id",
).delete_instance(recursive=True)
# Deleting all entities on the element remaining transcription while leaving the transcription intact
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment