From fb222870bb3355a9c8ee898f27dc63e00d454019 Mon Sep 17 00:00:00 2001 From: Manon blanco <blanco@teklia.com> Date: Mon, 22 May 2023 11:00:32 +0000 Subject: [PATCH] Remove coverage_vector from predict.py --- README.md | 3 ++- dan/predict/prediction.py | 5 ----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3f50eb52..0c78c71c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/. - ## Installation To use DAN in your own scripts, install it using pip: @@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation. ### Synthetic data generation + See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation. ### Model prediction + See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation. diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py index 9c8de085..f288fbbf 100644 --- a/dan/predict/prediction.py +++ b/dan/predict/prediction.py @@ -148,10 +148,6 @@ class DAN: features = self.encoder(input_tensor.float()) features_size = features.size() - coverage_vector = torch.zeros( - (features.size(0), 1, features.size(2), features.size(3)), - device=self.device, - ) pos_features = self.decoder.features_updater.get_pos_features(features) features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( 2, 0, 1 @@ -179,7 +175,6 @@ class DAN: confidence_scores.append( torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values ) - coverage_vector = torch.clamp(coverage_vector + weights, 0, 1) predicted_tokens = torch.cat( [ predicted_tokens, -- GitLab