From fb222870bb3355a9c8ee898f27dc63e00d454019 Mon Sep 17 00:00:00 2001
From: Manon blanco <blanco@teklia.com>
Date: Mon, 22 May 2023 11:00:32 +0000
Subject: [PATCH] Remove coverage_vector from predict.py

---
 README.md                 | 3 ++-
 dan/predict/prediction.py | 5 -----
 2 files changed, 2 insertions(+), 6 deletions(-)

diff --git a/README.md b/README.md
index 3f50eb52..0c78c71c 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,6 @@
 
 For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
 
-
 ## Installation
 
 To use DAN in your own scripts, install it using pip:
@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
 See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
 
 ### Synthetic data generation
+
 See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
 
 ### Model prediction
+
 See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
diff --git a/dan/predict/prediction.py b/dan/predict/prediction.py
index 9c8de085..f288fbbf 100644
--- a/dan/predict/prediction.py
+++ b/dan/predict/prediction.py
@@ -148,10 +148,6 @@ class DAN:
 
             features = self.encoder(input_tensor.float())
             features_size = features.size()
-            coverage_vector = torch.zeros(
-                (features.size(0), 1, features.size(2), features.size(3)),
-                device=self.device,
-            )
             pos_features = self.decoder.features_updater.get_pos_features(features)
             features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
                 2, 0, 1
@@ -179,7 +175,6 @@ class DAN:
                 confidence_scores.append(
                     torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
                 )
-                coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
                 predicted_tokens = torch.cat(
                     [
                         predicted_tokens,
-- 
GitLab