Skip to content
Snippets Groups Projects
Commit fb222870 authored by Manon Blanco's avatar Manon Blanco Committed by Yoann Schneider
Browse files

Remove coverage_vector from predict.py

parent 89f4bb6a
No related branches found
No related tags found
1 merge request!132Remove coverage_vector from predict.py
......@@ -4,7 +4,6 @@
For more details about this package, make sure to see the documentation available at https://teklia.gitlab.io/atr/dan/.
## Installation
To use DAN in your own scripts, install it using pip:
......@@ -55,7 +54,9 @@ See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/datasets/form
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/train/) on the official DAN documentation.
### Synthetic data generation
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/generate/) on the official DAN documentation.
### Model prediction
See the [dedicated section](https://teklia.gitlab.io/atr/dan/usage/predict/) on the official DAN documentation.
......@@ -148,10 +148,6 @@ class DAN:
features = self.encoder(input_tensor.float())
features_size = features.size()
coverage_vector = torch.zeros(
(features.size(0), 1, features.size(2), features.size(3)),
device=self.device,
)
pos_features = self.decoder.features_updater.get_pos_features(features)
features = torch.flatten(pos_features, start_dim=2, end_dim=3).permute(
2, 0, 1
......@@ -179,7 +175,6 @@ class DAN:
confidence_scores.append(
torch.max(torch.softmax(pred[:, :], dim=1), dim=1).values
)
coverage_vector = torch.clamp(coverage_vector + weights, 0, 1)
predicted_tokens = torch.cat(
[
predicted_tokens,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment