Skip to content
Snippets Groups Projects
Commit 6da2f4dd authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Update load_export to handle new Dataset Set model

parent 0e6e04f6
No related branches found
No related tags found
1 merge request!2269Update load_export to handle new Dataset Set model
......@@ -37,7 +37,7 @@ from arkindex.process.models import (
WorkerType,
WorkerVersion,
)
from arkindex.training.models import Dataset, DatasetElement, Model
from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
from arkindex.users.models import Role, User
EXPORT_VERSION = 8
......@@ -320,17 +320,30 @@ class Command(BaseCommand):
id=row["id"],
corpus=self.corpus,
name=row["name"],
sets=[r.strip() for r in row["sets"].split(",")],
creator=self.user,
description="Imported dataset",
)]
def convert_dataset_sets(self, row):
return [
DatasetSet(
name=set_name.strip(),
dataset_id=row["id"]
)
for set_name in row["sets"].split(",")
]
def map_dataset_sets(self):
return {
(str(set.dataset_id), set.name): set.id
for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
}
def convert_dataset_elements(self, row):
return [DatasetElement(
id=row["id"],
element_id=row["element_id"],
dataset_id=row["dataset_id"],
set=row["set_name"],
set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
)]
def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
......@@ -603,6 +616,12 @@ class Command(BaseCommand):
# Create datasets
self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
# Create dataset sets
self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
# Create dataset sets mapping
self.dataset_sets_map = self.map_dataset_sets()
# Create dataset elements
self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
......
......@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer
from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, DatasetElement
BASE_DIR = Path(__file__).absolute().parent
......@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
element.classifications.create(
ml_class=self.corpus.ml_classes.create(name="Blah"),
confidence=.55555555,
......@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
confidence=.55555555,
)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
person_type = EntityType.objects.get(
name="person",
corpus=self.corpus
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment