diff --git a/arkindex/project/settings.py b/arkindex/project/settings.py index 6fa4d7b6e9bea7b8686bfc1a1633dc8a801a6ab4..f476cc5fc6c096b63f6de5392172204da3f438b9 100644 --- a/arkindex/project/settings.py +++ b/arkindex/project/settings.py @@ -119,6 +119,7 @@ INSTALLED_APPS = [ 'arkindex.documents', 'arkindex.users', 'arkindex.dataimport', + 'arkindex.training', ] MIDDLEWARE = [ @@ -250,6 +251,7 @@ SPECTACULAR_SETTINGS = { 'ClassificationState': 'arkindex.documents.models.ClassificationState', 'PonosState': 'ponos.models.State', 'WorkerVersionState': 'arkindex.dataimport.models.WorkerVersionState', + 'ModelVersionState': 'arkindex.training.models.ModelVersionState', }, 'TAGS': [ {'name': 'classifications'}, diff --git a/arkindex/training/admin.py b/arkindex/training/admin.py new file mode 100644 index 0000000000000000000000000000000000000000..1d38c2bef5d57b4e022e24ad3a47dd52737b3cf3 --- /dev/null +++ b/arkindex/training/admin.py @@ -0,0 +1,21 @@ +from django.contrib import admin +from enumfields.admin import EnumFieldListFilter + +from arkindex.training.models import Model, ModelVersion + + +class ModelAdmin(admin.ModelAdmin): + list_display = ('name', 'created', ) + search_fields = ('name', 'description', ) + fields = ('name', 'description', 'public', 'compatible_workers') + + +class ModelVersionAdmin(admin.ModelAdmin): + list_display = ('id', 'model', 'tag', 'size', 'state') + list_filter = ('model__name', ('state', EnumFieldListFilter), ) + fields = ('model', 'parent', 'description', 'state', 'tag', 'hash', 'size', 'configuration',) + readonly_fields = ('hash', 'size', ) + + +admin.site.register(Model, ModelAdmin) +admin.site.register(ModelVersion, ModelVersionAdmin) diff --git a/arkindex/training/migrations/0001_initial.py b/arkindex/training/migrations/0001_initial.py new file mode 100644 index 0000000000000000000000000000000000000000..20d99a11b5246eac7475665a4ed666d88551354b --- /dev/null +++ b/arkindex/training/migrations/0001_initial.py @@ -0,0 +1,56 @@ +# Generated by Django 4.0.2 on 2022-03-16 09:50 + +import uuid + +import django.db.models.deletion +import enumfields.fields +from django.db import migrations, models + +import arkindex.project.fields +import arkindex.training.models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('dataimport', '0045_remove_dataimport_best_class'), + ] + + operations = [ + migrations.CreateModel( + name='Model', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('name', models.CharField(max_length=100, unique=True)), + ('description', models.TextField(default='')), + ('public', models.BooleanField(default=False)), + ('compatible_workers', models.ManyToManyField(related_name='models', to='dataimport.Worker')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='ModelVersion', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('description', models.TextField(default='')), + ('tag', models.CharField(blank=True, max_length=50, null=True)), + ('state', enumfields.fields.EnumField(default='created', enum=arkindex.training.models.ModelVersionState, max_length=10)), + ('hash', arkindex.project.fields.MD5HashField(max_length=32)), + ('size', models.PositiveIntegerField(help_text='file size in bytes')), + ('configuration', models.JSONField(default=dict)), + ('model', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='versions', to='training.model')), + ('parent', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='children', to='training.modelversion')), + ], + options={ + 'unique_together': {('model', 'tag')}, + }, + ), + ] diff --git a/arkindex/training/migrations/__init__.py b/arkindex/training/migrations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/arkindex/training/models.py b/arkindex/training/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1a778ed0927fda8e06d1da5ddd97ff90b75d9432 --- /dev/null +++ b/arkindex/training/models.py @@ -0,0 +1,61 @@ +from django.db import models +from enumfields import Enum, EnumField + +from arkindex.project.fields import MD5HashField +from arkindex.project.models import IndexableModel + + +class Model(IndexableModel): + """ + An evolving Machine Learning model + """ + # Name of the model, unique + name = models.CharField(max_length=100, unique=True) + + description = models.TextField(default="") + + public = models.BooleanField(default=False) + + # Link to the workers that are able to use this model + compatible_workers = models.ManyToManyField('dataimport.Worker', related_name='models') + + def __str__(self): + return self.name + + +class ModelVersionState(Enum): + """ + State of the model Version, available meaning checked by the backend + """ + Created = 'created' + Available = 'available' + Error = 'error' + + +class ModelVersion(IndexableModel): + """ + A specific Model version + """ + model = models.ForeignKey('training.Model', related_name='versions', on_delete=models.CASCADE) + + parent = models.ForeignKey('self', related_name='children', null=True, blank=True, on_delete=models.CASCADE) + + description = models.TextField(default="") + + tag = models.CharField(null=True, max_length=50, blank=True) + + state = EnumField(ModelVersionState, default=ModelVersionState.Created) + + # Hash of the archive + hash = MD5HashField() + + # Size of the archive + size = models.PositiveIntegerField(help_text='file size in bytes') + + # Store dictionary of paramseters given by the ML developer + configuration = models.JSONField(default=dict) + + class Meta: + unique_together = ( + ('model', 'tag'), + )