Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (3)
0.3.0-rc5 0.3.0-rc6
...@@ -229,12 +229,3 @@ class Transcription(MagicDict): ...@@ -229,12 +229,3 @@ class Transcription(MagicDict):
def __str__(self): def __str__(self):
return "Transcription ({})".format(self.id) return "Transcription ({})".format(self.id)
class Corpus(MagicDict):
"""
Describes an Arkindex corpus.
"""
def __str__(self):
return "Corpus {} ({})".format(self.name, self.id)
...@@ -119,6 +119,7 @@ class BaseWorker(object): ...@@ -119,6 +119,7 @@ class BaseWorker(object):
logger.info(f"Worker will use {self.work_dir} as working directory") logger.info(f"Worker will use {self.work_dir} as working directory")
self.process_information = None self.process_information = None
self.corpus_id = None
self.user_configuration = {} self.user_configuration = {}
self.support_cache = support_cache self.support_cache = support_cache
# use_cache will be updated in configure() if the cache is supported and if there # use_cache will be updated in configure() if the cache is supported and if there
...@@ -188,6 +189,9 @@ class BaseWorker(object): ...@@ -188,6 +189,9 @@ class BaseWorker(object):
# Load process information # Load process information
self.process_information = worker_run["process"] self.process_information = worker_run["process"]
# Load corpus id
self.corpus_id = worker_run["process"]["corpus"]
# Load worker version information # Load worker version information
worker_version = worker_run["worker_version"] worker_version = worker_run["worker_version"]
self.worker_details = worker_version["worker"] self.worker_details = worker_version["worker"]
...@@ -211,15 +215,14 @@ class BaseWorker(object): ...@@ -211,15 +215,14 @@ class BaseWorker(object):
# Load worker run configuration when available # Load worker run configuration when available
worker_configuration = worker_run.get("configuration") worker_configuration = worker_run.get("configuration")
self.user_configuration = ( if worker_configuration and worker_configuration.get("configuration"):
worker_configuration.get("configuration") if worker_configuration else None
)
if self.user_configuration:
logger.info("Loaded user configuration from WorkerRun") logger.info("Loaded user configuration from WorkerRun")
# if debug mode is set to true activate debug mode in logger self.user_configuration.update(worker_configuration.get("configuration"))
if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG) # if debug mode is set to true activate debug mode in logger
logger.debug("Debug output enabled") if self.user_configuration.get("debug"):
logger.setLevel(logging.DEBUG)
logger.debug("Debug output enabled")
def configure_cache(self): def configure_cache(self):
task_id = os.environ.get("PONOS_TASK") task_id = os.environ.get("PONOS_TASK")
......
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
""" """
ElementsWorker methods for classifications and ML classes. ElementsWorker methods for classifications and ML classes.
""" """
import os
from apistar.exceptions import ErrorResponse from apistar.exceptions import ErrorResponse
from peewee import IntegrityError from peewee import IntegrityError
...@@ -18,46 +15,38 @@ class ClassificationMixin(object): ...@@ -18,46 +15,38 @@ class ClassificationMixin(object):
Mixin for the :class:`ElementsWorker` to add ``MLClass`` and ``Classification`` helpers. Mixin for the :class:`ElementsWorker` to add ``MLClass`` and ``Classification`` helpers.
""" """
def load_corpus_classes(self, corpus_id): def load_corpus_classes(self):
""" """
Load all ML classes for the given corpus ID and store them in the ``self.classes`` cache. Load all ML classes for the given corpus ID and store them in the ``self.classes`` cache.
:param corpus_id str: ID of the corpus.
""" """
corpus_classes = self.api_client.paginate( corpus_classes = self.api_client.paginate(
"ListCorpusMLClasses", "ListCorpusMLClasses",
id=corpus_id, id=self.corpus_id,
) )
self.classes[corpus_id] = { self.classes[self.corpus_id] = {
ml_class["name"]: ml_class["id"] for ml_class in corpus_classes ml_class["name"]: ml_class["id"] for ml_class in corpus_classes
} }
logger.info(f"Loaded {len(self.classes[corpus_id])} ML classes") logger.info(f"Loaded {len(self.classes[self.corpus_id])} ML classes")
def get_ml_class_id(self, corpus_id, ml_class): def get_ml_class_id(self, ml_class):
""" """
Return the MLClass ID corresponding to the given class name on a specific corpus. Return the MLClass ID corresponding to the given class name on a specific corpus.
If no MLClass exists for this class name, a new one is created. If no MLClass exists for this class name, a new one is created.
:param corpus_id: ID of the corpus, or None to use the ``ARKINDEX_CORPUS_ID`` environment variable.
:type corpus_id: str or None
:param ml_class str: Name of the MLClass. :param ml_class str: Name of the MLClass.
:returns str: ID of the retrieved or created MLClass. :returns str: ID of the retrieved or created MLClass.
""" """
if corpus_id is None: if not self.classes.get(self.corpus_id):
corpus_id = os.environ.get("ARKINDEX_CORPUS_ID") self.load_corpus_classes()
if not self.classes.get(corpus_id):
self.load_corpus_classes(corpus_id)
ml_class_id = self.classes[corpus_id].get(ml_class) ml_class_id = self.classes[self.corpus_id].get(ml_class)
if ml_class_id is None: if ml_class_id is None:
logger.info(f"Creating ML class {ml_class} on corpus {corpus_id}") logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
try: try:
response = self.request( response = self.request(
"CreateMLClass", id=corpus_id, body={"name": ml_class} "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
) )
ml_class_id = self.classes[corpus_id][ml_class] = response["id"] ml_class_id = self.classes[self.corpus_id][ml_class] = response["id"]
logger.debug(f"Created ML class {response['id']}") logger.debug(f"Created ML class {response['id']}")
except ErrorResponse as e: except ErrorResponse as e:
# Only reload for 400 errors # Only reload for 400 errors
...@@ -68,11 +57,11 @@ class ClassificationMixin(object): ...@@ -68,11 +57,11 @@ class ClassificationMixin(object):
logger.info( logger.info(
f"Reloading corpus classes to see if {ml_class} already exists" f"Reloading corpus classes to see if {ml_class} already exists"
) )
self.load_corpus_classes(corpus_id) self.load_corpus_classes()
assert ( assert (
ml_class in self.classes[corpus_id] ml_class in self.classes[self.corpus_id]
), "Missing class {ml_class} even after reloading" ), "Missing class {ml_class} even after reloading"
ml_class_id = self.classes[corpus_id][ml_class] ml_class_id = self.classes[self.corpus_id][ml_class]
return ml_class_id return ml_class_id
...@@ -112,7 +101,7 @@ class ClassificationMixin(object): ...@@ -112,7 +101,7 @@ class ClassificationMixin(object):
"CreateClassification", "CreateClassification",
body={ body={
"element": str(element.id), "element": str(element.id),
"ml_class": self.get_ml_class_id(None, ml_class), "ml_class": self.get_ml_class_id(ml_class),
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
"confidence": confidence, "confidence": confidence,
"high_confidence": high_confidence, "high_confidence": high_confidence,
......
...@@ -2,15 +2,13 @@ ...@@ -2,15 +2,13 @@
""" """
ElementsWorker methods for elements and element types. ElementsWorker methods for elements and element types.
""" """
import uuid
from typing import Dict, Iterable, List, NamedTuple, Optional, Union from typing import Dict, Iterable, List, NamedTuple, Optional, Union
from peewee import IntegrityError from peewee import IntegrityError
from arkindex_worker import logger from arkindex_worker import logger
from arkindex_worker.cache import CachedElement, CachedImage from arkindex_worker.cache import CachedElement, CachedImage
from arkindex_worker.models import Corpus, Element from arkindex_worker.models import Element
ElementType = NamedTuple("ElementType", name=str, slug=str, is_folder=bool) ElementType = NamedTuple("ElementType", name=str, slug=str, is_folder=bool)
...@@ -26,7 +24,7 @@ class ElementMixin(object): ...@@ -26,7 +24,7 @@ class ElementMixin(object):
Mixin for the :class:`ElementsWorker` to provide ``Element`` helpers. Mixin for the :class:`ElementsWorker` to provide ``Element`` helpers.
""" """
def create_required_types(self, corpus: Corpus, element_types: List[ElementType]): def create_required_types(self, element_types: List[ElementType]):
"""Creates given element types in the corpus. """Creates given element types in the corpus.
:param Corpus corpus: The corpus to create types on. :param Corpus corpus: The corpus to create types on.
...@@ -39,47 +37,42 @@ class ElementMixin(object): ...@@ -39,47 +37,42 @@ class ElementMixin(object):
"slug": element_type.slug, "slug": element_type.slug,
"display_name": element_type.name, "display_name": element_type.name,
"folder": element_type.is_folder, "folder": element_type.is_folder,
"corpus": corpus.id, "corpus": self.corpus_id,
}, },
) )
logger.info(f"Created a new element type with slug {element_type.slug}") logger.info(f"Created a new element type with slug {element_type.slug}")
def check_required_types( def check_required_types(
self, corpus_id: str, *type_slugs: str, create_missing: bool = False self, *type_slugs: str, create_missing: bool = False
) -> bool: ) -> bool:
""" """
Check that a corpus has a list of required element types, Check that a corpus has a list of required element types,
and raise an exception if any of them are missing. and raise an exception if any of them are missing.
:param str corpus_id: ID of the corpus to check types on.
:param str \\*type_slugs: Type slugs to look for. :param str \\*type_slugs: Type slugs to look for.
:param bool create_missing: Whether missing types should be created. :param bool create_missing: Whether missing types should be created.
:returns bool: True if all of the specified type slugs have been found. :returns bool: True if all of the specified type slugs have been found.
:raises MissingTypeError: If any of the specified type slugs were not found. :raises MissingTypeError: If any of the specified type slugs were not found.
""" """
assert isinstance(
corpus_id, (uuid.UUID, str)
), "Corpus ID should be a string or UUID"
assert len(type_slugs), "At least one element type slug is required." assert len(type_slugs), "At least one element type slug is required."
assert all( assert all(
isinstance(slug, str) for slug in type_slugs isinstance(slug, str) for slug in type_slugs
), "Element type slugs must be strings." ), "Element type slugs must be strings."
corpus = Corpus(self.request("RetrieveCorpus", id=corpus_id)) corpus = self.request("RetrieveCorpus", id=self.corpus_id)
available_slugs = {element_type.slug for element_type in corpus.types} available_slugs = {element_type["slug"] for element_type in corpus["types"]}
missing_slugs = set(type_slugs) - available_slugs missing_slugs = set(type_slugs) - available_slugs
if missing_slugs: if missing_slugs:
if create_missing: if create_missing:
self.create_required_types( self.create_required_types(
corpus,
element_types=[ element_types=[
ElementType(slug, slug, False) for slug in missing_slugs ElementType(slug, slug, False) for slug in missing_slugs
], ],
) )
else: else:
raise MissingTypeError( raise MissingTypeError(
f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus.name} corpus ({corpus.id}).' f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
) )
return True return True
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
""" """
ElementsWorker methods for entities. ElementsWorker methods for entities.
""" """
import os
from enum import Enum from enum import Enum
from peewee import IntegrityError from peewee import IntegrityError
...@@ -32,9 +30,7 @@ class EntityMixin(object): ...@@ -32,9 +30,7 @@ class EntityMixin(object):
Mixin for the :class:`ElementsWorker` to add ``Entity`` helpers. Mixin for the :class:`ElementsWorker` to add ``Entity`` helpers.
""" """
def create_entity( def create_entity(self, element, name, type, metas=dict(), validated=None):
self, element, name, type, corpus=None, metas=dict(), validated=None
):
""" """
Create an entity on the given corpus. Create an entity on the given corpus.
If cache support is enabled, a :class:`CachedEntity` will also be created. If cache support is enabled, a :class:`CachedEntity` will also be created.
...@@ -44,13 +40,7 @@ class EntityMixin(object): ...@@ -44,13 +40,7 @@ class EntityMixin(object):
:type element: Element or CachedElement :type element: Element or CachedElement
:param name str: Name of the entity. :param name str: Name of the entity.
:param type EntityType: Type of the entity. :param type EntityType: Type of the entity.
:param corpus: UUID of the corpus to create an entity on, or None to use the
value of the ``ARKINDEX_CORPUS_ID`` environment variable.
:type corpus: str or None
""" """
if corpus is None:
corpus = os.environ.get("ARKINDEX_CORPUS_ID")
assert element and isinstance( assert element and isinstance(
element, (Element, CachedElement) element, (Element, CachedElement)
), "element shouldn't be null and should be an Element or CachedElement" ), "element shouldn't be null and should be an Element or CachedElement"
...@@ -60,9 +50,6 @@ class EntityMixin(object): ...@@ -60,9 +50,6 @@ class EntityMixin(object):
assert type and isinstance( assert type and isinstance(
type, EntityType type, EntityType
), "type shouldn't be null and should be of type EntityType" ), "type shouldn't be null and should be of type EntityType"
assert corpus and isinstance(
corpus, str
), "corpus shouldn't be null and should be of type str"
if metas: if metas:
assert isinstance(metas, dict), "metas should be of type dict" assert isinstance(metas, dict), "metas should be of type dict"
if validated is not None: if validated is not None:
...@@ -78,7 +65,7 @@ class EntityMixin(object): ...@@ -78,7 +65,7 @@ class EntityMixin(object):
"type": type.value, "type": type.value,
"metas": metas, "metas": metas,
"validated": validated, "validated": validated,
"corpus": corpus, "corpus": self.corpus_id,
"worker_version": self.worker_version_id, "worker_version": self.worker_version_id,
}, },
) )
......
...@@ -106,7 +106,6 @@ def give_env_variable(request, monkeypatch): ...@@ -106,7 +106,6 @@ def give_env_variable(request, monkeypatch):
"""Defines required environment variables""" """Defines required environment variables"""
monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234") monkeypatch.setenv("WORKER_VERSION_ID", "12341234-1234-1234-1234-123412341234")
monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", "56785678-5678-5678-5678-567856785678") monkeypatch.setenv("ARKINDEX_WORKER_RUN_ID", "56785678-5678-5678-5678-567856785678")
monkeypatch.setenv("ARKINDEX_CORPUS_ID", "11111111-1111-1111-1111-111111111111")
@pytest.fixture @pytest.fixture
...@@ -207,7 +206,6 @@ def mock_activity_calls(responses): ...@@ -207,7 +206,6 @@ def mock_activity_calls(responses):
def mock_elements_worker(monkeypatch, mock_worker_run_api): def mock_elements_worker(monkeypatch, mock_worker_run_api):
"""Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest""" """Build and configure an ElementsWorker with fixed CLI parameters to avoid issues with pytest"""
monkeypatch.setattr(sys, "argv", ["worker"]) monkeypatch.setattr(sys, "argv", ["worker"])
worker = ElementsWorker() worker = ElementsWorker()
worker.configure() worker.configure()
return worker return worker
......
...@@ -194,7 +194,10 @@ def test_configure_worker_run(mocker, monkeypatch, responses): ...@@ -194,7 +194,10 @@ def test_configure_worker_run(mocker, monkeypatch, responses):
"configuration": {"configuration": {}}, "configuration": {"configuration": {}},
}, },
"configuration": user_configuration, "configuration": user_configuration,
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"}, "process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
} }
responses.add( responses.add(
...@@ -246,7 +249,15 @@ def test_configure_user_configuration_defaults( ...@@ -246,7 +249,15 @@ def test_configure_user_configuration_defaults(
}, },
"revision": {"hash": "deadbeef1234"}, "revision": {"hash": "deadbeef1234"},
"configuration": { "configuration": {
"configuration": {"param_1": "/some/path/file.pth", "param_2": 12} "configuration": {"param_1": "/some/path/file.pth", "param_2": 12},
"user_configuration": {
"integer_parameter": {
"type": "int",
"title": "Lambda",
"default": 0,
"required": False,
}
},
}, },
}, },
"configuration": { "configuration": {
...@@ -257,7 +268,10 @@ def test_configure_user_configuration_defaults( ...@@ -257,7 +268,10 @@ def test_configure_user_configuration_defaults(
"param_5": True, "param_5": True,
}, },
}, },
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"}, "process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
} }
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -271,6 +285,7 @@ def test_configure_user_configuration_defaults( ...@@ -271,6 +285,7 @@ def test_configure_user_configuration_defaults(
assert worker.config == {"param_1": "/some/path/file.pth", "param_2": 12} assert worker.config == {"param_1": "/some/path/file.pth", "param_2": 12}
assert worker.user_configuration == { assert worker.user_configuration == {
"integer_parameter": 0,
"param_3": "Animula vagula blandula", "param_3": "Animula vagula blandula",
"param_5": True, "param_5": True,
} }
...@@ -310,7 +325,10 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug): ...@@ -310,7 +325,10 @@ def test_configure_user_config_debug(mocker, monkeypatch, responses, debug):
"name": "BBB", "name": "BBB",
"configuration": {"debug": debug}, "configuration": {"debug": debug},
}, },
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"}, "process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
} }
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -357,7 +375,10 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses): ...@@ -357,7 +375,10 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
"configuration": {"configuration": {}}, "configuration": {"configuration": {}},
}, },
"configuration": {"id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "BBB"}, "configuration": {"id": "bbbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "BBB"},
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"}, "process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
} }
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -369,7 +390,7 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses): ...@@ -369,7 +390,7 @@ def test_configure_worker_run_missing_conf(mocker, monkeypatch, responses):
worker.args = worker.parser.parse_args() worker.args = worker.parser.parse_args()
worker.configure() worker.configure()
assert worker.user_configuration is None assert worker.user_configuration == {}
def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses): def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses):
...@@ -404,7 +425,10 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses) ...@@ -404,7 +425,10 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses)
"configuration": {}, "configuration": {},
}, },
"configuration": None, "configuration": None,
"process": {"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"}, "process": {
"id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
"corpus": "11111111-1111-1111-1111-111111111111",
},
} }
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -416,7 +440,7 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses) ...@@ -416,7 +440,7 @@ def test_configure_worker_run_no_worker_run_conf(mocker, monkeypatch, responses)
worker.args = worker.parser.parse_args() worker.args = worker.parser.parse_args()
worker.configure() worker.configure()
assert worker.user_configuration is None assert worker.user_configuration == {}
def test_load_missing_secret(): def test_load_missing_secret():
......
...@@ -12,7 +12,7 @@ from . import BASE_API_CALLS ...@@ -12,7 +12,7 @@ from . import BASE_API_CALLS
def test_get_ml_class_id_load_classes(responses, mock_elements_worker): def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "11111111-1111-1111-1111-111111111111"
responses.add( responses.add(
responses.GET, responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/", f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
...@@ -30,25 +30,28 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker): ...@@ -30,25 +30,28 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
) )
assert not mock_elements_worker.classes assert not mock_elements_worker.classes
ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") ml_class_id = mock_elements_worker.get_ml_class_id("good")
assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [ assert [
(call.request.method, call.request.url) for call in responses.calls (call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [ ] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), (
"GET",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
),
] ]
assert mock_elements_worker.classes == { assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {"good": "0000"} "11111111-1111-1111-1111-111111111111": {"good": "0000"}
} }
assert ml_class_id == "0000" assert ml_class_id == "0000"
def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
# A missing class is now created automatically # A missing class is now created automatically
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "11111111-1111-1111-1111-111111111111"
mock_elements_worker.classes = { mock_elements_worker.classes = {
"12341234-1234-1234-1234-123412341234": {"good": "0000"} "11111111-1111-1111-1111-111111111111": {"good": "0000"}
} }
responses.add( responses.add(
...@@ -60,15 +63,15 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): ...@@ -60,15 +63,15 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
# Missing class at first # Missing class at first
assert mock_elements_worker.classes == { assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": {"good": "0000"} "11111111-1111-1111-1111-111111111111": {"good": "0000"}
} }
ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "bad") ml_class_id = mock_elements_worker.get_ml_class_id("bad")
assert ml_class_id == "new-ml-class-1234" assert ml_class_id == "new-ml-class-1234"
# Now it's available # Now it's available
assert mock_elements_worker.classes == { assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": { "11111111-1111-1111-1111-111111111111": {
"good": "0000", "good": "0000",
"bad": "new-ml-class-1234", "bad": "new-ml-class-1234",
} }
...@@ -76,17 +79,16 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses): ...@@ -76,17 +79,16 @@ def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
def test_get_ml_class_id(mock_elements_worker): def test_get_ml_class_id(mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
mock_elements_worker.classes = { mock_elements_worker.classes = {
"12341234-1234-1234-1234-123412341234": {"good": "0000"} "11111111-1111-1111-1111-111111111111": {"good": "0000"}
} }
ml_class_id = mock_elements_worker.get_ml_class_id(corpus_id, "good") ml_class_id = mock_elements_worker.get_ml_class_id("good")
assert ml_class_id == "0000" assert ml_class_id == "0000"
def test_get_ml_class_reload(responses, mock_elements_worker): def test_get_ml_class_reload(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "11111111-1111-1111-1111-111111111111"
# Add some initial classes # Add some initial classes
responses.add( responses.add(
...@@ -133,7 +135,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker): ...@@ -133,7 +135,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
) )
# Simply request class 2, it should be reloaded # Simply request class 2, it should be reloaded
assert mock_elements_worker.get_ml_class_id(corpus_id, "class2") == "class2_id" assert mock_elements_worker.get_ml_class_id("class2") == "class2_id"
assert len(responses.calls) == len(BASE_API_CALLS) + 3 assert len(responses.calls) == len(BASE_API_CALLS) + 3
assert mock_elements_worker.classes == { assert mock_elements_worker.classes == {
...@@ -145,9 +147,18 @@ def test_get_ml_class_reload(responses, mock_elements_worker): ...@@ -145,9 +147,18 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
assert [ assert [
(call.request.method, call.request.url) for call in responses.calls (call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [ ] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), (
("POST", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), "GET",
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
),
(
"POST",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
),
(
"GET",
f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
),
] ]
......
...@@ -22,20 +22,17 @@ from . import BASE_API_CALLS ...@@ -22,20 +22,17 @@ from . import BASE_API_CALLS
def test_check_required_types_argument_types(mock_elements_worker): def test_check_required_types_argument_types(mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234"
worker = ElementsWorker()
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
worker.check_required_types(corpus_id) mock_elements_worker.check_required_types()
assert str(e.value) == "At least one element type slug is required." assert str(e.value) == "At least one element type slug is required."
with pytest.raises(AssertionError) as e: with pytest.raises(AssertionError) as e:
worker.check_required_types(corpus_id, "lol", 42) mock_elements_worker.check_required_types("lol", 42)
assert str(e.value) == "Element type slugs must be strings." assert str(e.value) == "Element type slugs must be strings."
def test_check_required_types(responses): def test_check_required_types(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "11111111-1111-1111-1111-111111111111"
responses.add( responses.add(
responses.GET, responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/", f"http://testserver/api/v1/corpus/{corpus_id}/",
...@@ -45,22 +42,21 @@ def test_check_required_types(responses): ...@@ -45,22 +42,21 @@ def test_check_required_types(responses):
"types": [{"slug": "folder"}, {"slug": "page"}], "types": [{"slug": "folder"}, {"slug": "page"}],
}, },
) )
worker = ElementsWorker() mock_elements_worker.setup_api_client()
worker.setup_api_client()
assert worker.check_required_types(corpus_id, "page") assert mock_elements_worker.check_required_types("page")
assert worker.check_required_types(corpus_id, "page", "folder") assert mock_elements_worker.check_required_types("page", "folder")
with pytest.raises(MissingTypeError) as e: with pytest.raises(MissingTypeError) as e:
assert worker.check_required_types(corpus_id, "page", "text_line", "act") assert mock_elements_worker.check_required_types("page", "text_line", "act")
assert ( assert (
str(e.value) str(e.value)
== "Element type(s) act, text_line were not found in the Some Corpus corpus (12341234-1234-1234-1234-123412341234)." == "Element type(s) act, text_line were not found in the Some Corpus corpus (11111111-1111-1111-1111-111111111111)."
) )
def test_create_missing_types(responses): def test_create_missing_types(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" corpus_id = "11111111-1111-1111-1111-111111111111"
responses.add( responses.add(
responses.GET, responses.GET,
...@@ -99,11 +95,10 @@ def test_create_missing_types(responses): ...@@ -99,11 +95,10 @@ def test_create_missing_types(responses):
) )
], ],
) )
worker = ElementsWorker() mock_elements_worker.setup_api_client()
worker.setup_api_client()
assert worker.check_required_types( assert mock_elements_worker.check_required_types(
corpus_id, "page", "text_line", "act", create_missing=True "page", "text_line", "act", create_missing=True
) )
...@@ -276,10 +271,10 @@ def test_database_arg_cache_missing_version_table( ...@@ -276,10 +271,10 @@ def test_database_arg_cache_missing_version_table(
def test_load_corpus_classes_api_error(responses, mock_elements_worker): def test_load_corpus_classes_api_error(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" mock_elements_worker.corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add( responses.add(
responses.GET, responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/", f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
status=500, status=500,
) )
...@@ -287,27 +282,42 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker): ...@@ -287,27 +282,42 @@ def test_load_corpus_classes_api_error(responses, mock_elements_worker):
with pytest.raises( with pytest.raises(
Exception, match="Stopping pagination as data will be incomplete" Exception, match="Stopping pagination as data will be incomplete"
): ):
mock_elements_worker.load_corpus_classes(corpus_id) mock_elements_worker.load_corpus_classes()
assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert len(responses.calls) == len(BASE_API_CALLS) + 5
assert [ assert [
(call.request.method, call.request.url) for call in responses.calls (call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [ ] == BASE_API_CALLS + [
# We do 5 retries # We do 5 retries
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), (
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), "GET",
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), ),
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), (
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
(
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
] ]
assert not mock_elements_worker.classes assert not mock_elements_worker.classes
def test_load_corpus_classes(responses, mock_elements_worker): def test_load_corpus_classes(responses, mock_elements_worker):
corpus_id = "12341234-1234-1234-1234-123412341234" mock_elements_worker.corpus_id = "12341234-1234-1234-1234-123412341234"
responses.add( responses.add(
responses.GET, responses.GET,
f"http://testserver/api/v1/corpus/{corpus_id}/classes/", f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
status=200, status=200,
json={ json={
"count": 3, "count": 3,
...@@ -330,13 +340,16 @@ def test_load_corpus_classes(responses, mock_elements_worker): ...@@ -330,13 +340,16 @@ def test_load_corpus_classes(responses, mock_elements_worker):
) )
assert not mock_elements_worker.classes assert not mock_elements_worker.classes
mock_elements_worker.load_corpus_classes(corpus_id) mock_elements_worker.load_corpus_classes()
assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [ assert [
(call.request.method, call.request.url) for call in responses.calls (call.request.method, call.request.url) for call in responses.calls
] == BASE_API_CALLS + [ ] == BASE_API_CALLS + [
("GET", f"http://testserver/api/v1/corpus/{corpus_id}/classes/"), (
"GET",
f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
),
] ]
assert mock_elements_worker.classes == { assert mock_elements_worker.classes == {
"12341234-1234-1234-1234-123412341234": { "12341234-1234-1234-1234-123412341234": {
......
...@@ -24,7 +24,6 @@ def test_create_entity_wrong_element(mock_elements_worker): ...@@ -24,7 +24,6 @@ def test_create_entity_wrong_element(mock_elements_worker):
element=None, element=None,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert ( assert (
str(e.value) str(e.value)
...@@ -36,7 +35,6 @@ def test_create_entity_wrong_element(mock_elements_worker): ...@@ -36,7 +35,6 @@ def test_create_entity_wrong_element(mock_elements_worker):
element="not element type", element="not element type",
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert ( assert (
str(e.value) str(e.value)
...@@ -52,7 +50,6 @@ def test_create_entity_wrong_name(mock_elements_worker): ...@@ -52,7 +50,6 @@ def test_create_entity_wrong_name(mock_elements_worker):
element=elt, element=elt,
name=None, name=None,
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert str(e.value) == "name shouldn't be null and should be of type str" assert str(e.value) == "name shouldn't be null and should be of type str"
...@@ -61,7 +58,6 @@ def test_create_entity_wrong_name(mock_elements_worker): ...@@ -61,7 +58,6 @@ def test_create_entity_wrong_name(mock_elements_worker):
element=elt, element=elt,
name=1234, name=1234,
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert str(e.value) == "name shouldn't be null and should be of type str" assert str(e.value) == "name shouldn't be null and should be of type str"
...@@ -74,7 +70,6 @@ def test_create_entity_wrong_type(mock_elements_worker): ...@@ -74,7 +70,6 @@ def test_create_entity_wrong_type(mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=None, type=None,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert str(e.value) == "type shouldn't be null and should be of type EntityType" assert str(e.value) == "type shouldn't be null and should be of type EntityType"
...@@ -83,7 +78,6 @@ def test_create_entity_wrong_type(mock_elements_worker): ...@@ -83,7 +78,6 @@ def test_create_entity_wrong_type(mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=1234, type=1234,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert str(e.value) == "type shouldn't be null and should be of type EntityType" assert str(e.value) == "type shouldn't be null and should be of type EntityType"
...@@ -92,7 +86,6 @@ def test_create_entity_wrong_type(mock_elements_worker): ...@@ -92,7 +86,6 @@ def test_create_entity_wrong_type(mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type="not_an_entity_type", type="not_an_entity_type",
corpus="12341234-1234-1234-1234-123412341234",
) )
assert str(e.value) == "type shouldn't be null and should be of type EntityType" assert str(e.value) == "type shouldn't be null and should be of type EntityType"
...@@ -111,26 +104,6 @@ def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker): ...@@ -111,26 +104,6 @@ def test_create_entity_wrong_corpus(monkeypatch, mock_elements_worker):
) )
assert str(e.value) == "metas should be of type dict" assert str(e.value) == "metas should be of type dict"
# Removing ARKINDEX_CORPUS_ID environment variable should give an error when corpus=None
monkeypatch.delenv("ARKINDEX_CORPUS_ID")
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
corpus=None,
)
assert str(e.value) == "corpus shouldn't be null and should be of type str"
with pytest.raises(AssertionError) as e:
mock_elements_worker.create_entity(
element=elt,
name="Bob Bob",
type=EntityType.Person,
corpus=1234,
)
assert str(e.value) == "corpus shouldn't be null and should be of type str"
def test_create_entity_wrong_metas(mock_elements_worker): def test_create_entity_wrong_metas(mock_elements_worker):
elt = Element({"id": "12341234-1234-1234-1234-123412341234"}) elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
...@@ -140,7 +113,6 @@ def test_create_entity_wrong_metas(mock_elements_worker): ...@@ -140,7 +113,6 @@ def test_create_entity_wrong_metas(mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
metas="wrong metas", metas="wrong metas",
) )
assert str(e.value) == "metas should be of type dict" assert str(e.value) == "metas should be of type dict"
...@@ -154,7 +126,6 @@ def test_create_entity_wrong_validated(mock_elements_worker): ...@@ -154,7 +126,6 @@ def test_create_entity_wrong_validated(mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
validated="wrong validated", validated="wrong validated",
) )
assert str(e.value) == "validated should be of type bool" assert str(e.value) == "validated should be of type bool"
...@@ -173,7 +144,6 @@ def test_create_entity_api_error(responses, mock_elements_worker): ...@@ -173,7 +144,6 @@ def test_create_entity_api_error(responses, mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert len(responses.calls) == len(BASE_API_CALLS) + 5 assert len(responses.calls) == len(BASE_API_CALLS) + 5
...@@ -202,7 +172,6 @@ def test_create_entity(responses, mock_elements_worker): ...@@ -202,7 +172,6 @@ def test_create_entity(responses, mock_elements_worker):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert len(responses.calls) == len(BASE_API_CALLS) + 1
...@@ -216,7 +185,7 @@ def test_create_entity(responses, mock_elements_worker): ...@@ -216,7 +185,7 @@ def test_create_entity(responses, mock_elements_worker):
"type": "person", "type": "person",
"metas": {}, "metas": {},
"validated": None, "validated": None,
"corpus": "12341234-1234-1234-1234-123412341234", "corpus": "11111111-1111-1111-1111-111111111111",
"worker_version": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234",
} }
assert entity_id == "12345678-1234-1234-1234-123456789123" assert entity_id == "12345678-1234-1234-1234-123456789123"
...@@ -235,7 +204,6 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): ...@@ -235,7 +204,6 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
element=elt, element=elt,
name="Bob Bob", name="Bob Bob",
type=EntityType.Person, type=EntityType.Person,
corpus="12341234-1234-1234-1234-123412341234",
) )
assert len(responses.calls) == len(BASE_API_CALLS) + 1 assert len(responses.calls) == len(BASE_API_CALLS) + 1
...@@ -250,7 +218,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache): ...@@ -250,7 +218,7 @@ def test_create_entity_with_cache(responses, mock_elements_worker_with_cache):
"type": "person", "type": "person",
"metas": {}, "metas": {},
"validated": None, "validated": None,
"corpus": "12341234-1234-1234-1234-123412341234", "corpus": "11111111-1111-1111-1111-111111111111",
"worker_version": "12341234-1234-1234-1234-123412341234", "worker_version": "12341234-1234-1234-1234-123412341234",
} }
assert entity_id == "12345678-1234-1234-1234-123456789123" assert entity_id == "12345678-1234-1234-1234-123456789123"
......