Skip to content
Snippets Groups Projects
Commit 4aa0478d authored by Thibault Lavigne's avatar Thibault Lavigne Committed by Yoann Schneider
Browse files

Parse a tokens yaml instead of hardcoding the tokens

parent 8675a210
No related branches found
No related tags found
1 merge request!211Parse a tokens yaml instead of hardcoding the tokens
...@@ -51,6 +51,12 @@ def add_predict_parser(subcommands) -> None: ...@@ -51,6 +51,12 @@ def add_predict_parser(subcommands) -> None:
help="Path to the output folder.", help="Path to the output folder.",
required=True, required=True,
) )
parser.add_argument(
"--tokens",
type=pathlib.Path,
required=True,
help="Path to a yaml file containing a mapping between starting tokens and end tokens. Needed for entities.",
)
# Optional arguments. # Optional arguments.
parser.add_argument( parser.add_argument(
"--image-extension", "--image-extension",
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import yaml import yaml
from dan import logger from dan import logger
from dan.datasets.extract.utils import save_json from dan.datasets.extract.utils import parse_tokens, save_json
from dan.decoder import GlobalHTADecoder from dan.decoder import GlobalHTADecoder
from dan.encoder import FCN_Encoder from dan.encoder import FCN_Encoder
from dan.predict.attention import ( from dan.predict.attention import (
...@@ -256,6 +256,16 @@ class DAN: ...@@ -256,6 +256,16 @@ class DAN:
return out return out
def parse_ner_predictions(text, char_confidences, predictions):
return [
{
"text": f"{text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
for current, next_token in predictions
]
def process_batch( def process_batch(
image_batch, image_batch,
dan_model, dan_model,
...@@ -271,6 +281,7 @@ def process_batch( ...@@ -271,6 +281,7 @@ def process_batch(
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
tokens,
): ):
input_images, visu_images, input_sizes = [], [], [] input_images, visu_images, input_sizes = [], [], []
logger.info("Loading images...") logger.info("Loading images...")
...@@ -317,23 +328,30 @@ def process_batch( ...@@ -317,23 +328,30 @@ def process_batch(
# Return mean confidence score # Return mean confidence score
if confidence_score: if confidence_score:
result["confidences"] = {} result["confidences"] = {}
char_confidences = prediction["confidences"][idx] char_confidences = prediction["confidences"][0]
# retrieve the index of the token ner text = result["text"]
index = [ tokens = parse_tokens(tokens)
pos start_tokens, end_tokens = zip(*list(tokens.values()))
for pos, char in enumerate(predicted_text) end_tokens = list(filter(bool, end_tokens))
if char in ["", "", "", ""]
] if len(end_tokens):
assert len(start_tokens) == len(
end_tokens
), "You don't have the same number of starting tokens and ending tokens"
indices = [
[pos_start, pos_end]
for pos_start, pos_end in zip(start_tokens, end_tokens)
]
else:
indices = pairwise(
[pos for pos, char in enumerate(text) if char in start_tokens]
+ [None]
)
result["confidences"]["by ner token"] = parse_ner_predictions(
text, char_confidences, indices
)
# calculates scores by token
result["confidences"]["by ner token"] = [
{
"text": f"{predicted_text[current: next_token]}".replace("\n", " "),
"confidence_ner": f"{np.around(np.mean(char_confidences[current : next_token]), 2)}",
}
# We go up to -1 so that the last token matches until the end of the text
for current, next_token in pairwise(index + [-1])
]
result["confidences"]["total"] = np.around(np.mean(char_confidences), 2) result["confidences"]["total"] = np.around(np.mean(char_confidences), 2)
for level in confidence_score_levels: for level in confidence_score_levels:
...@@ -398,6 +416,7 @@ def run( ...@@ -398,6 +416,7 @@ def run(
image_extension, image_extension,
gpu_device, gpu_device,
batch_size, batch_size,
tokens,
): ):
""" """
Predict a single image save the output Predict a single image save the output
...@@ -446,4 +465,5 @@ def run( ...@@ -446,4 +465,5 @@ def run(
predict_objects, predict_objects,
threshold_method, threshold_method,
threshold_value, threshold_value,
tokens,
) )
Surname:
start: "Ⓢ"
end: ""
Firstname:
start: "Ⓕ"
end: ""
Batiment:
start: "Ⓑ"
end: ""
Louche:
start: "Ⓛ"
end: ""
Chalumeau:
start: "Ⓒ"
end: ""
Koala:
start: "Ⓚ"
end: ""
Operai:
start: "Ⓞ"
end: ""
Patron:
start: "Ⓟ"
end: ""
...@@ -73,7 +73,16 @@ def test_predict( ...@@ -73,7 +73,16 @@ def test_predict(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 1.0}, {"text": "ⓈBellisson", "confidence": 1.0},
...@@ -95,7 +104,16 @@ def test_predict( ...@@ -95,7 +104,16 @@ def test_predict(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "0.92"},
{"text": "ⒻGeorges ", "confidence_ner": "0.94"},
{"text": "Ⓑ91 ", "confidence_ner": "0.93"},
{"text": "ⓁP ", "confidence_ner": "0.92"},
{"text": "ⒸM ", "confidence_ner": "0.93"},
{"text": "ⓀCh ", "confidence_ner": "0.95"},
{"text": "ⓄPlombier ", "confidence_ner": "0.93"},
{"text": "ⓅPatron?12241", "confidence_ner": "0.93"},
],
"total": 0.93, "total": 0.93,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 0.93}, {"text": "ⓈBellisson", "confidence": 0.93},
...@@ -117,7 +135,16 @@ def test_predict( ...@@ -117,7 +135,16 @@ def test_predict(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"line": [ "line": [
{ {
...@@ -135,7 +162,16 @@ def test_predict( ...@@ -135,7 +162,16 @@ def test_predict(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "0.92"},
{"text": "ⒻGeorges ", "confidence_ner": "0.94"},
{"text": "Ⓑ91 ", "confidence_ner": "0.93"},
{"text": "ⓁP ", "confidence_ner": "0.92"},
{"text": "ⒸM ", "confidence_ner": "0.93"},
{"text": "ⓀCh ", "confidence_ner": "0.95"},
{"text": "ⓄPlombier ", "confidence_ner": "0.93"},
{"text": "ⓅPatron?12241", "confidence_ner": "0.93"},
],
"total": 0.93, "total": 0.93,
"line": [ "line": [
{ {
...@@ -159,7 +195,15 @@ def test_predict( ...@@ -159,7 +195,15 @@ def test_predict(
{ {
"text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376", "text": "ⓈTemplié ⒻMarcelle Ⓑ93 ⓁS Ⓚch ⓄE dactylo Ⓟ18376",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈTemplié ", "confidence_ner": "0.98"},
{"text": "ⒻMarcelle ", "confidence_ner": "1.0"},
{"text": "Ⓑ93 ", "confidence_ner": "1.0"},
{"text": "ⓁS ", "confidence_ner": "1.0"},
{"text": "Ⓚch ", "confidence_ner": "1.0"},
{"text": "ⓄE dactylo ", "confidence_ner": "1.0"},
{"text": "Ⓟ18376", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"char": [ "char": [
{"text": "", "confidence": 1.0}, {"text": "", "confidence": 1.0},
...@@ -272,6 +316,7 @@ def test_run_prediction( ...@@ -272,6 +316,7 @@ def test_run_prediction(
image_extension=None, image_extension=None,
gpu_device=None, gpu_device=None,
batch_size=1, batch_size=1,
tokens=prediction_data_path / "tokens.yml",
) )
prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text()) prediction = json.loads((tmp_path / image_name).with_suffix(".json").read_text())
...@@ -295,7 +340,16 @@ def test_run_prediction( ...@@ -295,7 +340,16 @@ def test_run_prediction(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 1.0}, {"text": "ⓈBellisson", "confidence": 1.0},
...@@ -322,7 +376,16 @@ def test_run_prediction( ...@@ -322,7 +376,16 @@ def test_run_prediction(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 1.0}, {"text": "ⓈBellisson", "confidence": 1.0},
...@@ -339,7 +402,16 @@ def test_run_prediction( ...@@ -339,7 +402,16 @@ def test_run_prediction(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 1.0}, {"text": "ⓈBellisson", "confidence": 1.0},
...@@ -363,7 +435,16 @@ def test_run_prediction( ...@@ -363,7 +435,16 @@ def test_run_prediction(
{ {
"text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241", "text": "ⓈBellisson ⒻGeorges Ⓑ91 ⓁP ⒸM ⓀCh ⓄPlombier ⓅPatron?12241",
"confidences": { "confidences": {
"by ner token": [], "by ner token": [
{"text": "ⓈBellisson ", "confidence_ner": "1.0"},
{"text": "ⒻGeorges ", "confidence_ner": "1.0"},
{"text": "Ⓑ91 ", "confidence_ner": "1.0"},
{"text": "ⓁP ", "confidence_ner": "1.0"},
{"text": "ⒸM ", "confidence_ner": "1.0"},
{"text": "ⓀCh ", "confidence_ner": "1.0"},
{"text": "ⓄPlombier ", "confidence_ner": "1.0"},
{"text": "ⓅPatron?12241", "confidence_ner": "1.0"},
],
"total": 1.0, "total": 1.0,
"word": [ "word": [
{"text": "ⓈBellisson", "confidence": 1.0}, {"text": "ⓈBellisson", "confidence": 1.0},
...@@ -433,6 +514,7 @@ def test_run_prediction_batch( ...@@ -433,6 +514,7 @@ def test_run_prediction_batch(
image_extension=".png", image_extension=".png",
gpu_device=None, gpu_device=None,
batch_size=batch_size, batch_size=batch_size,
tokens=prediction_data_path / "tokens.yml",
) )
for image_name, expected_prediction in zip(image_names, expected_predictions): for image_name, expected_prediction in zip(image_names, expected_predictions):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment