From 6da2f4dd995d0e0062ebc6af9b2001ecc2c124a3 Mon Sep 17 00:00:00 2001
From: ml bonhomme <bonhomme@teklia.com>
Date: Tue, 26 Mar 2024 15:42:31 +0000
Subject: [PATCH] Update load_export to handle new Dataset Set model

---
 .../management/commands/load_export.py        | 27 ++++++++++++++++---
 .../tests/commands/test_load_export.py        |  7 +++++
 2 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/arkindex/documents/management/commands/load_export.py b/arkindex/documents/management/commands/load_export.py
index 46bb32cb4c..b7ac5a6899 100644
--- a/arkindex/documents/management/commands/load_export.py
+++ b/arkindex/documents/management/commands/load_export.py
@@ -37,7 +37,7 @@ from arkindex.process.models import (
     WorkerType,
     WorkerVersion,
 )
-from arkindex.training.models import Dataset, DatasetElement, Model
+from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
 from arkindex.users.models import Role, User
 
 EXPORT_VERSION = 8
@@ -320,17 +320,30 @@ class Command(BaseCommand):
             id=row["id"],
             corpus=self.corpus,
             name=row["name"],
-            sets=[r.strip() for r in row["sets"].split(",")],
             creator=self.user,
             description="Imported dataset",
         )]
 
+    def convert_dataset_sets(self, row):
+        return [
+            DatasetSet(
+                name=set_name.strip(),
+                dataset_id=row["id"]
+            )
+            for set_name in row["sets"].split(",")
+        ]
+
+    def map_dataset_sets(self):
+        return {
+            (str(set.dataset_id), set.name): set.id
+            for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
+        }
+
     def convert_dataset_elements(self, row):
         return [DatasetElement(
             id=row["id"],
             element_id=row["element_id"],
-            dataset_id=row["dataset_id"],
-            set=row["set_name"],
+            set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
         )]
 
     def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
@@ -603,6 +616,12 @@ class Command(BaseCommand):
             # Create datasets
             self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
 
+            # Create dataset sets
+            self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
+
+            # Create dataset sets mapping
+            self.dataset_sets_map = self.map_dataset_sets()
+
             # Create dataset elements
             self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
 
diff --git a/arkindex/documents/tests/commands/test_load_export.py b/arkindex/documents/tests/commands/test_load_export.py
index b9aaf6e699..121f27855b 100644
--- a/arkindex/documents/tests/commands/test_load_export.py
+++ b/arkindex/documents/tests/commands/test_load_export.py
@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
 from arkindex.images.models import Image, ImageServer
 from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
 from arkindex.project.tests import FixtureTestCase
+from arkindex.training.models import Dataset, DatasetElement
 
 BASE_DIR = Path(__file__).absolute().parent
 
@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
         dla_version = WorkerVersion.objects.get(worker__slug="dla")
         dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
 
+        dataset_set = Dataset.objects.first().sets.first()
+        DatasetElement.objects.create(set=dataset_set, element=element)
+
         element.classifications.create(
             ml_class=self.corpus.ml_classes.create(name="Blah"),
             confidence=.55555555,
@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
             confidence=.55555555,
         )
 
+        dataset_set = Dataset.objects.first().sets.first()
+        DatasetElement.objects.create(set=dataset_set, element=element)
+
         person_type = EntityType.objects.get(
             name="person",
             corpus=self.corpus
-- 
GitLab