Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • workers/base-worker
1 result
Show changes
Commits on Source (12)
......@@ -148,6 +148,13 @@ class BaseWorker:
# there is at least one available sqlite database either given or in the parent tasks
self.use_cache = False
# model_version_id will be updated in configure() using the worker_run's model version
# or in configure_for_developers() from the environment
self.model_version_id = None
# model_details will be updated in configure() using the worker_run's model version
# or in configure_for_developers() from the environment
self.model_details = {}
# task_parents will be updated in configure_cache() if the cache is supported,
# if the task ID is set and if no database is passed as argument
self.task_parents = []
......@@ -257,15 +264,15 @@ class BaseWorker:
# Load model version configuration when available
model_version = worker_run.get("model_version")
if model_version and model_version.get("configuration"):
if model_version:
logger.info("Loaded model version configuration from WorkerRun")
self.model_configuration.update(model_version.get("configuration"))
self.model_configuration.update(model_version["configuration"])
# Set model_version ID as worker attribute
self.model_version_id = model_version.get("id")
self.model_version_id = model_version["id"]
# Set model details as worker attribute
self.model_details = model_version.get("model")
self.model_details = model_version["model"]
# Retrieve initial configuration from API
self.config = worker_version["configuration"].get("configuration", {})
......
......@@ -49,7 +49,10 @@ class DatasetMixin:
"ListProcessDatasets", id=self.process_information["id"]
)
return map(Dataset, list(results))
return map(
lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
list(results),
)
def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
"""
......@@ -65,9 +68,11 @@ class DatasetMixin:
results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
def format_result(result):
if result["set"] not in dataset.selected_sets:
return
return (result["set"], Element(**result["element"]))
return map(format_result, list(results))
return filter(None, map(format_result, list(results)))
@unsupported_cache
def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
......
......@@ -81,6 +81,10 @@ class TrainingMixin:
model_version = None
@property
def is_finetuning(self) -> bool:
return bool(self.model_version_id)
@skip_if_read_only
def publish_model_version(
self,
......
black==23.12.0
black==24.2.0
doc8==1.1.1
mkdocs==1.5.3
mkdocs-material==9.5.2
mkdocs-material==9.5.10
mkdocstrings==0.24.0
mkdocstrings-python==1.7.5
mkdocstrings-python==1.8.0
recommonmark==0.7.1
......@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "arkindex-base-worker"
version = "0.3.7-rc1"
version = "0.3.7rc3"
description = "Base Worker to easily build Arkindex ML workflows"
license = { file = "LICENSE" }
dynamic = ["dependencies", "optional-dependencies"]
......
arkindex-client==1.0.14
peewee==3.17.0
Pillow==10.1.0
pymdown-extensions==10.5
Pillow==10.2.0
pymdown-extensions==10.7
python-gnupg==0.5.2
shapely==2.0.2
shapely==2.0.3
tenacity==8.2.3
zstandard==0.22.0
pytest==7.4.3
pytest==8.0.1
pytest-mock==3.12.0
pytest-responses==0.5.1
......@@ -601,14 +601,15 @@ def default_dataset():
"id": "dataset_id",
"name": "My dataset",
"description": "A super dataset built by me",
"sets": ["set_1", "set_2", "set_3"],
"sets": ["set_1", "set_2", "set_3", "set_4"],
"state": DatasetState.Open.value,
"corpus_id": "corpus_id",
"creator": "creator@teklia.com",
"task_id": "11111111-1111-1111-1111-111111111111",
"created": "2000-01-01T00:00:00Z",
"updated": "2000-01-01T00:00:00Z",
}
},
selected_sets=["set_1", "set_2", "set_3"],
)
......
......@@ -296,6 +296,22 @@ def test_list_dataset_elements_per_split(
"worker_run_id": None,
},
},
# `set_4` is not in `default_dataset.selected_sets`
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
},
]
responses.add(
responses.GET,
......@@ -362,34 +378,46 @@ def test_list_datasets_api_error(responses, mock_dataset_worker):
def test_list_datasets(responses, mock_dataset_worker):
expected_results = [
{
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
"sets": ["train", "val", "test"],
"state": "open",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_1",
"id": "process_dataset_1",
"dataset": {
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
"sets": ["train", "val", "test"],
"state": "open",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_1",
},
"sets": ["test"],
},
{
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
"id": "process_dataset_2",
"dataset": {
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
"sets": ["train", "val"],
"state": "complete",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
"sets": ["train", "val"],
"state": "complete",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
{
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
"sets": ["nonsense", "random set"],
"state": "error",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_3",
"id": "process_dataset_3",
"dataset": {
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
"sets": ["nonsense", "random set"],
"state": "error",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_3",
},
"sets": ["random set"],
},
]
responses.add(
......@@ -403,7 +431,11 @@ def test_list_datasets(responses, mock_dataset_worker):
},
)
assert list(mock_dataset_worker.list_datasets()) == expected_results
for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
assert dataset == {
**expected_results[idx]["dataset"],
"selected_sets": expected_results[idx]["sets"],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......
......@@ -52,34 +52,46 @@ def test_list_process_datasets(
):
expected_results = [
{
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
"sets": ["train", "val", "test"],
"state": "open",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_1",
"id": "process_dataset_1",
"dataset": {
"id": "dataset_1",
"name": "Dataset 1",
"description": "My first great dataset",
"sets": ["train", "val", "test"],
"state": "open",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_1",
},
"sets": ["test"],
},
{
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
"id": "process_dataset_2",
"dataset": {
"id": "dataset_2",
"name": "Dataset 2",
"description": "My second great dataset",
"sets": ["train", "val"],
"state": "complete",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
"sets": ["train", "val"],
"state": "complete",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_2",
},
{
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
"sets": ["nonsense", "random set"],
"state": "error",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_3",
"id": "process_dataset_3",
"dataset": {
"id": "dataset_3",
"name": "Dataset 3 (TRASHME)",
"description": "My third dataset, in error",
"sets": ["nonsense", "random set"],
"state": "error",
"corpus_id": "corpus_id",
"creator": "test@teklia.com",
"task_id": "task_id_3",
},
"sets": ["random set"],
},
]
responses.add(
......@@ -95,7 +107,10 @@ def test_list_process_datasets(
for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
assert isinstance(dataset, Dataset)
assert dataset == expected_results[idx]
assert dataset == {
**expected_results[idx]["dataset"],
"selected_sets": expected_results[idx]["sets"],
}
assert len(responses.calls) == len(BASE_API_CALLS) + 1
assert [
......@@ -240,7 +255,25 @@ def test_list_dataset_elements(
json={
"count": 4,
"next": None,
"results": expected_results,
"results": expected_results
# `set_4` is not in `default_dataset.selected_sets`
+ [
{
"set": "set_4",
"element": {
"id": "4444",
"type": "page",
"name": "Test 5",
"corpus": {},
"thumbnail_url": None,
"zone": {},
"best_classes": None,
"has_children": None,
"worker_version_id": None,
"worker_run_id": None,
},
}
],
},
)
......
......@@ -467,6 +467,10 @@ def test_worker_config_multiple_source(
"id": "12341234-1234-1234-1234-123412341234",
"name": "Model version 1337",
"configuration": model_config,
"model": {
"id": "hahahaha-haha-haha-haha-hahahahahaha",
"name": "My model",
},
},
"process": {
"name": None,
......