Skip to content
Snippets Groups Projects
Commit 6da2f4dd authored by ml bonhomme's avatar ml bonhomme :bee: Committed by Erwan Rouchet
Browse files

Update load_export to handle new Dataset Set model

parent 0e6e04f6
No related branches found
No related tags found
1 merge request!2269Update load_export to handle new Dataset Set model
...@@ -37,7 +37,7 @@ from arkindex.process.models import ( ...@@ -37,7 +37,7 @@ from arkindex.process.models import (
WorkerType, WorkerType,
WorkerVersion, WorkerVersion,
) )
from arkindex.training.models import Dataset, DatasetElement, Model from arkindex.training.models import Dataset, DatasetElement, DatasetSet, Model
from arkindex.users.models import Role, User from arkindex.users.models import Role, User
EXPORT_VERSION = 8 EXPORT_VERSION = 8
...@@ -320,17 +320,30 @@ class Command(BaseCommand): ...@@ -320,17 +320,30 @@ class Command(BaseCommand):
id=row["id"], id=row["id"],
corpus=self.corpus, corpus=self.corpus,
name=row["name"], name=row["name"],
sets=[r.strip() for r in row["sets"].split(",")],
creator=self.user, creator=self.user,
description="Imported dataset", description="Imported dataset",
)] )]
def convert_dataset_sets(self, row):
return [
DatasetSet(
name=set_name.strip(),
dataset_id=row["id"]
)
for set_name in row["sets"].split(",")
]
def map_dataset_sets(self):
return {
(str(set.dataset_id), set.name): set.id
for set in DatasetSet.objects.filter(dataset__corpus=self.corpus)
}
def convert_dataset_elements(self, row): def convert_dataset_elements(self, row):
return [DatasetElement( return [DatasetElement(
id=row["id"], id=row["id"],
element_id=row["element_id"], element_id=row["element_id"],
dataset_id=row["dataset_id"], set_id=self.dataset_sets_map[(row["dataset_id"], row["set_name"])]
set=row["set_name"],
)] )]
def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True): def bulk_create_objects(self, ModelClass, convert_method, sql_query, ignore_conflicts=True):
...@@ -603,6 +616,12 @@ class Command(BaseCommand): ...@@ -603,6 +616,12 @@ class Command(BaseCommand):
# Create datasets # Create datasets
self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY) self.bulk_create_objects(Dataset, self.convert_datasets, SQL_DATASET_QUERY)
# Create dataset sets
self.bulk_create_objects(DatasetSet, self.convert_dataset_sets, SQL_DATASET_QUERY)
# Create dataset sets mapping
self.dataset_sets_map = self.map_dataset_sets()
# Create dataset elements # Create dataset elements
self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY) self.bulk_create_objects(DatasetElement, self.convert_dataset_elements, SQL_ELEMENT_DATASET_QUERY)
......
...@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete ...@@ -14,6 +14,7 @@ from arkindex.documents.tasks import corpus_delete
from arkindex.images.models import Image, ImageServer from arkindex.images.models import Image, ImageServer
from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion from arkindex.process.models import ProcessMode, Repository, Worker, WorkerRun, WorkerType, WorkerVersion
from arkindex.project.tests import FixtureTestCase from arkindex.project.tests import FixtureTestCase
from arkindex.training.models import Dataset, DatasetElement
BASE_DIR = Path(__file__).absolute().parent BASE_DIR = Path(__file__).absolute().parent
...@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase): ...@@ -132,6 +133,9 @@ class TestLoadExport(FixtureTestCase):
dla_version = WorkerVersion.objects.get(worker__slug="dla") dla_version = WorkerVersion.objects.get(worker__slug="dla")
dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers) dla_run = dla_version.worker_runs.get(process__mode=ProcessMode.Workers)
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
element.classifications.create( element.classifications.create(
ml_class=self.corpus.ml_classes.create(name="Blah"), ml_class=self.corpus.ml_classes.create(name="Blah"),
confidence=.55555555, confidence=.55555555,
...@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase): ...@@ -266,6 +270,9 @@ class TestLoadExport(FixtureTestCase):
confidence=.55555555, confidence=.55555555,
) )
dataset_set = Dataset.objects.first().sets.first()
DatasetElement.objects.create(set=dataset_set, element=element)
person_type = EntityType.objects.get( person_type = EntityType.objects.get(
name="person", name="person",
corpus=self.corpus corpus=self.corpus
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment