Skip to content
Snippets Groups Projects
Commit fb222870 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Remove coverage_vector from predict.py

parent 89f4bb6a
No related branches found
No related tags found
1 merge request!132Remove coverage_vector from predict.py
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/. For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
## Installation ## Installation
To use DAN in your own scripts, install it using pip: To use DAN in your own scripts, install it using pip:
...@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form ...@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
### Synthetic data generation ### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
### Model prediction ### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation. See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
...@@ -148,10 +148,6 @@ class DAN: ...@@ -148,10 +148,6 @@ class DAN:
features = self.encoder(input_tensor.float()) features = self.encoder(input_tensor.float())
features_size = features.size() features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features) pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute( features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1 2, 0, 1
...@@ -179,7 +175,6 @@ class DAN: ...@@ -179,7 +175,6 @@ class DAN:
confidence_scores.append( confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
) )
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat( predicted_tokens = torch.cat(
[ [
predicted_tokens, predicted_tokens,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment