From 40c071f53a6db16633cab2d50d8aad4f61a02703 Mon Sep 17 00:00:00 2001 From: Eva Bardou <ebardou@teklia.com> Date: Wed, 19 Aug 2020 14:06:00 +0000 Subject: [PATCH] Add few more helpers for developers + Report entity creation --- arkindex_worker/reporting.py | 10 +- arkindex_worker/worker.py | 113 ++++++++ tests/test_elements_worker.py | 497 +++++++++++++++++++++++++++++++++- tests/test_reporting.py | 31 +++ 4 files changed, 648 insertions(+), 3 deletions(-) diff --git a/arkindex_worker/reporting.py b/arkindex_worker/reporting.py index 3e7f248b..1f0736f1 100644 --- a/arkindex_worker/reporting.py +++ b/arkindex_worker/reporting.py @@ -34,6 +34,8 @@ class Reporter(object): "transcriptions": {}, # Created classification counts, by class "classifications": {}, + # Created entities ({"id": "", "type": "", "name": ""}) from this element + "entities": [], "errors": [], }, ) @@ -99,8 +101,12 @@ class Reporter(object): counter.update([transcription["type"] for transcription in transcriptions]) element["transcriptions"] = dict(counter) - def add_entity(self, *args, **kwargs): - raise NotImplementedError + def add_entity(self, element_id, entity_id, type, name): + """ + Report creating an entity from an element. + """ + entities = self._get_element(element_id)["entities"] + entities.append({"id": entity_id, "type": type, "name": name}) def add_entity_link(self, *args, **kwargs): raise NotImplementedError diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 46aab692..e13c5fe0 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -5,6 +5,7 @@ import logging import os import sys import uuid +from enum import Enum from apistar.exceptions import ErrorResponse @@ -72,6 +73,24 @@ class BaseWorker(object): """Override this method to implement your own process""" +class TranscriptionType(Enum): + Page = "page" + Paragraph = "paragraph" + Line = "line" + Word = "word" + Character = "character" + + +class EntityType(Enum): + Person = "person" + Location = "location" + Subject = "subject" + Organization = "organization" + Misc = "misc" + Number = "number" + Date = "date" + + class ElementsWorker(BaseWorker): def __init__(self, description="Arkindex Elements Worker"): super().__init__(description) @@ -202,3 +221,97 @@ class ElementsWorker(BaseWorker): }, ) self.report.add_element(element.id, type) + + def create_transcription(self, element, text, type, score): + """ + Create a transcription on the given element through API + """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" + assert type and isinstance( + type, TranscriptionType + ), "type shouldn't be null and should be of type TranscriptionType" + assert text and isinstance( + text, str + ), "text shouldn't be null and should be of type str" + assert ( + score and isinstance(score, float) and 0 <= score <= 1 + ), "score shouldn't be null and should be a float in [0..1] range" + + self.api_client.request( + "CreateTranscription", + id=element.id, + body={ + "text": text, + "type": type.value, + "worker_version": self.worker_version_id, + "score": score, + }, + ) + self.report.add_transcription(element.id, type.value) + + def create_classification( + self, element, ml_class, confidence, high_confidence=False + ): + """ + Create a classification on the given element through API + """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" + assert ml_class and isinstance( + ml_class, str + ), "ml_class shouldn't be null and should be of type str" + assert ( + confidence and isinstance(confidence, float) and 0 <= confidence <= 1 + ), "confidence shouldn't be null and should be a float in [0..1] range" + assert high_confidence and isinstance( + high_confidence, bool + ), "high_confidence shouldn't be null and should be of type bool" + + self.api_client.request( + "CreateClassification", + body={ + "element": element.id, + "ml_class": ml_class, + "worker_version": self.worker_version_id, + "confidence": confidence, + "high_confidence": high_confidence, + }, + ) + self.report.add_classification(element.id, ml_class) + + def create_entity(self, element, name, type, corpus, metas=None, validated=None): + """ + Create an entity on the given corpus through API + """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" + assert name and isinstance( + name, str + ), "name shouldn't be null and should be of type str" + assert type and isinstance( + type, EntityType + ), "type shouldn't be null and should be of type EntityType" + assert corpus and isinstance( + corpus, str + ), "corpus shouldn't be null and should be of type str" + if metas: + assert isinstance(metas, dict), "metas should be of type dict" + if validated: + assert isinstance(validated, bool), "validated should be of type bool" + + entity = self.api_client.request( + "CreateEntity", + body={ + "name": name, + "type": type.value, + "metas": metas, + "validated": validated, + "corpus": corpus, + "worker_version": self.worker_version_id, + }, + ) + self.report.add_entity(element.id, entity["id"], type.value, name) diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index b94312fb..c3618b4b 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -10,7 +10,7 @@ import pytest from apistar.exceptions import ErrorResponse from arkindex_worker.models import Element -from arkindex_worker.worker import ElementsWorker +from arkindex_worker.worker import ElementsWorker, EntityType, TranscriptionType def test_cli_default(monkeypatch): @@ -379,3 +379,498 @@ def test_create_sub_element(responses): responses.calls[0].request.url == "https://arkindex.teklia.com/api/v1/elements/create/" ) + assert json.loads(responses.calls[0].request.body) == { + "type": "something", + "name": "0", + "image": "22222222-2222-2222-2222-222222222222", + "corpus": "11111111-1111-1111-1111-111111111111", + "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], + "parent": "12341234-1234-1234-1234-123412341234", + "worker_version": "12341234-1234-1234-1234-123412341234", + } + + +def test_create_transcription_wrong_element(): + worker = ElementsWorker() + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=None, text="i am a line", type=TranscriptionType.Line, score=0.42, + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element="not element type", + text="i am a line", + type=TranscriptionType.Line, + score=0.42, + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + +def test_create_transcription_wrong_type(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text="i am a line", type=None, score=0.42, + ) + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text="i am a line", type=1234, score=0.42, + ) + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, + text="i am a line", + type="not_a_transcription_type", + score=0.42, + ) + assert ( + str(e.value) == "type shouldn't be null and should be of type TranscriptionType" + ) + + +def test_create_transcription_wrong_text(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text=None, type=TranscriptionType.Line, score=0.42, + ) + assert str(e.value) == "text shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text=1234, type=TranscriptionType.Line, score=0.42, + ) + assert str(e.value) == "text shouldn't be null and should be of type str" + + +def test_create_transcription_wrong_score(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text="i am a line", type=TranscriptionType.Line, score=None, + ) + assert ( + str(e.value) == "score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, + text="i am a line", + type=TranscriptionType.Line, + score="wrong score", + ) + assert ( + str(e.value) == "score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text="i am a line", type=TranscriptionType.Line, score=0, + ) + assert ( + str(e.value) == "score shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_transcription( + element=elt, text="i am a line", type=TranscriptionType.Line, score=2.00, + ) + assert ( + str(e.value) == "score shouldn't be null and should be a float in [0..1] range" + ) + + +def test_create_transcription_api_error(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"https://arkindex.teklia.com/api/v1/element/{elt.id}/transcription/", + status=500, + ) + + with pytest.raises(ErrorResponse): + worker.create_transcription( + element=elt, text="i am a line", type=TranscriptionType.Line, score=0.42, + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url + == f"https://arkindex.teklia.com/api/v1/element/{elt.id}/transcription/" + ) + + +def test_create_transcription(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + f"https://arkindex.teklia.com/api/v1/element/{elt.id}/transcription/", + status=200, + ) + + worker.create_transcription( + element=elt, text="i am a line", type=TranscriptionType.Line, score=0.42, + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url + == f"https://arkindex.teklia.com/api/v1/element/{elt.id}/transcription/" + ) + assert json.loads(responses.calls[0].request.body) == { + "text": "i am a line", + "type": "line", + "worker_version": "12341234-1234-1234-1234-123412341234", + "score": 0.42, + } + + +def test_create_classification_wrong_element(): + worker = ElementsWorker() + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=None, ml_class="a_class", confidence=0.42, high_confidence=True, + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element="not element type", + ml_class="a_class", + confidence=0.42, + high_confidence=True, + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + +def test_create_classification_wrong_ml_class(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class=None, confidence=0.42, high_confidence=True, + ) + assert str(e.value) == "ml_class shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class=1234, confidence=0.42, high_confidence=True, + ) + assert str(e.value) == "ml_class shouldn't be null and should be of type str" + + +def test_create_classification_wrong_confidence(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class="a_class", confidence=None, high_confidence=True, + ) + assert ( + str(e.value) + == "confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, + ml_class="a_class", + confidence="wrong confidence", + high_confidence=True, + ) + assert ( + str(e.value) + == "confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class="a_class", confidence=0, high_confidence=True, + ) + assert ( + str(e.value) + == "confidence shouldn't be null and should be a float in [0..1] range" + ) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class="a_class", confidence=2.00, high_confidence=True, + ) + assert ( + str(e.value) + == "confidence shouldn't be null and should be a float in [0..1] range" + ) + + +def test_create_classification_wrong_high_confidence(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, ml_class="a_class", confidence=0.42, high_confidence=None, + ) + assert ( + str(e.value) == "high_confidence shouldn't be null and should be of type bool" + ) + + with pytest.raises(AssertionError) as e: + worker.create_classification( + element=elt, + ml_class="a_class", + confidence=0.42, + high_confidence="wrong high_confidence", + ) + assert ( + str(e.value) == "high_confidence shouldn't be null and should be of type bool" + ) + + +def test_create_classification_api_error(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + "https://arkindex.teklia.com/api/v1/classifications/", + status=500, + ) + + with pytest.raises(ErrorResponse): + worker.create_classification( + element=elt, ml_class="a_class", confidence=0.42, high_confidence=True, + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url + == "https://arkindex.teklia.com/api/v1/classifications/" + ) + + +def test_create_classification(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + "https://arkindex.teklia.com/api/v1/classifications/", + status=200, + ) + + worker.create_classification( + element=elt, ml_class="a_class", confidence=0.42, high_confidence=True, + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url + == "https://arkindex.teklia.com/api/v1/classifications/" + ) + assert json.loads(responses.calls[0].request.body) == { + "element": "12341234-1234-1234-1234-123412341234", + "ml_class": "a_class", + "worker_version": "12341234-1234-1234-1234-123412341234", + "confidence": 0.42, + "high_confidence": True, + } + + +def test_create_entity_wrong_element(): + worker = ElementsWorker() + with pytest.raises(AssertionError) as e: + worker.create_entity( + element="not element type", + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element="not element type", + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + +def test_create_entity_wrong_name(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name=None, + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "name shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name=1234, + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "name shouldn't be null and should be of type str" + + +def test_create_entity_wrong_type(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name="Bob Bob", + type=None, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "type shouldn't be null and should be of type EntityType" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name="Bob Bob", + type=1234, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "type shouldn't be null and should be of type EntityType" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name="Bob Bob", + type="not_an_entity_type", + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "type shouldn't be null and should be of type EntityType" + + +def test_create_entity_wrong_corpus(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus=None, + ) + assert str(e.value) == "corpus shouldn't be null and should be of type str" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus=1234, + ) + assert str(e.value) == "corpus shouldn't be null and should be of type str" + + +def test_create_entity_wrong_metas(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + metas="wrong metas", + ) + assert str(e.value) == "metas should be of type dict" + + +def test_create_entity_wrong_validated(): + worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element=elt, + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + validated="wrong validated", + ) + assert str(e.value) == "validated should be of type bool" + + +def test_create_entity_api_error(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, "https://arkindex.teklia.com/api/v1/entity/", status=500, + ) + + with pytest.raises(ErrorResponse): + worker.create_entity( + element=elt, + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url == "https://arkindex.teklia.com/api/v1/entity/" + ) + + +def test_create_entity(responses): + worker = ElementsWorker() + worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + responses.add( + responses.POST, + "https://arkindex.teklia.com/api/v1/entity/", + status=200, + json={"id": "12345678-1234-1234-1234-123456789123"}, + ) + + worker.create_entity( + element=elt, + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + + assert len(responses.calls) == 1 + assert ( + responses.calls[0].request.url == "https://arkindex.teklia.com/api/v1/entity/" + ) + assert json.loads(responses.calls[0].request.body) == { + "name": "Bob Bob", + "type": "person", + "metas": None, + "validated": None, + "corpus": "12341234-1234-1234-1234-123412341234", + "worker_version": "12341234-1234-1234-1234-123412341234", + } diff --git a/tests/test_reporting.py b/tests/test_reporting.py index 16d66635..7f1c1cb2 100644 --- a/tests/test_reporting.py +++ b/tests/test_reporting.py @@ -23,6 +23,7 @@ def test_process(): "elements": {}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -37,6 +38,7 @@ def test_add_element(): "elements": {"text_line": 1}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -54,6 +56,7 @@ def test_add_element_count(): "elements": {"text_line": 42}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -68,6 +71,7 @@ def test_add_classification(): "elements": {}, "transcriptions": {}, "classifications": {"three": 1}, + "entities": [], "errors": [], } @@ -96,6 +100,7 @@ def test_add_classifications(): "elements": {}, "transcriptions": {}, "classifications": {"three": 3, "two": 2}, + "entities": [], "errors": [], } @@ -110,6 +115,7 @@ def test_add_transcription(): "elements": {}, "transcriptions": {"word": 1}, "classifications": {}, + "entities": [], "errors": [], } @@ -127,6 +133,7 @@ def test_add_transcription_count(): "elements": {}, "transcriptions": {"word": 1337}, "classifications": {}, + "entities": [], "errors": [], } @@ -153,6 +160,30 @@ def test_add_transcriptions(): "elements": {}, "transcriptions": {"word": 3, "line": 2}, "classifications": {}, + "entities": [], + "errors": [], + } + + +def test_add_entity(): + reporter = Reporter("worker") + reporter.add_entity( + "myelement", "12341234-1234-1234-1234-123412341234", "person", "Bob Bob" + ) + assert "myelement" in reporter.report_data["elements"] + element_data = reporter.report_data["elements"]["myelement"] + del element_data["started"] + assert element_data == { + "elements": {}, + "transcriptions": {}, + "classifications": {}, + "entities": [ + { + "id": "12341234-1234-1234-1234-123412341234", + "type": "person", + "name": "Bob Bob", + } + ], "errors": [], } -- GitLab