From dcd12d6cf6a5a93b74931743b7a62227ebf3e435 Mon Sep 17 00:00:00 2001
From: Yoann Schneider <yschneider@teklia.com>
Date: Wed, 14 Dec 2022 13:12:48 +0100
Subject: [PATCH] assert ml_class_id is a valid uuid

---
 arkindex_worker/worker/classification.py      |  9 ++
 .../test_classifications.py                   | 83 ++++++++++++-------
 2 files changed, 63 insertions(+), 29 deletions(-)

diff --git a/arkindex_worker/worker/classification.py b/arkindex_worker/worker/classification.py
index 72eb646a..98b5db4b 100644
--- a/arkindex_worker/worker/classification.py
+++ b/arkindex_worker/worker/classification.py
@@ -4,6 +4,7 @@ ElementsWorker methods for classifications and ML classes.
 """
 
 from typing import Dict, List, Optional, Union
+from uuid import UUID
 
 from apistar.exceptions import ErrorResponse
 from peewee import IntegrityError
@@ -209,6 +210,14 @@ class ClassificationMixin(object):
                 ml_class_id, str
             ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
 
+            # Make sure it's a valid UUID
+            try:
+                UUID(ml_class_id)
+            except ValueError:
+                raise ValueError(
+                    f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
+                )
+
             confidence = classification.get("confidence")
             assert (
                 confidence is not None
diff --git a/tests/test_elements_worker/test_classifications.py b/tests/test_elements_worker/test_classifications.py
index 8a108997..63b06a7a 100644
--- a/tests/test_elements_worker/test_classifications.py
+++ b/tests/test_elements_worker/test_classifications.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 import json
-from uuid import UUID
+from uuid import UUID, uuid4
 
 import pytest
 from apistar.exceptions import ErrorResponse
@@ -624,7 +624,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
@@ -644,7 +644,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
@@ -665,7 +665,7 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
@@ -681,17 +681,38 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
         == "Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str"
     )
 
+    with pytest.raises(ValueError) as e:
+        mock_elements_worker.create_classifications(
+            element=elt,
+            classifications=[
+                {
+                    "ml_class_id": str(uuid4()),
+                    "confidence": 0.75,
+                    "high_confidence": False,
+                },
+                {
+                    "ml_class_id": "not_an_uuid",
+                    "confidence": 0.25,
+                    "high_confidence": False,
+                },
+            ],
+        )
+    assert (
+        str(e.value)
+        == "Classification at index 1 in classifications: ml_class_id is not a valid uuid."
+    )
+
     with pytest.raises(AssertionError) as e:
         mock_elements_worker.create_classifications(
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "high_confidence": False,
                 },
             ],
@@ -706,12 +727,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "confidence": None,
                     "high_confidence": False,
                 },
@@ -727,12 +748,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "confidence": "wrong confidence",
                     "high_confidence": False,
                 },
@@ -748,12 +769,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0,
                     "high_confidence": False,
                 },
@@ -769,12 +790,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 2.00,
                     "high_confidence": False,
                 },
@@ -790,12 +811,12 @@ def test_create_classifications_wrong_classifications(mock_elements_worker):
             element=elt,
             classifications=[
                 {
-                    "ml_class_id": "uuid1",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.75,
                     "high_confidence": False,
                 },
                 {
-                    "ml_class_id": "uuid2",
+                    "ml_class_id": str(uuid4()),
                     "confidence": 0.25,
                     "high_confidence": "wrong high_confidence",
                 },
@@ -816,12 +837,12 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
     elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
     classes = [
         {
-            "ml_class_id": "uuid1",
+            "ml_class_id": str(uuid4()),
             "confidence": 0.75,
             "high_confidence": False,
         },
         {
-            "ml_class_id": "uuid2",
+            "ml_class_id": str(uuid4()),
             "confidence": 0.25,
             "high_confidence": False,
         },
@@ -847,19 +868,21 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
 
 def test_create_classifications(responses, mock_elements_worker_with_cache):
     # Set MLClass in cache
+    portrait_uuid = str(uuid4())
+    landscape_uuid = str(uuid4())
     mock_elements_worker_with_cache.classes[
         mock_elements_worker_with_cache.corpus_id
-    ] = {"portrait": "uuid1", "landscape": "uuid2"}
+    ] = {"portrait": portrait_uuid, "landscape": landscape_uuid}
 
     elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
     classes = [
         {
-            "ml_class_id": "uuid1",
+            "ml_class_id": portrait_uuid,
             "confidence": 0.75,
             "high_confidence": False,
         },
         {
-            "ml_class_id": "uuid2",
+            "ml_class_id": landscape_uuid,
             "confidence": 0.25,
             "high_confidence": False,
         },
@@ -875,14 +898,14 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
             "classifications": [
                 {
                     "id": "00000000-0000-0000-0000-000000000000",
-                    "ml_class": "uuid1",
+                    "ml_class": portrait_uuid,
                     "confidence": 0.75,
                     "high_confidence": False,
                     "state": "pending",
                 },
                 {
                     "id": "11111111-1111-1111-1111-111111111111",
-                    "ml_class": "uuid2",
+                    "ml_class": landscape_uuid,
                     "confidence": 0.25,
                     "high_confidence": False,
                     "state": "pending",
@@ -936,15 +959,17 @@ def test_create_classifications_not_in_cache(
     CreateClassifications using ID that are not in `.classes` attribute.
     Will load corpus MLClass to insert the corresponding name in Cache.
     """
+    portrait_uuid = str(uuid4())
+    landscape_uuid = str(uuid4())
     elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
     classes = [
         {
-            "ml_class_id": "uuid1",
+            "ml_class_id": portrait_uuid,
             "confidence": 0.75,
             "high_confidence": False,
         },
         {
-            "ml_class_id": "uuid2",
+            "ml_class_id": landscape_uuid,
             "confidence": 0.25,
             "high_confidence": False,
         },
@@ -960,14 +985,14 @@ def test_create_classifications_not_in_cache(
             "classifications": [
                 {
                     "id": "00000000-0000-0000-0000-000000000000",
-                    "ml_class": "uuid1",
+                    "ml_class": portrait_uuid,
                     "confidence": 0.75,
                     "high_confidence": False,
                     "state": "pending",
                 },
                 {
                     "id": "11111111-1111-1111-1111-111111111111",
-                    "ml_class": "uuid2",
+                    "ml_class": landscape_uuid,
                     "confidence": 0.25,
                     "high_confidence": False,
                     "state": "pending",
@@ -984,10 +1009,10 @@ def test_create_classifications_not_in_cache(
             "next": None,
             "results": [
                 {
-                    "id": "uuid1",
+                    "id": portrait_uuid,
                     "name": "portrait",
                 },
-                {"id": "uuid2", "name": "landscape"},
+                {"id": landscape_uuid, "name": "landscape"},
             ],
         },
     )
-- 
GitLab