Skip to content
Snippets Groups Projects
Commit d413646a authored by Bastien Abadie's avatar Bastien Abadie
Browse files

Merge branch 'workerversion-fk-on-entity' into 'master'

Add a FK on Entity towards a WorkerVersion

Closes #368

See merge request !875
parents be833e3c 5191da48
No related branches found
No related tags found
1 merge request!875Add a FK on Entity towards a WorkerVersion
......@@ -137,12 +137,14 @@ class EntityCreate(CreateAPIView):
corpus = serializer.validated_data['corpus']
metas = serializer.validated_data['metas'] if 'metas' in serializer.data else None
source = serializer.validated_data['ner']
worker_version = serializer.validated_data['worker_version']
return Entity.objects.create(
name=name,
type=type,
corpus=corpus,
metas=metas,
source=source
source=source,
worker_version=worker_version
)
def create(self, request, *args, **kwargs):
......
......@@ -53,4 +53,14 @@ class Migration(migrations.Migration):
model_name='classification',
constraint=models.CheckConstraint(check=models.Q(models.Q(('source_id__isnull', False), ('worker_version_id__isnull', True)), models.Q(('source_id__isnull', True), ('worker_version_id__isnull', False)), _connector='OR'), name='classification_source_xor_workerversion'),
),
migrations.AddField(
model_name='entity',
name='worker_version',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='entities', to='dataimport.WorkerVersion'),
),
migrations.AlterField(
model_name='entity',
name='source',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='entities', to='documents.DataSource'),
),
]
......@@ -341,7 +341,16 @@ class Entity(InterpretedDateMixin, models.Model):
source = models.ForeignKey(
DataSource,
on_delete=models.CASCADE,
related_name='entities'
related_name='entities',
null=True,
blank=True,
)
worker_version = models.ForeignKey(
'dataimport.WorkerVersion',
on_delete=models.CASCADE,
related_name='entities',
null=True,
blank=True,
)
class Meta:
......
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from arkindex_common.ml_tool import MLToolType
from arkindex.project.serializer_fields import EnumField, DataSourceSlugField
from arkindex.dataimport.models import WorkerVersion
from arkindex.documents.models import \
Element, Corpus, Entity, EntityLink, EntityRole, Transcription, TranscriptionEntity
from arkindex_common.enums import EntityType
......@@ -99,7 +101,8 @@ class EntityCreateSerializer(EntityLightSerializer):
metas = serializers.HStoreField(child=serializers.CharField(), required=False)
children = EntityLinkSerializer(many=True, read_only=True)
parents = EntityLinkSerializer(many=True, read_only=True)
ner = DataSourceSlugField(tool_type=MLToolType.NER)
ner = DataSourceSlugField(tool_type=MLToolType.NER, default=None)
worker_version = serializers.PrimaryKeyRelatedField(queryset=WorkerVersion.objects.all(), default=None)
class Meta:
model = Entity
......@@ -112,7 +115,8 @@ class EntityCreateSerializer(EntityLightSerializer):
'corpus',
'parents',
'children',
'ner'
'ner',
'worker_version'
)
read_only_fields = (
'id',
......@@ -128,6 +132,22 @@ class EntityCreateSerializer(EntityLightSerializer):
corpora = Corpus.objects.writable(self.context['request'].user)
self.fields['corpus'].queryset = corpora
def validate(self, data):
ner = data.get('ner')
worker_version = data.get('worker_version')
if not ner and not worker_version:
raise ValidationError({
'ner': ['This field XOR worker_version field must be set to create an entity'],
'worker_version': ['This field XOR ner field must be set to create an entity']
})
elif ner and worker_version:
raise ValidationError({
'ner': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on an entity']
})
return data
class EntityLinkCreateSerializer(EntityLinkSerializer):
"""
......
......@@ -5,6 +5,7 @@ from rest_framework import status
from arkindex_common.enums import MetaType
from arkindex.project.polygon import Polygon
from arkindex.project.tests import FixtureAPITestCase
from arkindex.dataimport.models import Worker, WorkerVersion
from arkindex.documents.models import (
Corpus, Element, TranscriptionType, DataSource, MLToolType,
Entity, EntityType, EntityRole, EntityLink, TranscriptionEntity,
......@@ -27,6 +28,9 @@ class TestEntitiesAPI(FixtureAPITestCase):
)
cls.source = DataSource.objects.get(slug='test')
cls.private_corpus = Corpus.objects.create(name='private')
cls.creds = cls.user.credentials.get()
cls.repo = cls.creds.repos.get()
cls.rev = cls.repo.revisions.get()
def setUp(self):
super().setUp()
......@@ -245,6 +249,7 @@ class TestEntitiesAPI(FixtureAPITestCase):
self.assertEqual(entity.name, 'entity')
self.assertEqual(entity.raw_dates, None)
self.assertEqual(entity.source, self.entity_source)
self.assertEqual(entity.worker_version, None)
@patch('arkindex.project.serializer_fields.MLTool.get')
def test_create_entity_number(self, ml_get_mock):
......@@ -267,6 +272,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
entity = Entity.objects.get(id=response.json()['id'])
self.assertEqual(entity.name, '300g')
self.assertEqual(entity.raw_dates, None)
self.assertEqual(entity.source, self.entity_source)
self.assertEqual(entity.worker_version, None)
@patch('arkindex.project.serializer_fields.MLTool.get')
def test_create_entity_date(self, ml_get_mock):
......@@ -289,6 +296,8 @@ class TestEntitiesAPI(FixtureAPITestCase):
entity = Entity.objects.get(id=response.json()['id'])
self.assertEqual(entity.name, '1789')
self.assertEqual(entity.raw_dates, entity.name)
self.assertEqual(entity.source, self.entity_source)
self.assertEqual(entity.worker_version, None)
def test_create_entity_requires_login(self):
data = {
......@@ -304,6 +313,85 @@ class TestEntitiesAPI(FixtureAPITestCase):
response = self.client.post(reverse('api:entity-create'), data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_create_entity_no_source_no_worker_version(self):
data = {
'name': '1789',
'type': EntityType.Date.value,
'corpus': str(self.corpus.id),
'metas': {
'key': 'value',
'other key': 'other value'
},
}
self.client.force_login(self.user)
response = self.client.post(reverse('api:entity-create'), data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'ner': ['This field XOR worker_version field must be set to create an entity'],
'worker_version': ['This field XOR ner field must be set to create an entity']
})
def test_create_entity_with_source_and_worker_version_returns_error(self):
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
data = {
'name': '1789',
'type': EntityType.Date.value,
'corpus': str(self.corpus.id),
'metas': {
'key': 'value',
'other key': 'other value'
},
'ner': self.entity_source.slug,
'worker_version': str(version.id)
}
self.client.force_login(self.user)
response = self.client.post(reverse('api:entity-create'), data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {
'ner': ['You can only refer to a DataSource XOR a WorkerVersion on an entity'],
'worker_version': ['You can only refer to a DataSource XOR a WorkerVersion on an entity']
})
def test_create_entity_with_worker_version(self):
worker = Worker.objects.create(
repository=self.repo,
name='Worker 1',
slug='worker_1',
type=MLToolType.Classifier
)
version = WorkerVersion.objects.create(
worker=worker,
revision=self.rev,
configuration={"test": "test1"}
)
data = {
'name': '1789',
'type': EntityType.Date.value,
'corpus': str(self.corpus.id),
'metas': {
'key': 'value',
'other key': 'other value'
},
'worker_version': str(version.id)
}
self.client.force_login(self.user)
response = self.client.post(reverse('api:entity-create'), data=data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
entity = Entity.objects.get(id=response.json()['id'])
self.assertEqual(entity.name, '1789')
self.assertEqual(entity.source, None)
self.assertEqual(entity.worker_version, version)
def test_create_link(self):
child = Entity.objects.create(
type=EntityType.Location,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment