Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (4)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/. For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
## Installation ## Installation
To use DAN in your own scripts, install it using pip: To use DAN in your own scripts, install it using pip:
...@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form ...@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
### Synthetic data generation ### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
### Model prediction ### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
...@@ -272,13 +272,6 @@ class GenericDataset(Dataset): ...@@ -272,13 +272,6 @@ class GenericDataset(Dataset):
) )
if load_in_memory: if load_in_memory:
samples[-1]["img"] = GenericDataset.load_image(filename) samples[-1]["img"] = GenericDataset.load_image(filename)
if type(gt[filename]) is dict:
if "lines" in gt[filename].keys():
samples[-1]["raw_line_seg_label"] = gt[filename]["lines"]
if "paragraphs" in gt[filename].keys():
samples[-1]["paragraphs_label"] = gt[filename]["paragraphs"]
if "pages" in gt[filename].keys():
samples[-1]["pages_label"] = gt[filename]["pages"]
return samples return samples
def apply_preprocessing(self, preprocessings): def apply_preprocessing(self, preprocessings):
...@@ -421,15 +414,6 @@ def apply_preprocessing(sample, preprocessings): ...@@ -421,15 +414,6 @@ def apply_preprocessing(sample, preprocessings):
temp_img = np.expand_dims(temp_img, axis=2) temp_img = np.expand_dims(temp_img, axis=2)
img = temp_img img = temp_img
resize_ratio = [ratio, ratio] resize_ratio = [ratio, ratio]
if resize_ratio != [1, 1] and "raw_line_seg_label" in sample:
for li in range(len(sample["raw_line_seg_label"])):
for side, ratio in zip(
(["bottom", "top"], ["right", "left"]), resize_ratio
):
for s in side:
sample["raw_line_seg_label"][li][s] = (
sample["raw_line_seg_label"][li][s] * ratio
)
sample["img"] = img sample["img"] = img
sample["resize_ratio"] = resize_ratio sample["resize_ratio"] = resize_ratio
......
...@@ -152,8 +152,6 @@ class MetricManager: ...@@ -152,8 +152,6 @@ class MetricManager:
"loss", "loss",
"loss_ctc", "loss_ctc",
"loss_ce", "loss_ce",
"syn_max_lines",
"syn_prob_lines",
]: ]:
value = float( value = float(
np.average( np.average(
...@@ -225,8 +223,6 @@ class MetricManager: ...@@ -225,8 +223,6 @@ class MetricManager:
"loss_ctc", "loss_ctc",
"loss_ce", "loss_ce",
"loss", "loss",
"syn_max_lines",
"syn_prob_lines",
]: ]:
metrics[metric_name] = [ metrics[metric_name] = [
values[metric_name], values[metric_name],
......
...@@ -1188,12 +1188,6 @@ class Manager(OCRManager): ...@@ -1188,12 +1188,6 @@ class Manager(OCRManager):
"str_x": str_x, "str_x": str_x,
"loss": sum_loss.item(), "loss": sum_loss.item(),
"loss_ce": loss_ce.item(), "loss_ce": loss_ce.item(),
"syn_max_lines": self.dataset.train_dataset.get_syn_max_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
"syn_prob_lines": self.dataset.train_dataset.get_syn_proba_lines()
if self.params["dataset_params"]["config"]["synthetic_data"]
else 0,
} }
return values return values
...@@ -1247,10 +1241,6 @@ class Manager(OCRManager): ...@@ -1247,10 +1241,6 @@ class Manager(OCRManager):
else: else:
features = self.models["encoder"](x) features = self.models["encoder"](x)
features_size = features.size() features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.models["decoder"].features_updater.get_pos_features( pos_features = self.models["decoder"].features_updater.get_pos_features(
features features
) )
...@@ -1279,7 +1269,6 @@ class Manager(OCRManager): ...@@ -1279,7 +1269,6 @@ class Manager(OCRManager):
confidence_scores.append( confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
) )
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat( predicted_tokens = torch.cat(
[ [
predicted_tokens, predicted_tokens,
......
...@@ -212,8 +212,6 @@ def get_config(): ...@@ -212,8 +212,6 @@ def get_config():
"cer", "cer",
"wer", "wer",
"wer_no_punct", "wer_no_punct",
"syn_max_lines",
"syn_prob_lines",
], # Metrics name for training ], # Metrics name for training
"eval_metrics": [ "eval_metrics": [
"cer", "cer",
......
...@@ -148,10 +148,6 @@ class DAN: ...@@ -148,10 +148,6 @@ class DAN:
features = self.encoder(input_tensor.float()) features = self.encoder(input_tensor.float())
features_size = features.size() features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features) pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1 2, 0, 1
...@@ -179,7 +175,6 @@ class DAN: ...@@ -179,7 +175,6 @@ class DAN:
confidence_scores.append( confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
) )
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat( predicted_tokens = torch.cat(
[ [
predicted_tokens, predicted_tokens,
......