Skip to content
Snippets Groups Projects
Commit e276294e authored by Solene Tarride's avatar Solene Tarride
Browse files

Simplify and document data extraction

parent da4652c8
No related branches found
No related tags found
1 merge request!287Support subword and word language models
...@@ -393,13 +393,13 @@ class ArkindexExtractor: ...@@ -393,13 +393,13 @@ class ArkindexExtractor:
indent=4, indent=4,
) )
) )
(self.output / "language_corpus.txt").write_text( (self.output / "language_model" / "corpus.txt").write_text(
"\n".join(self.language_corpus) "\n".join(self.language_corpus)
) )
(self.output / "language_tokens.txt").write_text( (self.output / "language_model" / "tokens.txt").write_text(
"\n".join(self.language_tokens) "\n".join(self.language_tokens)
) )
(self.output / "language_lexicon.txt").write_text( (self.output / "language_model" / "lexicon.txt").write_text(
"\n".join(self.language_lexicon) "\n".join(self.language_lexicon)
) )
(self.output / "charset.pkl").write_bytes( (self.output / "charset.pkl").write_bytes(
......
...@@ -92,12 +92,12 @@ class DAN: ...@@ -92,12 +92,12 @@ class DAN:
self.decoder = decoder self.decoder = decoder
self.lm_decoder = None self.lm_decoder = None
if use_language_model and parameters["lm_decoder"]["language_model_weight"] > 0: if use_language_model and parameters["language_model"]["weight"] > 0:
self.lm_decoder = CTCLanguageDecoder( self.lm_decoder = CTCLanguageDecoder(
language_model_path=parameters["lm_decoder"]["language_model_path"], language_model_path=parameters["language_model"]["model"],
lexicon_path=parameters["lm_decoder"]["lexicon_path"], lexicon_path=parameters["language_model"]["lexicon"],
tokens_path=parameters["lm_decoder"]["tokens_path"], tokens_path=parameters["language_model"]["tokens"],
language_model_weight=parameters["lm_decoder"]["language_model_weight"], language_model_weight=parameters["language_model"]["weight"],
) )
self.mean, self.std = ( self.mean, self.std = (
......
...@@ -20,10 +20,10 @@ output/ ...@@ -20,10 +20,10 @@ output/
│ ├── train │ ├── train
│ ├── val │ ├── val
│ └── test │ └── test
── language_model ── language_model
├── corpus.txt ├── language_corpus.txt
├── lexicon.txt ├── language_lexicon.txt
└── tokens.txt └── language_tokens.txt
``` ```
## 2. Train ## 2. Train
......
...@@ -22,8 +22,8 @@ parameters: ...@@ -22,8 +22,8 @@ parameters:
- type: "max_resize" - type: "max_resize"
max_height: 1500 max_height: 1500
max_width: 1500 max_width: 1500
lm_decoder: language_model:
language_model_path: tests/data/prediction/language_model.arpa model: tests/data/prediction/language_model.arpa
lexicon_path: tests/data/prediction/language_lexicon.txt lexicon: tests/data/prediction/language_lexicon.txt
tokens_path: tests/data/prediction/language_tokens.txt tokens: tests/data/prediction/language_tokens.txt
language_model_weight: 1.0 weight: 1.0
...@@ -396,9 +396,10 @@ def test_extract( ...@@ -396,9 +396,10 @@ def test_extract(
VAL_DIR / "val-page_1-line_2.jpg", VAL_DIR / "val-page_1-line_2.jpg",
VAL_DIR / "val-page_1-line_3.jpg", VAL_DIR / "val-page_1-line_3.jpg",
output / "labels.json", output / "labels.json",
output / "language_corpus.txt", # Language resources
output / "language_lexicon.txt", output / "language_model" / "corpus.txt",
output / "language_tokens.txt", output / "language_model" / "lexicon.txt",
output / "language_model" / "tokens.txt",
] ]
assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths assert sorted(filter(methodcaller("is_file"), output.rglob("*"))) == expected_paths
...@@ -487,20 +488,22 @@ def test_extract( ...@@ -487,20 +488,22 @@ def test_extract(
"", expected_language_corpus "", expected_language_corpus
) )
assert (output / "language_corpus.txt").read_text() == expected_language_corpus assert (
output / "language_model" / "corpus.txt"
).read_text() == expected_language_corpus
# Check "language_tokens.txt" # Check "language_tokens.txt"
expected_language_tokens = [ expected_language_tokens = [
t if t != " " else "" for t in sorted(list(expected_charset)) t if t != " " else "" for t in sorted(list(expected_charset))
] ]
expected_language_tokens.append("") expected_language_tokens.append("")
assert (output / "language_tokens.txt").read_text() == "\n".join( assert (output / "language_model" / "tokens.txt").read_text() == "\n".join(
expected_language_tokens expected_language_tokens
) )
# Check "language_lexicon.txt" # Check "language_lexicon.txt"
expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens] expected_language_lexicon = [f"{t} {t}" for t in expected_language_tokens]
assert (output / "language_lexicon.txt").read_text() == "\n".join( assert (output / "language_model" / "lexicon.txt").read_text() == "\n".join(
expected_language_lexicon expected_language_lexicon
) )
......
...@@ -739,11 +739,15 @@ def test_run_prediction_language_model( ...@@ -739,11 +739,15 @@ def test_run_prediction_language_model(
# Update language_model_weight in parameters.yml # Update language_model_weight in parameters.yml
params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml") params = read_yaml(PREDICTION_DATA_PATH / "parameters.yml")
<<<<<<< HEAD
<<<<<<< HEAD <<<<<<< HEAD
params["parameters"]["language_model"]["weight"] = language_model_weight params["parameters"]["language_model"]["weight"] = language_model_weight
======= =======
params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight params["parameters"]["lm_decoder"]["language_model_weight"] = language_model_weight
>>>>>>> c80c413 (Write tests for LM decoding) >>>>>>> c80c413 (Write tests for LM decoding)
=======
params["parameters"]["language_model"]["weight"] = language_model_weight
>>>>>>> 57684ef (Simplify and document data extraction)
yaml.dump(params, (tmp_path / "parameters.yml").open("w")) yaml.dump(params, (tmp_path / "parameters.yml").open("w"))
run_prediction( run_prediction(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment