import json import re from argparse import Namespace from uuid import UUID import pytest from apistar.exceptions import ErrorResponse from responses import matchers from arkindex_worker.cache import ( SQL_VERSION, CachedElement, CachedImage, create_version_table, init_cache_db, ) from arkindex_worker.models import Element from arkindex_worker.worker import ElementsWorker from arkindex_worker.worker.element import MissingTypeError from . import BASE_API_CALLS def test_check_required_types_argument_types(mock_elements_worker): with pytest.raises( AssertionError, match="At least one element type slug is required." ): mock_elements_worker.check_required_types() with pytest.raises(AssertionError, match="Element type slugs must be strings."): mock_elements_worker.check_required_types("lol", 42) def test_check_required_types(responses, mock_elements_worker): corpus_id = "11111111-1111-1111-1111-111111111111" responses.add( responses.GET, f"http://testserver/api/v1/corpus/{corpus_id}/", json={ "id": corpus_id, "name": "Some Corpus", "types": [{"slug": "folder"}, {"slug": "page"}], }, ) mock_elements_worker.setup_api_client() assert mock_elements_worker.check_required_types("page") assert mock_elements_worker.check_required_types("page", "folder") with pytest.raises( MissingTypeError, match=re.escape( "Element type(s) act, text_line were not found in the Some Corpus corpus (11111111-1111-1111-1111-111111111111)." ), ): assert mock_elements_worker.check_required_types("page", "text_line", "act") def test_create_missing_types(responses, mock_elements_worker): corpus_id = "11111111-1111-1111-1111-111111111111" responses.add( responses.GET, f"http://testserver/api/v1/corpus/{corpus_id}/", json={ "id": corpus_id, "name": "Some Corpus", "types": [{"slug": "folder"}, {"slug": "page"}], }, ) responses.add( responses.POST, "http://testserver/api/v1/elements/type/", match=[ matchers.json_params_matcher( { "slug": "text_line", "display_name": "text_line", "folder": False, "corpus": corpus_id, } ) ], ) responses.add( responses.POST, "http://testserver/api/v1/elements/type/", match=[ matchers.json_params_matcher( { "slug": "act", "display_name": "act", "folder": False, "corpus": corpus_id, } ) ], ) mock_elements_worker.setup_api_client() assert mock_elements_worker.check_required_types( "page", "text_line", "act", create_missing=True ) def test_list_elements_elements_list_arg_wrong_type( monkeypatch, tmp_path, mock_elements_worker ): elements_path = tmp_path / "elements.json" elements_path.write_text("{}") monkeypatch.setenv("TASK_ELEMENTS", str(elements_path)) worker = ElementsWorker() worker.configure() with pytest.raises(AssertionError, match="Elements list must be a list"): worker.list_elements() def test_list_elements_elements_list_arg_empty_list( monkeypatch, tmp_path, mock_elements_worker ): elements_path = tmp_path / "elements.json" elements_path.write_text("[]") monkeypatch.setenv("TASK_ELEMENTS", str(elements_path)) worker = ElementsWorker() worker.configure() with pytest.raises(AssertionError, match="No elements in elements list"): worker.list_elements() def test_list_elements_elements_list_arg_missing_id( monkeypatch, tmp_path, mock_elements_worker ): elements_path = tmp_path / "elements.json" with elements_path.open("w") as f: json.dump([{"type": "volume"}], f) monkeypatch.setenv("TASK_ELEMENTS", str(elements_path)) worker = ElementsWorker() worker.configure() elt_list = worker.list_elements() assert elt_list == [] def test_list_elements_elements_list_arg(monkeypatch, tmp_path, mock_elements_worker): elements_path = tmp_path / "elements.json" with elements_path.open("w") as f: json.dump( [ {"id": "volumeid", "type": "volume"}, {"id": "pageid", "type": "page"}, {"id": "actid", "type": "act"}, {"id": "surfaceid", "type": "surface"}, ], f, ) monkeypatch.setenv("TASK_ELEMENTS", str(elements_path)) worker = ElementsWorker() worker.configure() elt_list = worker.list_elements() assert elt_list == ["volumeid", "pageid", "actid", "surfaceid"] def test_list_elements_element_arg(mocker, mock_elements_worker): mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", return_value=Namespace( element=["volumeid", "pageid"], verbose=False, elements_list=None, database=None, dev=False, ), ) worker = ElementsWorker() worker.configure() elt_list = worker.list_elements() assert elt_list == ["volumeid", "pageid"] def test_list_elements_both_args_error(mocker, mock_elements_worker, tmp_path): elements_path = tmp_path / "elements.json" with elements_path.open("w") as f: json.dump( [ {"id": "volumeid", "type": "volume"}, {"id": "pageid", "type": "page"}, {"id": "actid", "type": "act"}, {"id": "surfaceid", "type": "surface"}, ], f, ) mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", return_value=Namespace( element=["anotherid", "againanotherid"], verbose=False, elements_list=elements_path.open(), database=None, dev=False, ), ) worker = ElementsWorker() worker.configure() with pytest.raises( AssertionError, match="elements-list and element CLI args shouldn't be both set" ): worker.list_elements() def test_database_arg(mocker, mock_elements_worker, tmp_path): database_path = tmp_path / "my_database.sqlite" init_cache_db(database_path) create_version_table() mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", return_value=Namespace( element=["volumeid", "pageid"], verbose=False, elements_list=None, database=database_path, dev=False, ), ) worker = ElementsWorker(support_cache=True) worker.configure() assert worker.use_cache is True assert worker.cache_path == database_path def test_database_arg_cache_missing_version_table( mocker, mock_elements_worker, tmp_path ): database_path = tmp_path / "my_database.sqlite" database_path.touch() mocker.patch( "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args", return_value=Namespace( element=["volumeid", "pageid"], verbose=False, elements_list=None, database=database_path, dev=False, ), ) worker = ElementsWorker(support_cache=True) with pytest.raises( AssertionError, match=f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}", ): worker.configure() def test_load_corpus_classes_api_error(responses, mock_elements_worker): responses.add( responses.GET, "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", status=500, ) assert not mock_elements_worker.classes with pytest.raises( Exception, match="Stopping pagination as data will be incomplete" ): mock_elements_worker.load_corpus_classes() assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We do 5 retries ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ] assert not mock_elements_worker.classes def test_load_corpus_classes(responses, mock_elements_worker): responses.add( responses.GET, "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", status=200, json={ "count": 3, "next": None, "results": [ { "id": "0000", "name": "good", }, { "id": "1111", "name": "average", }, { "id": "2222", "name": "bad", }, ], }, ) assert not mock_elements_worker.classes mock_elements_worker.load_corpus_classes() assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/", ), ] assert mock_elements_worker.classes == { "good": "0000", "average": "1111", "bad": "2222", } def test_create_sub_element_wrong_element(mock_elements_worker): with pytest.raises( AssertionError, match="element shouldn't be null and should be of type Element" ): mock_elements_worker.create_sub_element( element=None, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) with pytest.raises( AssertionError, match="element shouldn't be null and should be of type Element" ): mock_elements_worker.create_sub_element( element="not element type", type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) def test_create_sub_element_wrong_type(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises( AssertionError, match="type shouldn't be null and should be of type str" ): mock_elements_worker.create_sub_element( element=elt, type=None, name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) with pytest.raises( AssertionError, match="type shouldn't be null and should be of type str" ): mock_elements_worker.create_sub_element( element=elt, type=1234, name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) def test_create_sub_element_wrong_name(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises( AssertionError, match="name shouldn't be null and should be of type str" ): mock_elements_worker.create_sub_element( element=elt, type="something", name=None, polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) with pytest.raises( AssertionError, match="name shouldn't be null and should be of type str" ): mock_elements_worker.create_sub_element( element=elt, type="something", name=1234, polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) def test_create_sub_element_wrong_polygon(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises( AssertionError, match="polygon shouldn't be null and should be of type list" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=None, ) with pytest.raises( AssertionError, match="polygon shouldn't be null and should be of type list" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="O", polygon="not a polygon", ) with pytest.raises( AssertionError, match="polygon should have at least three points" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="O", polygon=[[1, 1], [2, 2]], ) with pytest.raises( AssertionError, match="polygon points should be lists of two items" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="O", polygon=[[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]], ) with pytest.raises( AssertionError, match="polygon points should be lists of two items" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="O", polygon=[[1], [2], [2], [1]], ) with pytest.raises( AssertionError, match="polygon points should be lists of two numbers" ): mock_elements_worker.create_sub_element( element=elt, type="something", name="O", polygon=[["not a coord", 1], [2, 2], [2, 1], [1, 2]], ) @pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")]) def test_create_sub_element_wrong_confidence(mock_elements_worker, confidence): with pytest.raises( AssertionError, match=re.escape("confidence should be None or a float in [0..1] range"), ): mock_elements_worker.create_sub_element( element=Element({"zone": None}), type="something", name="blah", polygon=[[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], confidence=confidence, ) def test_create_sub_element_api_error(responses, mock_elements_worker): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } ) responses.add( responses.POST, "http://testserver/api/v1/elements/create/", status=500, ) with pytest.raises(ErrorResponse): mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We retry 5 times the API call ("POST", "http://testserver/api/v1/elements/create/"), ("POST", "http://testserver/api/v1/elements/create/"), ("POST", "http://testserver/api/v1/elements/create/"), ("POST", "http://testserver/api/v1/elements/create/"), ("POST", "http://testserver/api/v1/elements/create/"), ] @pytest.mark.parametrize("slim_output", [True, False]) def test_create_sub_element(responses, mock_elements_worker, slim_output): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } ) child_elt = { "id": "12345678-1234-1234-1234-123456789123", "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } responses.add( responses.POST, "http://testserver/api/v1/elements/create/", status=200, json=child_elt, ) element_creation_response = mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], slim_output=slim_output, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "POST", "http://testserver/api/v1/elements/create/", ), ] assert json.loads(responses.calls[-1].request.body) == { "type": "something", "name": "0", "image": "22222222-2222-2222-2222-222222222222", "corpus": "11111111-1111-1111-1111-111111111111", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "parent": "12341234-1234-1234-1234-123412341234", "worker_run_id": "56785678-5678-5678-5678-567856785678", "confidence": None, } if slim_output: assert element_creation_response == "12345678-1234-1234-1234-123456789123" else: assert Element(element_creation_response) == Element(child_elt) def test_create_sub_element_confidence(responses, mock_elements_worker): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "corpus": {"id": "11111111-1111-1111-1111-111111111111"}, "zone": {"image": {"id": "22222222-2222-2222-2222-222222222222"}}, } ) responses.add( responses.POST, "http://testserver/api/v1/elements/create/", status=200, json={"id": "12345678-1234-1234-1234-123456789123"}, ) sub_element_id = mock_elements_worker.create_sub_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], confidence=0.42, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ("POST", "http://testserver/api/v1/elements/create/"), ] assert json.loads(responses.calls[-1].request.body) == { "type": "something", "name": "0", "image": "22222222-2222-2222-2222-222222222222", "corpus": "11111111-1111-1111-1111-111111111111", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "parent": "12341234-1234-1234-1234-123412341234", "worker_run_id": "56785678-5678-5678-5678-567856785678", "confidence": 0.42, } assert sub_element_id == "12345678-1234-1234-1234-123456789123" def test_create_elements_wrong_parent(mock_elements_worker): with pytest.raises( TypeError, match="Parent element should be an Element or CachedElement instance" ): mock_elements_worker.create_elements( parent=None, elements=[], ) with pytest.raises( TypeError, match="Parent element should be an Element or CachedElement instance" ): mock_elements_worker.create_elements( parent="not element type", elements=[], ) def test_create_elements_no_zone(mock_elements_worker): elt = Element({"zone": None}) with pytest.raises( AssertionError, match="create_elements cannot be used on parents without zones" ): mock_elements_worker.create_elements( parent=elt, elements=None, ) elt = CachedElement( id="11111111-1111-1111-1111-1111111111", name="blah", type="blah" ) with pytest.raises( AssertionError, match="create_elements cannot be used on parents without images" ): mock_elements_worker.create_elements( parent=elt, elements=None, ) def test_create_elements_wrong_elements(mock_elements_worker): elt = Element({"zone": {"image": {"id": "image_id"}}}) with pytest.raises( AssertionError, match="elements shouldn't be null and should be of type list" ): mock_elements_worker.create_elements( parent=elt, elements=None, ) with pytest.raises( AssertionError, match="elements shouldn't be null and should be of type list" ): mock_elements_worker.create_elements( parent=elt, elements="not a list", ) def test_create_elements_wrong_elements_instance(mock_elements_worker): elt = Element({"zone": {"image": {"id": "image_id"}}}) with pytest.raises( AssertionError, match="Element at index 0 in elements: Should be of type dict" ): mock_elements_worker.create_elements( parent=elt, elements=["not a dict"], ) def test_create_elements_wrong_elements_name(mock_elements_worker): elt = Element({"zone": {"image": {"id": "image_id"}}}) with pytest.raises( AssertionError, match="Element at index 0 in elements: name shouldn't be null and should be of type str", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": None, "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: name shouldn't be null and should be of type str", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": 1234, "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) def test_create_elements_wrong_elements_type(mock_elements_worker): elt = Element({"zone": {"image": {"id": "image_id"}}}) with pytest.raises( AssertionError, match="Element at index 0 in elements: type shouldn't be null and should be of type str", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": None, "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: type shouldn't be null and should be of type str", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": 1234, "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) def test_create_elements_wrong_elements_polygon(mock_elements_worker): elt = Element({"zone": {"image": {"id": "image_id"}}}) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon shouldn't be null and should be of type list", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": None, } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon shouldn't be null and should be of type list", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": "not a polygon", } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon should have at least three points", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2]], } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon points should be lists of two items", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]], } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon points should be lists of two items", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1], [2], [2], [1]], } ], ) with pytest.raises( AssertionError, match="Element at index 0 in elements: polygon points should be lists of two numbers", ): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]], } ], ) @pytest.mark.parametrize("confidence", ["lol", "0.2", -1.0, 1.42, float("inf")]) def test_create_elements_wrong_elements_confidence(mock_elements_worker, confidence): with pytest.raises( AssertionError, match=re.escape( "Element at index 0 in elements: confidence should be None or a float in [0..1] range" ), ): mock_elements_worker.create_elements( parent=Element({"zone": {"image": {"id": "image_id"}}}), elements=[ { "name": "a", "type": "something", "polygon": [[0, 0], [0, 10], [10, 10], [10, 0], [0, 0]], "confidence": confidence, } ], ) def test_create_elements_api_error(responses, mock_elements_worker): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "zone": { "image": { "id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", "width": 42, "height": 42, "url": "http://aaaa", } }, } ) responses.add( responses.POST, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", status=500, ) with pytest.raises(ErrorResponse): mock_elements_worker.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We retry 5 times the API call ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ] def test_create_elements_cached_element(responses, mock_elements_worker_with_cache): image = CachedImage.create( id=UUID("c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe"), width=42, height=42, url="http://aaaa", ) elt = CachedElement.create( id=UUID("12341234-1234-1234-1234-123412341234"), type="parent", image_id=image.id, polygon="[[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]", ) responses.add( responses.POST, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", status=200, json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}], ) created_ids = mock_elements_worker_with_cache.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ] assert json.loads(responses.calls[-1].request.body) == { "elements": [ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], "worker_run_id": "56785678-5678-5678-5678-567856785678", } assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}] # Check that created elements were properly stored in SQLite cache assert list(CachedElement.select().order_by(CachedElement.id)) == [ elt, CachedElement( id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"), parent_id=elt.id, type="something", image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), ), ] def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "zone": { "image": { "id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", "width": 42, "height": 42, "url": "http://aaaa", } }, } ) responses.add( responses.POST, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", status=200, json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}], ) created_ids = mock_elements_worker_with_cache.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ] assert json.loads(responses.calls[-1].request.body) == { "elements": [ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], } ], "worker_run_id": "56785678-5678-5678-5678-567856785678", } assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}] # Check that created elements were properly stored in SQLite cache assert (tmp_path / "db.sqlite").is_file() assert list(CachedElement.select()) == [ CachedElement( id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"), parent_id=UUID("12341234-1234-1234-1234-123412341234"), type="something", image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), confidence=None, ) ] def test_create_elements_confidence( responses, mock_elements_worker_with_cache, tmp_path ): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "zone": { "image": { "id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", "width": 42, "height": 42, "url": "http://aaaa", } }, } ) responses.add( responses.POST, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", status=200, json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}], ) created_ids = mock_elements_worker_with_cache.create_elements( parent=elt, elements=[ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "confidence": 0.42, } ], ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "POST", "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", ), ] assert json.loads(responses.calls[-1].request.body) == { "elements": [ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], "confidence": 0.42, } ], "worker_run_id": "56785678-5678-5678-5678-567856785678", } assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}] # Check that created elements were properly stored in SQLite cache assert (tmp_path / "db.sqlite").is_file() assert list(CachedElement.select()) == [ CachedElement( id=UUID("497f6eca-6276-4993-bfeb-53cbbbba6f08"), parent_id=UUID("12341234-1234-1234-1234-123412341234"), type="something", image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], worker_run_id=UUID("56785678-5678-5678-5678-567856785678"), confidence=0.42, ) ] def test_create_elements_integrity_error( responses, mock_elements_worker_with_cache, caplog ): elt = Element( { "id": "12341234-1234-1234-1234-123412341234", "zone": { "image": { "id": "c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe", "width": 42, "height": 42, "url": "http://aaaa", } }, } ) responses.add( responses.POST, "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/", status=200, json=[ # Duplicate IDs, which will cause an IntegrityError when stored in the cache {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}, {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}, ], ) elements = [ { "name": "0", "type": "something", "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]], }, { "name": "1", "type": "something", "polygon": [[1, 1], [3, 3], [3, 1], [1, 3]], }, ] created_ids = mock_elements_worker_with_cache.create_elements( parent=elt, elements=elements, ) assert created_ids == [ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}, {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}, ] assert len(caplog.records) == 1 assert caplog.records[0].levelname == "WARNING" assert caplog.records[0].message.startswith( "Couldn't save created elements in local cache:" ) assert list(CachedElement.select()) == [] @pytest.mark.parametrize( "payload, error", ( # Element ( {"element": None}, "element shouldn't be null and should be an Element or CachedElement", ), ( {"element": "not element type"}, "element shouldn't be null and should be an Element or CachedElement", ), ), ) def test_partial_update_element_wrong_param_element( mock_elements_worker, payload, error ): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Type ({"type": 1234}, "type should be a str"), ({"type": None}, "type should be a str"), ), ) def test_partial_update_element_wrong_param_type(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Name ({"name": 1234}, "name should be a str"), ({"name": None}, "name should be a str"), ), ) def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Polygon ({"polygon": "not a polygon"}, "polygon should be a list"), ({"polygon": None}, "polygon should be a list"), ({"polygon": [[1, 1], [2, 2]]}, "polygon should have at least three points"), ( {"polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]]}, "polygon points should be lists of two items", ), ( {"polygon": [[1], [2], [2], [1]]}, "polygon points should be lists of two items", ), ( {"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]]}, "polygon points should be lists of two numbers", ), ), ) def test_partial_update_element_wrong_param_polygon( mock_elements_worker, payload, error ): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Confidence ({"confidence": "lol"}, "confidence should be None or a float in [0..1] range"), ({"confidence": "0.2"}, "confidence should be None or a float in [0..1] range"), ({"confidence": -1.0}, "confidence should be None or a float in [0..1] range"), ({"confidence": 1.42}, "confidence should be None or a float in [0..1] range"), ( {"confidence": float("inf")}, "confidence should be None or a float in [0..1] range", ), ), ) def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=re.escape(error)): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Rotation angle ({"rotation_angle": "lol"}, "rotation_angle should be a positive integer"), ({"rotation_angle": -1}, "rotation_angle should be a positive integer"), ({"rotation_angle": 0.5}, "rotation_angle should be a positive integer"), ({"rotation_angle": None}, "rotation_angle should be a positive integer"), ), ) def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Mirrored ({"mirrored": "lol"}, "mirrored should be a boolean"), ({"mirrored": 1234}, "mirrored should be a boolean"), ({"mirrored": None}, "mirrored should be a boolean"), ), ) def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) @pytest.mark.parametrize( "payload, error", ( # Image ({"image": "lol"}, "image should be a UUID"), ({"image": 1234}, "image should be a UUID"), ({"image": None}, "image should be a UUID"), ), ) def test_partial_update_element_wrong_param_image(mock_elements_worker, payload, error): api_payload = { "element": Element({"zone": None}), **payload, } with pytest.raises(AssertionError, match=error): mock_elements_worker.partial_update_element( **api_payload, ) def test_partial_update_element_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=500, ) with pytest.raises(ErrorResponse): mock_elements_worker.partial_update_element( element=elt, type="something", name="0", polygon=[[1, 1], [2, 2], [2, 1], [1, 2]], ) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We retry 5 times the API call ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ("PATCH", f"http://testserver/api/v1/element/{elt.id}/"), ] @pytest.mark.parametrize( "payload", ( ( { "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]], "confidence": None, } ), ( { "rotation_angle": 45, "mirrored": False, } ), ( { "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]], "confidence": None, "rotation_angle": 45, "mirrored": False, } ), ), ) def test_partial_update_element( responses, mock_elements_worker_with_cache, mock_cached_elements, mock_cached_images, payload, ): elt = CachedElement.select().first() new_image = CachedImage.select().first() elt_response = { "image": str(new_image.id), **payload, } responses.add( responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=200, # UUID not allowed in JSON json=elt_response, ) element_update_response = mock_elements_worker_with_cache.partial_update_element( element=elt, **{**elt_response, "image": new_image.id}, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "PATCH", f"http://testserver/api/v1/element/{elt.id}/", ), ] assert json.loads(responses.calls[-1].request.body) == elt_response assert element_update_response == elt_response cached_element = CachedElement.get(CachedElement.id == elt.id) # Always present in payload assert str(cached_element.image_id) == elt_response["image"] # Optional params if "polygon" in payload: # Cast to string as this is the only difference compared to model elt_response["polygon"] = str(elt_response["polygon"]) for param in payload: assert getattr(cached_element, param) == elt_response[param] @pytest.mark.parametrize("confidence", (None, 0.42)) def test_partial_update_element_confidence( responses, mock_elements_worker_with_cache, mock_cached_elements, confidence ): elt = CachedElement.select().first() elt_response = { "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]], "confidence": confidence, } responses.add( responses.PATCH, f"http://testserver/api/v1/element/{elt.id}/", status=200, json=elt_response, ) element_update_response = mock_elements_worker_with_cache.partial_update_element( element=elt, **elt_response, ) assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "PATCH", f"http://testserver/api/v1/element/{elt.id}/", ), ] assert json.loads(responses.calls[-1].request.body) == elt_response assert element_update_response == elt_response cached_element = CachedElement.get(CachedElement.id == elt.id) assert cached_element.polygon == str(elt_response["polygon"]) assert cached_element.confidence == confidence def test_list_element_children_wrong_element(mock_elements_worker): with pytest.raises( AssertionError, match="element shouldn't be null and should be an Element or CachedElement", ): mock_elements_worker.list_element_children(element=None) with pytest.raises( AssertionError, match="element shouldn't be null and should be an Element or CachedElement", ): mock_elements_worker.list_element_children(element="not element type") def test_list_element_children_wrong_folder(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="folder should be of type bool"): mock_elements_worker.list_element_children( element=elt, folder="not bool", ) def test_list_element_children_wrong_name(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="name should be of type str"): mock_elements_worker.list_element_children( element=elt, name=1234, ) def test_list_element_children_wrong_recursive(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="recursive should be of type bool"): mock_elements_worker.list_element_children( element=elt, recursive="not bool", ) def test_list_element_children_wrong_type(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="type should be of type str"): mock_elements_worker.list_element_children( element=elt, type=1234, ) def test_list_element_children_wrong_with_classes(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_classes should be of type bool"): mock_elements_worker.list_element_children( element=elt, with_classes="not bool", ) def test_list_element_children_wrong_with_corpus(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_corpus should be of type bool"): mock_elements_worker.list_element_children( element=elt, with_corpus="not bool", ) def test_list_element_children_wrong_with_has_children(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match="with_has_children should be of type bool" ): mock_elements_worker.list_element_children( element=elt, with_has_children="not bool", ) def test_list_element_children_wrong_with_zone(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_zone should be of type bool"): mock_elements_worker.list_element_children( element=elt, with_zone="not bool", ) def test_list_element_children_wrong_with_metadata(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_metadata should be of type bool"): mock_elements_worker.list_element_children( element=elt, with_metadata="not bool", ) @pytest.mark.parametrize( "param, value", ( ("worker_version", 1234), ("worker_run", 1234), ("transcription_worker_version", 1234), ("transcription_worker_run", 1234), ), ) def test_list_element_children_wrong_worker_version(mock_elements_worker, param, value): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match=f"{param} should be of type str or bool"): mock_elements_worker.list_element_children( element=elt, **{param: value}, ) @pytest.mark.parametrize( "param", ( ("worker_version"), ("worker_run"), ("transcription_worker_version"), ("transcription_worker_run"), ), ) def test_list_element_children_wrong_bool_worker_version(mock_elements_worker, param): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match=f"if of type bool, {param} can only be set to False" ): mock_elements_worker.list_element_children( element=elt, **{param: True}, ) def test_list_element_children_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", status=500, ) with pytest.raises( Exception, match="Stopping pagination as data will be incomplete" ): next(mock_elements_worker.list_element_children(element=elt)) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We do 5 retries ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ] def test_list_element_children(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_children = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, { "id": "1111", "type": "page", "name": "Test 2", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, { "id": "2222", "type": "page", "name": "Test 3", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", status=200, json={ "count": 3, "next": None, "results": expected_children, }, ) for idx, child in enumerate( mock_elements_worker.list_element_children(element=elt) ): assert child == expected_children[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/", ), ] def test_list_element_children_manual_worker_version(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_children = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, } ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False", status=200, json={ "count": 1, "next": None, "results": expected_children, }, ) for idx, child in enumerate( mock_elements_worker.list_element_children(element=elt, worker_version=False) ): assert child == expected_children[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_version=False", ), ] def test_list_element_children_manual_worker_run(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_children = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, } ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_run=False", status=200, json={ "count": 1, "next": None, "results": expected_children, }, ) for idx, child in enumerate( mock_elements_worker.list_element_children(element=elt, worker_run=False) ): assert child == expected_children[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/children/?worker_run=False", ), ] def test_list_element_children_with_cache_unhandled_param( mock_elements_worker_with_cache, ): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match="When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'", ): mock_elements_worker_with_cache.list_element_children( element=elt, with_corpus=True ) @pytest.mark.parametrize( "filters, expected_ids", ( # Filter on element should give all elements inserted ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), }, ( "11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333", ), ), # Filter on element and page should give the second element ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "page", }, ("22222222-2222-2222-2222-222222222222",), ), # Filter on element and worker version should give first two elements ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": "56785678-5678-5678-5678-567856785678", }, ( "11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222", ), ), # Filter on element, type something and worker version should give first ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "type": "something", "worker_version": "56785678-5678-5678-5678-567856785678", }, ("11111111-1111-1111-1111-111111111111",), ), # Filter on element, manual worker version should give third ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": False, }, ("33333333-3333-3333-3333-333333333333",), ), # Filter on element and worker run should give second ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_run": "56785678-5678-5678-5678-567856785678", }, ("22222222-2222-2222-2222-222222222222",), ), # Filter on element, manual worker run should give first and third ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_run": False, }, ( "11111111-1111-1111-1111-111111111111", "33333333-3333-3333-3333-333333333333", ), ), ), ) def test_list_element_children_with_cache( responses, mock_elements_worker_with_cache, mock_cached_elements, filters, expected_ids, ): # Check we have 5 elements already present in database assert CachedElement.select().count() == 5 # Query database through cache elements = mock_elements_worker_with_cache.list_element_children(**filters) assert elements.count() == len(expected_ids) for child, expected_id in zip(elements.order_by("id"), expected_ids): assert child.id == UUID(expected_id) # Check the worker never hits the API for elements assert len(responses.calls) == len(BASE_API_CALLS) assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS def test_list_element_parents_wrong_element(mock_elements_worker): with pytest.raises( AssertionError, match="element shouldn't be null and should be an Element or CachedElement", ): mock_elements_worker.list_element_parents(element=None) with pytest.raises( AssertionError, match="element shouldn't be null and should be an Element or CachedElement", ): mock_elements_worker.list_element_parents(element="not element type") def test_list_element_parents_wrong_folder(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="folder should be of type bool"): mock_elements_worker.list_element_parents( element=elt, folder="not bool", ) def test_list_element_parents_wrong_name(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="name should be of type str"): mock_elements_worker.list_element_parents( element=elt, name=1234, ) def test_list_element_parents_wrong_recursive(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="recursive should be of type bool"): mock_elements_worker.list_element_parents( element=elt, recursive="not bool", ) def test_list_element_parents_wrong_type(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="type should be of type str"): mock_elements_worker.list_element_parents( element=elt, type=1234, ) def test_list_element_parents_wrong_with_classes(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_classes should be of type bool"): mock_elements_worker.list_element_parents( element=elt, with_classes="not bool", ) def test_list_element_parents_wrong_with_corpus(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_corpus should be of type bool"): mock_elements_worker.list_element_parents( element=elt, with_corpus="not bool", ) def test_list_element_parents_wrong_with_has_children(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match="with_has_children should be of type bool" ): mock_elements_worker.list_element_parents( element=elt, with_has_children="not bool", ) def test_list_element_parents_wrong_with_zone(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_zone should be of type bool"): mock_elements_worker.list_element_parents( element=elt, with_zone="not bool", ) def test_list_element_parents_wrong_with_metadata(mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match="with_metadata should be of type bool"): mock_elements_worker.list_element_parents( element=elt, with_metadata="not bool", ) @pytest.mark.parametrize( "param, value", ( ("worker_version", 1234), ("worker_run", 1234), ("transcription_worker_version", 1234), ("transcription_worker_run", 1234), ), ) def test_list_element_parents_wrong_worker_version(mock_elements_worker, param, value): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises(AssertionError, match=f"{param} should be of type str or bool"): mock_elements_worker.list_element_parents( element=elt, **{param: value}, ) @pytest.mark.parametrize( "param", ( ("worker_version"), ("worker_run"), ("transcription_worker_version"), ("transcription_worker_run"), ), ) def test_list_element_parents_wrong_bool_worker_version(mock_elements_worker, param): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match=f"if of type bool, {param} can only be set to False" ): mock_elements_worker.list_element_parents( element=elt, **{param: True}, ) def test_list_element_parents_api_error(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", status=500, ) with pytest.raises( Exception, match="Stopping pagination as data will be incomplete" ): next(mock_elements_worker.list_element_parents(element=elt)) assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ # We do 5 retries ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ] def test_list_element_parents(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_parents = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, { "id": "1111", "type": "page", "name": "Test 2", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, { "id": "2222", "type": "page", "name": "Test 3", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, }, ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", status=200, json={ "count": 3, "next": None, "results": expected_parents, }, ) for idx, parent in enumerate( mock_elements_worker.list_element_parents(element=elt) ): assert parent == expected_parents[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/", ), ] def test_list_element_parents_manual_worker_version(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_parents = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, } ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_version=False", status=200, json={ "count": 1, "next": None, "results": expected_parents, }, ) for idx, parent in enumerate( mock_elements_worker.list_element_parents(element=elt, worker_version=False) ): assert parent == expected_parents[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_version=False", ), ] def test_list_element_parents_manual_worker_run(responses, mock_elements_worker): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) expected_parents = [ { "id": "0000", "type": "page", "name": "Test", "corpus": {}, "thumbnail_url": None, "zone": {}, "best_classes": None, "has_children": None, "worker_version_id": None, "worker_run_id": None, } ] responses.add( responses.GET, "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_run=False", status=200, json={ "count": 1, "next": None, "results": expected_parents, }, ) for idx, parent in enumerate( mock_elements_worker.list_element_parents(element=elt, worker_run=False) ): assert parent == expected_parents[idx] assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS + [ ( "GET", "http://testserver/api/v1/elements/12341234-1234-1234-1234-123412341234/parents/?worker_run=False", ), ] def test_list_element_parents_with_cache_unhandled_param( mock_elements_worker_with_cache, ): elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) with pytest.raises( AssertionError, match="When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'", ): mock_elements_worker_with_cache.list_element_parents( element=elt, with_corpus=True ) @pytest.mark.parametrize( "filters, expected_id", ( # Filter on element ( { "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), }, "12341234-1234-1234-1234-123412341234", ), # Filter on element and double_page ( { "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), "type": "double_page", }, "12341234-1234-1234-1234-123412341234", ), # Filter on element and worker version ( { "element": CachedElement(id="33333333-3333-3333-3333-333333333333"), "worker_version": "56785678-5678-5678-5678-567856785678", }, "12341234-1234-1234-1234-123412341234", ), # Filter on element, type double_page and worker version ( { "element": CachedElement(id="11111111-1111-1111-1111-111111111111"), "type": "double_page", "worker_version": "56785678-5678-5678-5678-567856785678", }, "12341234-1234-1234-1234-123412341234", ), # Filter on element, manual worker version ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_version": False, }, "99999999-9999-9999-9999-999999999999", ), # Filter on element and worker run ( { "element": CachedElement(id="22222222-2222-2222-2222-222222222222"), "worker_run": "56785678-5678-5678-5678-567856785678", }, "12341234-1234-1234-1234-123412341234", ), # Filter on element, manual worker run ( { "element": CachedElement(id="12341234-1234-1234-1234-123412341234"), "worker_run": False, }, "99999999-9999-9999-9999-999999999999", ), ), ) def test_list_element_parents_with_cache( responses, mock_elements_worker_with_cache, mock_cached_elements, filters, expected_id, ): # Check we have 5 elements already present in database assert CachedElement.select().count() == 5 # Query database through cache elements = mock_elements_worker_with_cache.list_element_parents(**filters) assert elements.count() == 1 for parent in elements.order_by("id"): assert parent.id == UUID(expected_id) # Check the worker never hits the API for elements assert len(responses.calls) == len(BASE_API_CALLS) assert [ (call.request.method, call.request.url) for call in responses.calls ] == BASE_API_CALLS