diff --git a/arkindex_worker/reporting.py b/arkindex_worker/reporting.py index 3e7f248bd12f86c8b8f4885a1d6aa2fda53c6811..1f0736f1912c179db2b4f58c2b80c5e7493db29a 100644 --- a/arkindex_worker/reporting.py +++ b/arkindex_worker/reporting.py @@ -34,6 +34,8 @@ class Reporter(object): "transcriptions": {}, # Created classification counts, by class "classifications": {}, + # Created entities ({"id": "", "type": "", "name": ""}) from this element + "entities": [], "errors": [], }, ) @@ -99,8 +101,12 @@ class Reporter(object): counter.update([transcription["type"] for transcription in transcriptions]) element["transcriptions"] = dict(counter) - def add_entity(self, *args, **kwargs): - raise NotImplementedError + def add_entity(self, element_id, entity_id, type, name): + """ + Report creating an entity from an element. + """ + entities = self._get_element(element_id)["entities"] + entities.append({"id": entity_id, "type": type, "name": name}) def add_entity_link(self, *args, **kwargs): raise NotImplementedError diff --git a/arkindex_worker/worker.py b/arkindex_worker/worker.py index 488bc8fe2503f10c160a5e643ae2baa205c033be..e13c5fe0357571b1e8cd5b22b903583d0f481120 100644 --- a/arkindex_worker/worker.py +++ b/arkindex_worker/worker.py @@ -249,7 +249,7 @@ class ElementsWorker(BaseWorker): "score": score, }, ) - self.report.add_transcription(element.id, type) + self.report.add_transcription(element.id, type.value) def create_classification( self, element, ml_class, confidence, high_confidence=False @@ -282,10 +282,13 @@ class ElementsWorker(BaseWorker): ) self.report.add_classification(element.id, ml_class) - def create_entity(self, name, type, corpus, metas=None, validated=None): + def create_entity(self, element, name, type, corpus, metas=None, validated=None): """ Create an entity on the given corpus through API """ + assert element and isinstance( + element, Element + ), "element shouldn't be null and should be of type Element" assert name and isinstance( name, str ), "name shouldn't be null and should be of type str" @@ -300,7 +303,7 @@ class ElementsWorker(BaseWorker): if validated: assert isinstance(validated, bool), "validated should be of type bool" - self.api_client.request( + entity = self.api_client.request( "CreateEntity", body={ "name": name, @@ -311,5 +314,4 @@ class ElementsWorker(BaseWorker): "worker_version": self.worker_version_id, }, ) - # TODO: Uncomment this when Reporter add_entity() method is implemented - # self.report.add_entity(element.id, type) + self.report.add_entity(element.id, entity["id"], type.value, name) diff --git a/tests/test_elements_worker.py b/tests/test_elements_worker.py index 48baa6852dd6f53eff36c214a54fee1c0a9e094a..ba88d25c8425b39f5aa95b924f955e8449d80544 100644 --- a/tests/test_elements_worker.py +++ b/tests/test_elements_worker.py @@ -677,10 +677,34 @@ def test_create_classification(responses): ) +def test_create_entity_wrong_element(): + worker = ElementsWorker() + with pytest.raises(AssertionError) as e: + worker.create_entity( + element="not element type", + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + with pytest.raises(AssertionError) as e: + worker.create_entity( + element="not element type", + name="Bob Bob", + type=EntityType.Person, + corpus="12341234-1234-1234-1234-123412341234", + ) + assert str(e.value) == "element shouldn't be null and should be of type Element" + + def test_create_entity_wrong_name(): worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + with pytest.raises(AssertionError) as e: worker.create_entity( + element=elt, name=None, type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", @@ -689,6 +713,7 @@ def test_create_entity_wrong_name(): with pytest.raises(AssertionError) as e: worker.create_entity( + element=elt, name=1234, type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", @@ -698,21 +723,29 @@ def test_create_entity_wrong_name(): def test_create_entity_wrong_type(): worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type=None, corpus="12341234-1234-1234-1234-123412341234", + element=elt, + name="Bob Bob", + type=None, + corpus="12341234-1234-1234-1234-123412341234", ) assert str(e.value) == "type shouldn't be null and should be of type EntityType" with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type=1234, corpus="12341234-1234-1234-1234-123412341234", + element=elt, + name="Bob Bob", + type=1234, + corpus="12341234-1234-1234-1234-123412341234", ) assert str(e.value) == "type shouldn't be null and should be of type EntityType" with pytest.raises(AssertionError) as e: worker.create_entity( + element=elt, name="Bob Bob", type="not_an_entity_type", corpus="12341234-1234-1234-1234-123412341234", @@ -722,23 +755,28 @@ def test_create_entity_wrong_type(): def test_create_entity_wrong_corpus(): worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type=EntityType.Person, corpus=None, + element=elt, name="Bob Bob", type=EntityType.Person, corpus=None, ) assert str(e.value) == "corpus shouldn't be null and should be of type str" with pytest.raises(AssertionError) as e: worker.create_entity( - name="Bob Bob", type=EntityType.Person, corpus=1234, + element=elt, name="Bob Bob", type=EntityType.Person, corpus=1234, ) assert str(e.value) == "corpus shouldn't be null and should be of type str" def test_create_entity_wrong_metas(): worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + with pytest.raises(AssertionError) as e: worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", @@ -749,8 +787,11 @@ def test_create_entity_wrong_metas(): def test_create_entity_wrong_validated(): worker = ElementsWorker() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) + with pytest.raises(AssertionError) as e: worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", @@ -762,12 +803,14 @@ def test_create_entity_wrong_validated(): def test_create_entity_api_error(responses): worker = ElementsWorker() worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.POST, "https://arkindex.teklia.com/api/v1/entity/", status=500, ) with pytest.raises(ErrorResponse): worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", @@ -782,11 +825,16 @@ def test_create_entity_api_error(responses): def test_create_entity(responses): worker = ElementsWorker() worker.configure() + elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( - responses.POST, "https://arkindex.teklia.com/api/v1/entity/", status=200, + responses.POST, + "https://arkindex.teklia.com/api/v1/entity/", + status=200, + json={"id": "12345678-1234-1234-1234-123456789123"}, ) worker.create_entity( + element=elt, name="Bob Bob", type=EntityType.Person, corpus="12341234-1234-1234-1234-123412341234", diff --git a/tests/test_reporting.py b/tests/test_reporting.py index 16d66635908c8fa3ad57e33debe821d3ab2a57b2..7f1c1cb299b82827cf2a49e25d4371fa8b777685 100644 --- a/tests/test_reporting.py +++ b/tests/test_reporting.py @@ -23,6 +23,7 @@ def test_process(): "elements": {}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -37,6 +38,7 @@ def test_add_element(): "elements": {"text_line": 1}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -54,6 +56,7 @@ def test_add_element_count(): "elements": {"text_line": 42}, "transcriptions": {}, "classifications": {}, + "entities": [], "errors": [], } @@ -68,6 +71,7 @@ def test_add_classification(): "elements": {}, "transcriptions": {}, "classifications": {"three": 1}, + "entities": [], "errors": [], } @@ -96,6 +100,7 @@ def test_add_classifications(): "elements": {}, "transcriptions": {}, "classifications": {"three": 3, "two": 2}, + "entities": [], "errors": [], } @@ -110,6 +115,7 @@ def test_add_transcription(): "elements": {}, "transcriptions": {"word": 1}, "classifications": {}, + "entities": [], "errors": [], } @@ -127,6 +133,7 @@ def test_add_transcription_count(): "elements": {}, "transcriptions": {"word": 1337}, "classifications": {}, + "entities": [], "errors": [], } @@ -153,6 +160,30 @@ def test_add_transcriptions(): "elements": {}, "transcriptions": {"word": 3, "line": 2}, "classifications": {}, + "entities": [], + "errors": [], + } + + +def test_add_entity(): + reporter = Reporter("worker") + reporter.add_entity( + "myelement", "12341234-1234-1234-1234-123412341234", "person", "Bob Bob" + ) + assert "myelement" in reporter.report_data["elements"] + element_data = reporter.report_data["elements"]["myelement"] + del element_data["started"] + assert element_data == { + "elements": {}, + "transcriptions": {}, + "classifications": {}, + "entities": [ + { + "id": "12341234-1234-1234-1234-123412341234", + "type": "person", + "name": "Bob Bob", + } + ], "errors": [], }