Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • atr/dan
1 result
Show changes
Commits on Source (5)
...@@ -5,7 +5,7 @@ repos: ...@@ -5,7 +5,7 @@ repos:
- id: isort - id: isort
args: ["--profile", "black"] args: ["--profile", "black"]
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 22.12.0 rev: 23.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
......
...@@ -214,7 +214,6 @@ class GlobalDecoderLayer(Module): ...@@ -214,7 +214,6 @@ class GlobalDecoderLayer(Module):
memory_key_padding_mask=None, memory_key_padding_mask=None,
predict_last_n_only=None, predict_last_n_only=None,
): ):
if memory_value is None: if memory_value is None:
memory_value = memory_key memory_value = memory_key
......
...@@ -365,7 +365,6 @@ def apply_preprocessing(sample, preprocessings): ...@@ -365,7 +365,6 @@ def apply_preprocessing(sample, preprocessings):
resize_ratio = [1, 1] resize_ratio = [1, 1]
img = sample["img"] img = sample["img"]
for preprocessing in preprocessings: for preprocessing in preprocessings:
if preprocessing["type"] == "dpi": if preprocessing["type"] == "dpi":
ratio = preprocessing["target"] / preprocessing["source"] ratio = preprocessing["target"] / preprocessing["source"]
temp_img = img temp_img = img
......
...@@ -31,6 +31,27 @@ def split_text(text, level, word_separators, line_separators): ...@@ -31,6 +31,27 @@ def split_text(text, level, word_separators, line_separators):
return text_split, offset return text_split, offset
def compute_coverage(text: str, max_value: float, offset: int, attentions):
"""
Aggregates attention maps for the current text piece (char, word, line)
:param text: Text piece selected with offset after splitting DAN prediction
:param max_value: Maximum "attention intensity" for parts of a text piece, used for normalization
:param offset: Offset value to get the relevant part of text piece
:param attentions: Attention weights of size (n_char, feature_height, feature_width)
"""
_, height, width = attentions.shape
# blank vector to accumulate weights for the current text
coverage_vector = np.zeros((height, width))
for i in range(len(text)):
local_weight = cv2.resize(attentions[i + offset], (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
return coverage_vector
def plot_attention( def plot_attention(
image, image,
text, text,
...@@ -50,6 +71,7 @@ def plot_attention( ...@@ -50,6 +71,7 @@ def plot_attention(
:param scale: Scaling factor for the output gif image :param scale: Scaling factor for the output gif image
:param outname: Name of the gif image :param outname: Name of the gif image
""" """
height, width, _ = image.shape height, width, _ = image.shape
attention_map = [] attention_map = []
...@@ -64,20 +86,15 @@ def plot_attention( ...@@ -64,20 +86,15 @@ def plot_attention(
tot_len = 0 tot_len = 0
max_value = weights.sum(0).max() max_value = weights.sum(0).max()
for text_piece in text_list: for text_piece in text_list:
# blank vector to accumulate weights for the current word/line # Accumulate weights for the current word/line and resize to original image size
coverage_vector = np.zeros((height, width)) coverage_vector = compute_coverage(text_piece, max_value, tot_len, weights)
for i in range(len(text_piece)): coverage_vector = cv2.resize(coverage_vector, (width, height))
local_weight = weights[i + tot_len]
local_weight = cv2.resize(local_weight, (width, height))
coverage_vector = np.clip(coverage_vector + local_weight, 0, 1)
# Keep track of text length # Keep track of text length
tot_len += len(text_piece) + offset tot_len += len(text_piece) + offset
# Normalize coverage vector
coverage_vector = (coverage_vector / max_value * 255).astype(np.uint8)
# Blend coverage vector with original image # Blend coverage vector with original image
blank_array = np.zeros((height, width)).astype(np.uint8) blank_array = np.zeros((height, width)).astype(np.uint8)
coverage_vector = Image.fromarray( coverage_vector = Image.fromarray(
......
...@@ -30,7 +30,7 @@ class DropoutScheduler: ...@@ -30,7 +30,7 @@ class DropoutScheduler:
self.init_teta_list_module(child) self.init_teta_list_module(child)
def update_dropout_rate(self): def update_dropout_rate(self):
for (module, p) in self.teta_list: for module, p in self.teta_list:
module.p = self.function(p, self.step_num, self.T) module.p = self.function(p, self.step_num, self.T)
......
...@@ -148,7 +148,6 @@ class ZoomRatio: ...@@ -148,7 +148,6 @@ class ZoomRatio:
class ElasticDistortion: class ElasticDistortion:
def __init__(self, kernel_size=(7, 7), sigma=5, alpha=1): def __init__(self, kernel_size=(7, 7), sigma=5, alpha=1):
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.sigma = sigma self.sigma = sigma
self.alpha = alpha self.alpha = alpha
......
black==22.12.0 black==22.12.0
doc8==1.1.1 doc8==1.1.1
mkdocs==1.4.2 mkdocs==1.4.2
mkdocs-material==9.1.0 mkdocs-material==9.1.2
mkdocstrings==0.20.0 mkdocstrings==0.20.0
mkdocstrings-python==0.8.3 mkdocstrings-python==0.8.3
recommonmark==0.7.1 recommonmark==0.7.1
arkindex-client==1.0.11 arkindex-client==1.0.11
boto3==1.26.83 boto3==1.26.90
editdistance==0.6.2 editdistance==0.6.2
fontTools==4.38.0 fontTools==4.38.0
imageio==2.26.0 imageio==2.26.0
mlflow==2.0.1 mlflow==2.2.2
networkx==3.0 networkx==3.0
numpy==1.23.5 numpy==1.23.5
opencv-python==4.7.0.72 opencv-python==4.7.0.72
......