Skip to content
Snippets Groups Projects
Commit cd07bb2b authored by Maciej Sumalvico's avatar Maciej Sumalvico
Browse files

Refactoring and bugfixes in the precision/recall metric

Changes in get_precision_recall():
- Refactoring: separating the funtionality of merging alignments from scoring.
- Changed the definition of true/false and positive/negative. The characters
  that are originally wrong and wrongly corrected are now false positives
  (previously: false negatives). This also changes the evaluation results quite
  significantly!
- Bugfix: consider only one-to-one or one-to-zero alignments, but not
  one-to-many. This also changes the results.
- Code cleaning.
parent 9fce122f
Branches master
No related tags found
No related merge requests found
......@@ -135,97 +135,50 @@ def get_adjusted_distance(l1, l2):
def get_precision_recall(ocr, cor, gt):
"""Calculate number of true/false positive/negative edits of given OCR vs GT and COR vs GT line by aligning them.
Align both OCR and COR with GT line,
then zip through alignment (which constitutes characters and gaps)
to count everything as true/false that is correct/incorrect w.r.t. GT in COR,
and everything as positive/negative that is incorrect/correct w.r.t. GT in OCR,
so precision and recall (or similar) metrics can be accumulated
for post-correction (as a classifier).
(If both OCR and COR are incorrect, this will be a false negative,
regardless of whether COR did alter OCR or not.)
"""
Calculate number of true/false positive/negative edits of given OCR vs
GT and COR vs GT line by aligning them.
Return the true positive, true negative, false positive, false negative counts as a tuple.
Return the true positive, true negative, false positive, false negative
counts as a tuple.
"""
# scoring = SimpleScoring(2, -1)
# aligner = StrictGlobalSequenceAligner(scoring, -2)
# a = Sequence(ocr)
# b = Sequence(cor)
# c = Sequence(gt)
# # create a vocabulary and encode the sequences
# vocabulary = Vocabulary()
# ocr_seq = vocabulary.encodeSequence(a)
# cor_seq = vocabulary.encodeSequence(b)
# gt_seq = vocabulary.encodeSequence(c)
def _merge_alignments(al_1, al_2):
'''
Merges alignment `al_1` between sequences A, C and `al_2` between
sequences B, C into a three-way alignment between A, B, C.
'''
x1, y1 = next(al_1)
x2, y2 = next(al_2)
while True:
try:
if y1 == y2 and y1 != GAP_ELEMENT:
yield x1, x2, y1
x1, y1 = next(al_1)
x2, y2 = next(al_2)
elif y1 == GAP_ELEMENT:
yield x1, '', ''
x1, y1 = next(al_1)
elif y2 == GAP_ELEMENT:
yield '', x2, ''
x2, y2 = next(al_2)
else:
raise RuntimeError(\
'Sequence mismatch in three-way alignment.')
except StopIteration:
break
# _, alignments = aligner.align(ocr_seq, gt_seq, backtrace=True)
# a = vocabulary.decodeSequenceAlignment(alignments[0]) # best result
# _, alignments = aligner.align(cor_seq, gt_seq, backtrace=True)
# b = vocabulary.decodeSequenceAlignment(alignments[0]) # best result
alignment_ocr = get_best_alignment(ocr, gt)
alignment_cor = get_best_alignment(cor, gt)
al_ocr = get_best_alignment(ocr, gt)
al_cor = get_best_alignment(cor, gt)
# positives: incorrect characters before post-correction
# negatives: correct characters before post-correction
# true: correct after post-correction
# false: incorrect after post-correction
# true positives: correctly corrected
# false positives: incorrectly corrected
# true negatives: correctly unchanged
# false negatives: incorrectly unchanged
# precision: ratio of correct changes
# recall: ratio of detected errors
i = 0
j = 0
FP = 0
TP = 0
FN = 0
TN = 0
while i < len(alignment_ocr) and j < len(alignment_cor):
ocr_sym = alignment_ocr[i][0] or ''
while i < len(alignment_ocr) and not alignment_ocr[i][1]:
i += 1
ocr_sym += alignment_ocr[i][0] or ''
cor_sym = alignment_cor[j][0] or ''
while j < len(alignment_cor) and not alignment_cor[j][1]:
j += 1
cor_sym += alignment_cor[j][0] or ''
gt_sym = alignment_ocr[i][1]
# loop invariants:
assert alignment_ocr[i][1] == alignment_cor[j][1] # gt is synchronous with i/j
assert alignment_ocr[i][1] # jumping over gaps
assert alignment_cor[j][1] # jumping over gaps
i += 1
j += 1
# fill from final gaps:
while i < len(alignment_ocr) and not alignment_ocr[i][1]:
ocr_sym += alignment_ocr[i][0]
i += 1
while j < len(alignment_cor) and not alignment_cor[j][1]:
cor_sym += alignment_cor[j][0]
j += 1
#print('ocr_sym: "%s"' % ocr_sym)
#print('cor_sym: "%s"' % cor_sym)
#print('gt_sym: "%s"' % gt_sym)
# counting:
if ocr_sym == gt_sym:
if cor_sym == gt_sym:
TN += 1
else:
FP += 1
else:
if cor_sym == gt_sym:
TP += 1
else:
FN += 1
# what if ocr_sym != cor_sym ?
assert i == len(alignment_ocr)
assert j == len(alignment_cor)
TP, FP, TN, FN = 0, 0, 0, 0
for c_ocr, c_cor, c_gt in _merge_alignments(iter(al_ocr), iter(al_cor)):
is_correct = (c_cor == c_gt)
is_changed = (c_cor != c_ocr)
TP += 1 if is_changed and is_correct else 0
FP += 1 if is_changed and not is_correct else 0
TN += 1 if not is_changed and is_correct else 0
FN += 1 if not is_changed and not is_correct else 0
return (TP, TN, FP, FN)
......@@ -240,9 +193,12 @@ def compute_total_precision_recall(line_triplets, silent=False):
FN += l_FN
if not silent:
print_line(ocr, cor, gt)
print("TP: %d / TN: %d / FP: %d / FN: %d" %
(l_TP, l_TN, l_FP, l_FN))
print('precision: %.3f / recall %.3f' %
(1 if l_TP+l_FP == 0 else l_TP / (l_TP+l_FP),
1 if l_TP+l_FN == 0 else l_TP / (l_TP+l_FN)))
print()
return TP, TN, FP, FN
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment