From 6da2f4dd995d0e0062ebc6af9b2001ecc2c124a3 Mon Sep 17 00:00:00 2001 From: ml bonhomme <bonhomme@teklia.com> Date: Tue, 26 Mar 2024 15:42:31 +0000 Subject: [PATCH] Update load_export to handle new Dataset Set model --- .../management/commands/load_export.py | 27 ++++++++++++++++--- .../tests/commands/test_load_export.py | 7 +++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/arkindex/documents/management/commands/load_export.py b/arkindex/documents/management/commands/load_export.py index 46bb32cb4c..b7ac5a6899 100644 --- a/arkindex/documents/management/commands/load_export.py +++ b/arkindex/documents/management/commands/load_export.py @@ -37,7 +37,7 @@ from arkindex.process.models import ( WorkerType, WorkerVersion, ) -from arkindex.training.models import Dataset, DatasetElement, Model +from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model from arkindex.users.models import Role, User EXPORT_VERSION = 8 @@ -320,17 +320,30 @@ class Command(BaseCommand): id=row["id"], corpus=self.corpus, name=row["name"], - sets=[r.strip() for r in row["sets"].split(",")], creator=self.user, description="Imported dataset", )] + def convert_dataset_sets(self, row): + return [ + DatasetSet( + name=set_name.strip(), + dataset_id=row["id"] + ) + for set_name in row["sets"].split(",") + ] + + def map_dataset_sets(self): + return { + (str(set.dataset_id), set.name): set.id + for set in DatasetSet.objects.filter(dataset__corpus=self.corpus) + } + def convert_dataset_elements(self, row): return [DatasetElement( id=row["id"], element_id=row["element_id"], - dataset_id=row["dataset_id"], - set=row["set_name"], + set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])] )] def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True): @@ -603,6 +616,12 @@ class Command(BaseCommand): # Create datasets self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY) + # Create dataset sets + self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY) + + # Create dataset sets mapping + self.dataset_sets_map = self.map_dataset_sets() + # Create dataset elements self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY) diff --git a/arkindex/documents/tests/commands/test_load_export.py b/arkindex/documents/tests/commands/test_load_export.py index b9aaf6e699..121f27855b 100644 --- a/arkindex/documents/tests/commands/test_load_export.py +++ b/arkindex/documents/tests/commands/test_load_export.py @@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete from arkindex.images.models import Image, ImageServer from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion from arkindex.project.tests import FixtureTestCase +from arkindex.training.models import Dataset, DatasetElement BASE_DIR = Path(__file__).absolute().parent @@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase): dla_version = WorkerVersion.objects.get(worker__slug="dla") dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers) + dataset_set = Dataset.objects.first().sets.first() + DatasetElement.objects.create(set=dataset_set, element=element) + element.classifications.create( ml_class=self.corpus.ml_classes.create(name="Blah"), confidence=.55555555, @@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase): confidence=.55555555, ) + dataset_set = Dataset.objects.first().sets.first() + DatasetElement.objects.create(set=dataset_set, element=element) + person_type = EntityType.objects.get( name="person", corpus=self.corpus -- GitLab