test code
Browse files- classifier.py +147 -0
- merger.py +181 -0
- requirements.txt +4 -1
- ru_errant.py +117 -18
classifier.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from string import punctuation
|
| 5 |
+
|
| 6 |
+
import Levenshtein
|
| 7 |
+
from errant.edit import Edit
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def edit_to_tuple(edit: Edit, idx: int = 0) -> tuple[int, int, str, str, int]:
|
| 11 |
+
cor_toks_str = " ".join([tok.text for tok in edit.c_toks])
|
| 12 |
+
return [edit.o_start, edit.o_end, edit.type, cor_toks_str, idx]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def classify(edit: Edit) -> list[Edit]:
|
| 16 |
+
"""Classifies an Edit via updating its `type` attribute."""
|
| 17 |
+
# Insertion and deletion
|
| 18 |
+
if ((not edit.o_toks and edit.c_toks) or (edit.o_toks and not edit.c_toks)):
|
| 19 |
+
error_cats = get_one_sided_type(edit.o_toks, edit.c_toks)
|
| 20 |
+
elif edit.o_toks != edit.c_toks:
|
| 21 |
+
error_cats = get_two_sided_type(edit.o_toks, edit.c_toks)
|
| 22 |
+
else:
|
| 23 |
+
error_cats = {"NA": edit.c_toks[0].text}
|
| 24 |
+
new_edit_list = []
|
| 25 |
+
if error_cats:
|
| 26 |
+
for error_cat, correct_str in error_cats.items():
|
| 27 |
+
edit.type = error_cat
|
| 28 |
+
edit_tuple = edit_to_tuple(edit)
|
| 29 |
+
edit_tuple[3] = correct_str
|
| 30 |
+
new_edit_list.append(edit_tuple)
|
| 31 |
+
return new_edit_list
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_edit_info(toks):
|
| 35 |
+
pos = []
|
| 36 |
+
dep = []
|
| 37 |
+
morph = dict()
|
| 38 |
+
for tok in toks:
|
| 39 |
+
pos.append(tok.tag_)
|
| 40 |
+
dep.append(tok.dep_)
|
| 41 |
+
morphs = str(tok.morph).split('|')
|
| 42 |
+
for m in morphs:
|
| 43 |
+
if len(m.strip()):
|
| 44 |
+
k, v = m.strip().split('=')
|
| 45 |
+
morph[k] = v
|
| 46 |
+
return pos, dep, morph
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_one_sided_type(o_toks, c_toks):
|
| 50 |
+
"""Classifies a zero-to-one or one-to-zero error based on a token list."""
|
| 51 |
+
pos_list, _, _ = get_edit_info(o_toks if o_toks else c_toks)
|
| 52 |
+
if "PUNCT" in pos_list or "SPACE" in pos_list:
|
| 53 |
+
return {"PUNCT": c_toks[0].text if c_toks else ""}
|
| 54 |
+
return {"SPELL": c_toks[0].text if c_toks else ""}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_two_sided_type(o_toks, c_toks) -> dict[str, str]:
|
| 58 |
+
"""Classifies a one-to-one or one-to-many or many-to-one error based on token lists."""
|
| 59 |
+
# one-to-one cases
|
| 60 |
+
if len(o_toks) == len(c_toks) == 1:
|
| 61 |
+
if (
|
| 62 |
+
all(char in punctuation + " " for char in o_toks[0].text) and
|
| 63 |
+
all(char in punctuation + " " for char in c_toks[0].text)
|
| 64 |
+
):
|
| 65 |
+
return {"PUNCT": c_toks[0].text}
|
| 66 |
+
source_w, correct_w = o_toks[0].text, c_toks[0].text
|
| 67 |
+
if source_w != correct_w:
|
| 68 |
+
# if both string are lowercase or both are uppercase,
|
| 69 |
+
# and there is no "ё" in both, then it may be only "SPELL" error type
|
| 70 |
+
if (((source_w.islower() and correct_w.islower()) or
|
| 71 |
+
(source_w.isupper() and correct_w.isupper())) and
|
| 72 |
+
"ё" not in source_w + correct_w):
|
| 73 |
+
return {"SPELL": correct_w}
|
| 74 |
+
# edits with multiple errors (e.g. SPELL + CASE)
|
| 75 |
+
# Step 1. Make char-level Levenstein table
|
| 76 |
+
char_edits = Levenshtein.editops(source_w, correct_w)
|
| 77 |
+
# Step 2. Classify operations (CASE, YO, SPELL)
|
| 78 |
+
edits_classified = classify_char_edits(char_edits, source_w, correct_w)
|
| 79 |
+
# Step 3. Combine the same-typed errors into minimal string pairs
|
| 80 |
+
separated_edits = get_edit_strings(source_w, correct_w, edits_classified)
|
| 81 |
+
return separated_edits
|
| 82 |
+
# one-to-many and many-to-one cases
|
| 83 |
+
if all(char in punctuation + " " for char in o_toks.text + c_toks.text):
|
| 84 |
+
return {"PUNCT": c_toks.text}
|
| 85 |
+
joint_corr_str = " ".join([tok.text for tok in c_toks])
|
| 86 |
+
joint_corr_str = joint_corr_str.replace("- ", "-").replace(" -", "-")
|
| 87 |
+
return {"SPELL": joint_corr_str}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def classify_char_edits(char_edits, source_w, correct_w):
|
| 91 |
+
"""Classifies char-level Levenstein operations into SPELL, YO and CASE."""
|
| 92 |
+
edits_classified = []
|
| 93 |
+
for edit in char_edits:
|
| 94 |
+
if edit[0] == "replace":
|
| 95 |
+
if "ё" in [source_w[edit[1]], correct_w[edit[2]]]:
|
| 96 |
+
edits_classified.append((*edit, "YO"))
|
| 97 |
+
elif source_w[edit[1]].lower() == correct_w[edit[2]].lower():
|
| 98 |
+
edits_classified.append((*edit, "CASE"))
|
| 99 |
+
else:
|
| 100 |
+
if (
|
| 101 |
+
(source_w[edit[1]].islower() and correct_w[edit[2]].isupper()) or
|
| 102 |
+
(source_w[edit[1]].isupper() and correct_w[edit[2]].islower())
|
| 103 |
+
):
|
| 104 |
+
edits_classified.append((*edit, "CASE"))
|
| 105 |
+
edits_classified.append((*edit, "SPELL"))
|
| 106 |
+
else:
|
| 107 |
+
edits_classified.append((*edit, "SPELL"))
|
| 108 |
+
return edits_classified
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_edit_strings(source: str, correction: str,
|
| 112 |
+
edits_classified: list[tuple]) -> dict[str, str]:
|
| 113 |
+
"""
|
| 114 |
+
Applies classified (SPELL, YO and CASE) char operations to source word separately.
|
| 115 |
+
Returns a dict mapping error type to source string with corrections of this type only.
|
| 116 |
+
"""
|
| 117 |
+
separated_edits = defaultdict(lambda: source)
|
| 118 |
+
shift = 0 # char position shift to consider on deletions and insertions
|
| 119 |
+
for edit in edits_classified:
|
| 120 |
+
edit_type = edit[3]
|
| 121 |
+
curr_src = separated_edits[edit_type]
|
| 122 |
+
if edit_type == "CASE": # SOURCE letter spelled in CORRECTION case
|
| 123 |
+
if correction[edit[2]].isupper():
|
| 124 |
+
correction_char = source[edit[1]].upper()
|
| 125 |
+
else:
|
| 126 |
+
correction_char = source[edit[1]].lower()
|
| 127 |
+
else:
|
| 128 |
+
if edit[0] == "delete":
|
| 129 |
+
correction_char = ""
|
| 130 |
+
elif edit[0] == "insert":
|
| 131 |
+
correction_char = correction[edit[2]]
|
| 132 |
+
elif source[edit[1]].isupper():
|
| 133 |
+
correction_char = correction[edit[2]].upper()
|
| 134 |
+
else:
|
| 135 |
+
correction_char = correction[edit[2]].lower()
|
| 136 |
+
if edit[0] == "replace":
|
| 137 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
|
| 138 |
+
curr_src[edit[1]+shift + 1:]
|
| 139 |
+
elif edit[0] == "delete":
|
| 140 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + \
|
| 141 |
+
curr_src[edit[1]+shift + 1:]
|
| 142 |
+
shift -= 1
|
| 143 |
+
elif edit[0] == "insert":
|
| 144 |
+
separated_edits[edit_type] = curr_src[:edit[1] + shift] + correction_char + \
|
| 145 |
+
curr_src[edit[1]+shift:]
|
| 146 |
+
shift += 1
|
| 147 |
+
return dict(separated_edits)
|
merger.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import re
|
| 5 |
+
from string import punctuation
|
| 6 |
+
|
| 7 |
+
import Levenshtein
|
| 8 |
+
from errant.alignment import Alignment
|
| 9 |
+
from errant.edit import Edit
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_rule_edits(alignment: Alignment) -> list[Edit]:
|
| 13 |
+
"""Groups word-level alignment according to merging rules."""
|
| 14 |
+
edits = []
|
| 15 |
+
# Split alignment into groups
|
| 16 |
+
alignment_groups = group_alignment(alignment, "new")
|
| 17 |
+
for op, group in alignment_groups:
|
| 18 |
+
group = list(group)
|
| 19 |
+
# Ignore M
|
| 20 |
+
if op == "M":
|
| 21 |
+
continue
|
| 22 |
+
# T is always split
|
| 23 |
+
if op == "T":
|
| 24 |
+
for seq in group:
|
| 25 |
+
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
|
| 26 |
+
# Process D, I and S subsequence
|
| 27 |
+
else:
|
| 28 |
+
processed = process_seq(group, alignment)
|
| 29 |
+
# Turn the processed sequence into edits
|
| 30 |
+
for seq in processed:
|
| 31 |
+
edits.append(Edit(alignment.orig, alignment.cor, seq[1:]))
|
| 32 |
+
return edits
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def group_alignment(alignment: Alignment, mode: str = "default") -> list[tuple[str, list[tuple]]]:
|
| 36 |
+
"""
|
| 37 |
+
Does initial alignment grouping:
|
| 38 |
+
1. Make groups of MDM, MIM od MSM.
|
| 39 |
+
2. In remaining operations, make groups of Ms, groups of Ts, and D/I/Ss.
|
| 40 |
+
Do not group what was on the sides of M[DIS]M: SSMDMS -> [SS, MDM, S], not [MDM, SSS].
|
| 41 |
+
3. Sort groups by the order in which they appear in the alignment.
|
| 42 |
+
"""
|
| 43 |
+
if mode == "new":
|
| 44 |
+
op_groups = []
|
| 45 |
+
# Format operation types sequence as string to use regex sequence search
|
| 46 |
+
all_ops_seq = "".join([op[0][0] for op in alignment.align_seq])
|
| 47 |
+
# Find M[DIS]M groups and merge (need them to detect hyphen vs. space spelling)
|
| 48 |
+
ungrouped_ids = list(range(len(alignment.align_seq)))
|
| 49 |
+
for match in re.finditer("M[DIS]M", all_ops_seq):
|
| 50 |
+
start, end = match.start(), match.end()
|
| 51 |
+
op_groups.append(("MSM", alignment.align_seq[start:end]))
|
| 52 |
+
for idx in range(start, end):
|
| 53 |
+
ungrouped_ids.remove(idx)
|
| 54 |
+
# Group remaining operations by default rules (groups of M, T and rest)
|
| 55 |
+
if ungrouped_ids:
|
| 56 |
+
def get_group_type(operation):
|
| 57 |
+
return operation if operation in {"M", "T"} else "DIS"
|
| 58 |
+
curr_group = [alignment.align_seq[ungrouped_ids[0]]]
|
| 59 |
+
last_oper_type = get_group_type(curr_group[0][0][0])
|
| 60 |
+
for i, idx in enumerate(ungrouped_ids[1:], start=1):
|
| 61 |
+
operation = alignment.align_seq[idx]
|
| 62 |
+
oper_type = get_group_type(operation[0][0])
|
| 63 |
+
if (oper_type == last_oper_type and
|
| 64 |
+
(idx - ungrouped_ids[i-1] == 1 or oper_type in {"M", "T"})):
|
| 65 |
+
curr_group.append(operation)
|
| 66 |
+
else:
|
| 67 |
+
op_groups.append((last_oper_type, curr_group))
|
| 68 |
+
curr_group = [operation]
|
| 69 |
+
last_oper_type = oper_type
|
| 70 |
+
if curr_group:
|
| 71 |
+
op_groups.append((last_oper_type, curr_group))
|
| 72 |
+
# Sort groups by the start id of the first group entry
|
| 73 |
+
op_groups = sorted(op_groups, key=lambda x: x[1][0][1])
|
| 74 |
+
else:
|
| 75 |
+
grouped = itertools.groupby(alignment.align_seq,
|
| 76 |
+
lambda x: x[0][0] if x[0][0] in {"M", "T"} else False)
|
| 77 |
+
op_groups = [(op, list(group)) for op, group in grouped]
|
| 78 |
+
return op_groups
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def process_seq(seq: list[tuple], alignment: Alignment) -> list[tuple]:
|
| 82 |
+
"""Applies merging rules to previously formed alignment groups (`seq`)."""
|
| 83 |
+
# Return single alignments
|
| 84 |
+
if len(seq) <= 1:
|
| 85 |
+
return seq
|
| 86 |
+
# Get the ops for the whole sequence
|
| 87 |
+
ops = [op[0] for op in seq]
|
| 88 |
+
|
| 89 |
+
# Get indices of all start-end combinations in the seq: 012 = 01, 02, 12
|
| 90 |
+
combos = list(itertools.combinations(range(0, len(seq)), 2))
|
| 91 |
+
# Sort them starting with largest spans first
|
| 92 |
+
combos.sort(key=lambda x: x[1] - x[0], reverse=True)
|
| 93 |
+
# Loop through combos
|
| 94 |
+
for start, end in combos:
|
| 95 |
+
# Ignore ranges that do NOT contain a substitution, deletion or insertion.
|
| 96 |
+
if not any(type_ in ops[start:end + 1] for type_ in ["D", "I", "S"]):
|
| 97 |
+
continue
|
| 98 |
+
# Merge all D xor I ops. (95% of human multi-token edits contain S).
|
| 99 |
+
if set(ops[start:end + 1]) == {"D"} or set(ops[start:end + 1]) == {"I"}:
|
| 100 |
+
return (process_seq(seq[:start], alignment)
|
| 101 |
+
+ merge_edits(seq[start:end + 1])
|
| 102 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 103 |
+
# Get the tokens in orig and cor.
|
| 104 |
+
o = alignment.orig[seq[start][1]:seq[end][2]]
|
| 105 |
+
c = alignment.cor[seq[start][3]:seq[end][4]]
|
| 106 |
+
if ops[start:end + 1] in [["M", "D", "M"], ["M", "I", "M"], ["M", "S", "M"]]:
|
| 107 |
+
# merge hyphens
|
| 108 |
+
if (o[start + 1].text == "-" or c[start + 1].text == "-") and len(o) != len(c):
|
| 109 |
+
return (process_seq(seq[:start], alignment)
|
| 110 |
+
+ merge_edits(seq[start:end + 1])
|
| 111 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 112 |
+
# if it is not a hyphen-space edit, return only punct edit
|
| 113 |
+
return seq[start + 1: end]
|
| 114 |
+
# Merge possessive suffixes: [friends -> friend 's]
|
| 115 |
+
if o[-1].tag_ == "POS" or c[-1].tag_ == "POS":
|
| 116 |
+
return (process_seq(seq[:end - 1], alignment)
|
| 117 |
+
+ merge_edits(seq[end - 1:end + 1])
|
| 118 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 119 |
+
# Case changes
|
| 120 |
+
if o[-1].lower == c[-1].lower:
|
| 121 |
+
# Merge first token I or D: [Cat -> The big cat]
|
| 122 |
+
if (start == 0 and
|
| 123 |
+
(len(o) == 1 and c[0].text[0].isupper()) or
|
| 124 |
+
(len(c) == 1 and o[0].text[0].isupper())):
|
| 125 |
+
return (merge_edits(seq[start:end + 1])
|
| 126 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 127 |
+
# Merge with previous punctuation: [, we -> . We], [we -> . We]
|
| 128 |
+
if (len(o) > 1 and is_punct(o[-2])) or \
|
| 129 |
+
(len(c) > 1 and is_punct(c[-2])):
|
| 130 |
+
return (process_seq(seq[:end - 1], alignment)
|
| 131 |
+
+ merge_edits(seq[end - 1:end + 1])
|
| 132 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 133 |
+
# Merge whitespace/hyphens: [acat -> a cat], [sub - way -> subway]
|
| 134 |
+
s_str = re.sub("['-]", "", "".join([tok.lower_ for tok in o]))
|
| 135 |
+
t_str = re.sub("['-]", "", "".join([tok.lower_ for tok in c]))
|
| 136 |
+
if s_str == t_str or s_str.replace(" ", "") == t_str.replace(" ", ""):
|
| 137 |
+
return (process_seq(seq[:start], alignment)
|
| 138 |
+
+ merge_edits(seq[start:end + 1])
|
| 139 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 140 |
+
# Merge same POS or auxiliary/infinitive/phrasal verbs:
|
| 141 |
+
# [to eat -> eating], [watch -> look at]
|
| 142 |
+
pos_set = set([tok.pos for tok in o] + [tok.pos for tok in c])
|
| 143 |
+
if len(o) != len(c) and (len(pos_set) == 1 or pos_set.issubset({"AUX", "PART", "VERB"})):
|
| 144 |
+
return (process_seq(seq[:start], alignment)
|
| 145 |
+
+ merge_edits(seq[start:end + 1])
|
| 146 |
+
+ process_seq(seq[end + 1:], alignment))
|
| 147 |
+
# Split rules take effect when we get to smallest chunks
|
| 148 |
+
if end - start < 2:
|
| 149 |
+
# Split adjacent substitutions
|
| 150 |
+
if len(o) == len(c) == 2:
|
| 151 |
+
return (process_seq(seq[:start + 1], alignment)
|
| 152 |
+
+ process_seq(seq[start + 1:], alignment))
|
| 153 |
+
# Split similar substitutions at sequence boundaries
|
| 154 |
+
if ((ops[start] == "S" and char_cost(o[0].text, c[0].text) > 0.75) or
|
| 155 |
+
(ops[end] == "S" and char_cost(o[-1].text, c[-1].text) > 0.75)):
|
| 156 |
+
return (process_seq(seq[:start + 1], alignment)
|
| 157 |
+
+ process_seq(seq[start + 1:], alignment))
|
| 158 |
+
# Split final determiners
|
| 159 |
+
if (end == len(seq) - 1 and
|
| 160 |
+
((ops[-1] in {"D", "S"} and o[-1].pos == "DET") or
|
| 161 |
+
(ops[-1] in {"I", "S"} and c[-1].pos == "DET"))):
|
| 162 |
+
return process_seq(seq[:-1], alignment) + [seq[-1]]
|
| 163 |
+
return seq
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def is_punct(token) -> bool:
|
| 167 |
+
return token.text in punctuation
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def char_cost(a: str, b: str) -> float:
|
| 171 |
+
"""Calculate the cost of character alignment; i.e. char similarity."""
|
| 172 |
+
|
| 173 |
+
return Levenshtein.ratio(a, b)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def merge_edits(seq: list[tuple]) -> list[tuple]:
|
| 177 |
+
"""Merge the input alignment sequence to a single edit span."""
|
| 178 |
+
|
| 179 |
+
if seq:
|
| 180 |
+
return [("X", seq[0][1], seq[-1][2], seq[0][3], seq[-1][4])]
|
| 181 |
+
return seq
|
requirements.txt
CHANGED
|
@@ -1 +1,4 @@
|
|
| 1 |
-
git+https://github.com/huggingface/evaluate@main
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/huggingface/evaluate@main
|
| 2 |
+
git+https://github.com/Askinkaty/errant/@4183e57
|
| 3 |
+
Levenshtein
|
| 4 |
+
ru-core-news-lg @ https://huggingface.co/spacy/ru_core_news_lg/resolve/main/ru_core_news_lg-any-py3-none-any.whl
|
ru_errant.py
CHANGED
|
@@ -12,11 +12,26 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
"""TODO: Add a description here."""
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
import evaluate
|
| 17 |
import datasets
|
| 18 |
|
| 19 |
-
|
| 20 |
# TODO: Add BibTeX citation
|
| 21 |
_CITATION = """\
|
| 22 |
@InProceedings{huggingface:module,
|
|
@@ -31,7 +46,6 @@ _DESCRIPTION = """\
|
|
| 31 |
This new module is designed to solve this great ML task and is crafted with a lot of care.
|
| 32 |
"""
|
| 33 |
|
| 34 |
-
|
| 35 |
# TODO: Add description of the arguments of the module here
|
| 36 |
_KWARGS_DESCRIPTION = """
|
| 37 |
Calculates how good are predictions given some references, using certain scores
|
|
@@ -57,6 +71,40 @@ Examples:
|
|
| 57 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 61 |
class RuErrant(evaluate.Metric):
|
| 62 |
"""TODO: Short description of my evaluation module."""
|
|
@@ -70,26 +118,77 @@ class RuErrant(evaluate.Metric):
|
|
| 70 |
citation=_CITATION,
|
| 71 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 72 |
# This defines the format of each prediction and reference
|
| 73 |
-
features=datasets.Features(
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
# Homepage of the module for documentation
|
| 78 |
homepage="http://module.homepage",
|
| 79 |
# Additional links to the codebase or references
|
| 80 |
-
codebase_urls=["
|
| 81 |
reference_urls=["http://path.to.reference.url/new_module"]
|
| 82 |
)
|
| 83 |
|
| 84 |
def _download_and_prepare(self, dl_manager):
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
"""TODO: Add a description here."""
|
| 15 |
+
from __future__ import annotations
|
| 16 |
|
| 17 |
+
import re
|
| 18 |
+
from collections import Counter, namedtuple
|
| 19 |
+
from typing import Iterable
|
| 20 |
+
from tqdm.auto import tqdm
|
| 21 |
+
|
| 22 |
+
from errant.annotator import Annotator
|
| 23 |
+
from errant.commands.compare_m2 import process_edits
|
| 24 |
+
from errant.commands.compare_m2 import evaluate_edits
|
| 25 |
+
from errant.commands.compare_m2 import merge_dict
|
| 26 |
+
from errant.edit import Edit
|
| 27 |
+
import spacy
|
| 28 |
+
from spacy.tokenizer import Tokenizer
|
| 29 |
+
from spacy.util import compile_prefix_regex, compile_infix_regex, compile_suffix_regex
|
| 30 |
+
import classifier
|
| 31 |
+
import merger
|
| 32 |
import evaluate
|
| 33 |
import datasets
|
| 34 |
|
|
|
|
| 35 |
# TODO: Add BibTeX citation
|
| 36 |
_CITATION = """\
|
| 37 |
@InProceedings{huggingface:module,
|
|
|
|
| 46 |
This new module is designed to solve this great ML task and is crafted with a lot of care.
|
| 47 |
"""
|
| 48 |
|
|
|
|
| 49 |
# TODO: Add description of the arguments of the module here
|
| 50 |
_KWARGS_DESCRIPTION = """
|
| 51 |
Calculates how good are predictions given some references, using certain scores
|
|
|
|
| 71 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
| 72 |
|
| 73 |
|
| 74 |
+
def update_spacy_tokenizer(nlp):
|
| 75 |
+
"""
|
| 76 |
+
Changes Spacy tokenizer to parse additional patterns.
|
| 77 |
+
"""
|
| 78 |
+
infix_re = compile_infix_regex(nlp.Defaults.infixes[:-1] + ["\]\("])
|
| 79 |
+
simple_url_re = re.compile(r'''^https?://''')
|
| 80 |
+
nlp.tokenizer = Tokenizer(
|
| 81 |
+
nlp.vocab,
|
| 82 |
+
prefix_search=compile_prefix_regex(nlp.Defaults.prefixes + ['\\\\\"']).search,
|
| 83 |
+
suffix_search=compile_suffix_regex(nlp.Defaults.suffixes + ['\\\\']).search,
|
| 84 |
+
infix_finditer=infix_re.finditer,
|
| 85 |
+
token_match=None,
|
| 86 |
+
url_match=simple_url_re.match
|
| 87 |
+
)
|
| 88 |
+
return nlp
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def annotate_errors(self, orig: str, cor: str, merging: str = "rules") -> list[Edit]:
|
| 92 |
+
"""
|
| 93 |
+
Overrides `Annotator.annotate()` function to allow multiple errors per token.
|
| 94 |
+
This is nesessary to parse combined errors, e.g.:
|
| 95 |
+
["werd", "Word"] >>> Errors: ["SPELL", "CASE"]
|
| 96 |
+
The `classify()` method called inside is implemented in ruerrant_classifier.py
|
| 97 |
+
(also overrides the original classifier).
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
alignment = self.annotator.align(orig, cor, False)
|
| 101 |
+
edits = self.annotator.merge(alignment, merging)
|
| 102 |
+
classified_edits = []
|
| 103 |
+
for edit in edits:
|
| 104 |
+
classified_edits.extend(self.annotator.classify(edit))
|
| 105 |
+
return sorted(classified_edits, key=lambda x: (x[0], x[2]))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 109 |
class RuErrant(evaluate.Metric):
|
| 110 |
"""TODO: Short description of my evaluation module."""
|
|
|
|
| 118 |
citation=_CITATION,
|
| 119 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 120 |
# This defines the format of each prediction and reference
|
| 121 |
+
features=datasets.Features(
|
| 122 |
+
{
|
| 123 |
+
"sources": datasets.Value("string", id="sequence"),
|
| 124 |
+
"corrections": datasets.Value("string", id="sequence"),
|
| 125 |
+
"answers": datasets.Value("string", id="sequence"),
|
| 126 |
+
}
|
| 127 |
+
),
|
| 128 |
# Homepage of the module for documentation
|
| 129 |
homepage="http://module.homepage",
|
| 130 |
# Additional links to the codebase or references
|
| 131 |
+
codebase_urls=["https://github.com/ai-forever/sage"],
|
| 132 |
reference_urls=["http://path.to.reference.url/new_module"]
|
| 133 |
)
|
| 134 |
|
| 135 |
def _download_and_prepare(self, dl_manager):
|
| 136 |
+
self.annotator = Annotator("ru",
|
| 137 |
+
nlp=update_spacy_tokenizer(spacy.load("ru_core_news_lg")),
|
| 138 |
+
merger=merger,
|
| 139 |
+
classifier=classifier)
|
| 140 |
+
|
| 141 |
+
def _compute(self, sources, corrections, answers):
|
| 142 |
+
"""
|
| 143 |
+
Evaluates iterables of sources, hyp and ref corrections with ERRANT metric.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
sources (Iterable[str]): an iterable of source texts;
|
| 147 |
+
corrections (Iterable[str]): an iterable of gold corrections for the source texts;
|
| 148 |
+
answers (Iterable[str]): an iterable of evaluated corrections for the source texts;
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
dict[str, tuple[float, ...]]: a dict mapping error categories to the corresponding
|
| 152 |
+
P, R, F1 metric values.
|
| 153 |
+
"""
|
| 154 |
+
best_dict = Counter({"tp": 0, "fp": 0, "fn": 0})
|
| 155 |
+
best_cats = {}
|
| 156 |
+
sents = zip(sources, corrections, answers)
|
| 157 |
+
pb = tqdm(sents, desc="Calculating errant metric", total=len(sources))
|
| 158 |
+
for sent_id, sent in enumerate(pb):
|
| 159 |
+
src = self.annotator.parse(sent[0])
|
| 160 |
+
ref = self.annotator.parse(sent[1])
|
| 161 |
+
hyp = self.annotator.parse(sent[2])
|
| 162 |
+
# Align hyp and ref corrections and annotate errors
|
| 163 |
+
hyp_edits = self.annotate_errors(src, hyp)
|
| 164 |
+
ref_edits = self.annotate_errors(src, ref)
|
| 165 |
+
# Process the edits for detection/correction based on args
|
| 166 |
+
ProcessingArgs = namedtuple("ProcessingArgs",
|
| 167 |
+
["dt", "ds", "single", "multi", "filt", "cse"],
|
| 168 |
+
defaults=[False, False, False, False, [], True])
|
| 169 |
+
processing_args = ProcessingArgs()
|
| 170 |
+
hyp_dict = process_edits(hyp_edits, processing_args)
|
| 171 |
+
ref_dict = process_edits(ref_edits, processing_args)
|
| 172 |
+
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
|
| 173 |
+
EvaluationArgs = namedtuple("EvaluationArgs",
|
| 174 |
+
["beta", "verbose"],
|
| 175 |
+
defaults=[1.0, False])
|
| 176 |
+
evaluation_args = EvaluationArgs()
|
| 177 |
+
count_dict, cat_dict = evaluate_edits(
|
| 178 |
+
hyp_dict, ref_dict, best_dict, sent_id, evaluation_args)
|
| 179 |
+
# Merge these dicts with best_dict and best_cats
|
| 180 |
+
best_dict += Counter(count_dict) # corpus-level TP, FP, FN
|
| 181 |
+
best_cats = merge_dict(best_cats, cat_dict) # corpus-level errortype-wise TP, FP, FN
|
| 182 |
+
cat_prf = {}
|
| 183 |
+
for cat, values in best_cats.items():
|
| 184 |
+
tp, fp, fn = values # fp - extra corrections, fn - missed corrections
|
| 185 |
+
p = float(tp) / (tp + fp) if tp + fp else 1.0
|
| 186 |
+
r = float(tp) / (tp + fn) if tp + fn else 1.0
|
| 187 |
+
f = (2 * p * r) / (p + r) if p + r else 0.0
|
| 188 |
+
cat_prf[cat] = (p, r, f)
|
| 189 |
+
|
| 190 |
+
for error_category in ["CASE", "PUNCT", "SPELL", "YO"]:
|
| 191 |
+
if error_category not in cat_prf:
|
| 192 |
+
cat_prf[error_category] = (1.0, 1.0, 1.0)
|
| 193 |
+
|
| 194 |
+
return cat_prf
|