Skip to content
Snippets Groups Projects
Commit bf129fae authored by Erwan Rouchet's avatar Erwan Rouchet Committed by Bastien Abadie
Browse files

Handle zones from worker-ml

parent 2ab4c91c
No related branches found
No related tags found
No related merge requests found
......@@ -6,11 +6,13 @@ from django.conf import settings
from django.db import transaction
from django.core.exceptions import ValidationError
from arkindex.project.celery import ReportingTask
from arkindex.documents.models import Element, ElementType, Page
from arkindex.documents.models import Element, ElementType, Page, Transcription, TranscriptionType
from arkindex.documents.importer import import_page
from arkindex.documents.indexer import Indexer
from arkindex.documents.tei import TeiParser
from arkindex.documents.tasks import generate_thumbnail
from arkindex.images.models import ImageServer, ImageStatus
from arkindex.images.importer import build_transcriptions, save_transcriptions
from arkindex.dataimport.models import DataImport, DataImportState, DataImportMode, EventType
from arkindex.dataimport.config import ConfigFile
from arkindex.dataimport.filetypes import FileType
......@@ -214,11 +216,32 @@ def save_ml_results(self, results, **kwargs):
continue
page.classification = result['classification']
page.text = result['text']
page.save()
self.report_message("Updated ML results for {}".format(page))
tr_items = result['zones']
# Parse transcription types
for item in tr_items:
item['type'] = TranscriptionType(item['type'])
trpolygons = build_transcriptions(
parent=page,
image=page.zone.image,
items=tr_items,
)
if trpolygons:
transcriptions, _ = save_transcriptions(*trpolygons)
self.report_message('Saved transcriptions for {}'.format(page))
Indexer().run_index(
settings.ES_INDEX_TRANSCRIPTIONS,
Transcription.INDEX_TYPE,
Transcription.objects.filter(
id__in=[t[0] for t in transcriptions],
),
)
self.report_message('Indexed transcriptions for {}'.format(page))
return list(map(str, results.keys()))
......
from django.core.management import call_command
from arkindex.project.tests import RedisMockAPITestCase, FixtureMixin
from arkindex.dataimport.tasks import save_ml_results, dataimport_postrun
from arkindex.documents.models import Page, Element, ElementType
from arkindex.documents.models import Page, Element, ElementType, TranscriptionType
from arkindex.dataimport.models import DataImport, DataImportState
from unittest.mock import MagicMock, patch
from celery import states
......@@ -13,8 +13,14 @@ class TestTasks(FixtureMixin, RedisMockAPITestCase):
Test data imports tasks
"""
def test_save_ml_results(self):
dog = Page.objects.create(corpus=self.corpus, name='A dog')
cat = Page.objects.create(corpus=self.corpus, name='A cat')
dog_img = self.imgsrv.images.create(path='dog', width=100, height=100)
cat_img = self.imgsrv.images.create(path='cat', width=100, height=100)
dog_zone = dog_img.zones.create(polygon=[(0, 0), (0, 100), (100, 100), (100, 0), (0, 0)])
cat_zone = cat_img.zones.create(polygon=[(0, 0), (0, 100), (100, 100), (100, 0), (0, 0)])
dog = Page.objects.create(corpus=self.corpus, name='A dog', zone=dog_zone)
cat = Page.objects.create(corpus=self.corpus, name='A cat', zone=cat_zone)
classification = {
dog.id: {
......@@ -24,7 +30,14 @@ class TestTasks(FixtureMixin, RedisMockAPITestCase):
'probability': 0.9,
}
],
'text': 'This is a dog',
'zones': [
{
'type': 'word',
'polygon': [[0, 0], [0, 10], [10, 10], [10, 0]],
'text': 'woof',
'score': 0.8
}
],
},
cat.id: {
'classification': [
......@@ -33,7 +46,14 @@ class TestTasks(FixtureMixin, RedisMockAPITestCase):
'probability': 0.8,
}
],
'text': 'This is a cat. meow',
'zones': [
{
'type': 'line',
'polygon': [[0, 0], [0, 20], [20, 20], [20, 0]],
'text': 'meow',
'score': 0.9
}
],
},
}
save_ml_results(classification)
......@@ -43,14 +63,22 @@ class TestTasks(FixtureMixin, RedisMockAPITestCase):
'label': 'dog',
'probability': 0.9,
}])
self.assertEqual(dog.text, 'This is a dog')
dog_ts = dog.transcriptions.get()
self.assertEqual(dog_ts.type, TranscriptionType.Word)
self.assertEqual(dog_ts.text, 'woof')
self.assertEqual(dog_ts.score, 0.8)
self.assertEqual(dog_ts.zone.polygon, [(0, 0), (0, 10), (10, 10), (10, 0), (0, 0)])
cat.refresh_from_db()
self.assertEqual(cat.classification, [{
'label': 'cat',
'probability': 0.8,
}])
self.assertEqual(cat.text, 'This is a cat. meow')
cat_ts = cat.transcriptions.get()
self.assertEqual(cat_ts.type, TranscriptionType.Line)
self.assertEqual(cat_ts.text, 'meow')
self.assertEqual(cat_ts.score, 0.9)
self.assertEqual(cat_ts.zone.polygon, [(0, 0), (0, 20), (20, 20), (20, 0), (0, 0)])
def test_command(self):
# No tasks at first
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment