Add predicted objects to predict command
Depends #35 (closed)
When we are able to get the coverage vector for each detected object, we can work on finding polygons for each of them. This will be integrated to the predict
command and triggered by a new flag --predict-objects
.
Implement two functions in dan.predict.attention
def get_predicted_polygons(
text,
weights,
level,
word_separators=["\n", " "],
line_separators=["\n"],
):
polygons = []
# Split text into characters, words or lines
text_list, offset = split_text(text, level, word_separators, line_separators)
max_value = weights.sum(0).max()
return [
get_polygon(text_piece, level, max_value, weights)
for text_piece in text_list
]
def get_polygon(text_piece, level, max_value, weights):
coverage_vector = compute_coverage(text_piece, max_value, level, weights)
# Generate a binary image for the current channel.
bin_img = coverage_vector.copy()
bin_img[bin_img > 0] = 1
# Detect the objects contours.
contours, _ = cv2.findContours(
np.uint8(bin_img), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
return {
"confidence": ..., # average of coverage vector on contour
"polygon": contour,
"type": level,
}
The get_polygon
function will be used in dan.predict.attention.plot_attention
if a new boolean parameter output_polygons
is enabled.
The loop in plot_attention becomes
polygons = []
for text_piece in text_list:
coverage_vector = compute_coverage(...)
if output_polygons:
polygons.append(get_polygon(...))
# Keep track of text length
tot_len += len(text_piece) + offset
attention_map.append(blend_coverage(...))
return polygons
I supposed that the image generation can also be ported in a separate blend_coverage
function.
The get_predicted_polygons
function will be used in the predict method when attentions
is enabled. We will add a new parameter attention_level
which will be used as level
in the get_predicted_polygons
call. The result will be stored in a new key objects
.
The predict
command will have a new CLI argument --predict-objects
, this flag will control the output_polygons
value in the plot_attention call