Spaces:
Running
Running
Harshil Darji
commited on
Commit
·
4a3eaed
1
Parent(s):
4cf33f6
update app
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import re
|
3 |
import string
|
4 |
|
@@ -12,7 +11,8 @@ from transformers import (
|
|
12 |
pipeline,
|
13 |
)
|
14 |
|
15 |
-
|
|
|
16 |
logging.set_verbosity(logging.ERROR)
|
17 |
|
18 |
st.markdown(
|
@@ -24,9 +24,7 @@ st.markdown(
|
|
24 |
padding-left: 3rem;
|
25 |
padding-right: 3rem;
|
26 |
}
|
27 |
-
|
28 |
header, footer {visibility: hidden;}
|
29 |
-
|
30 |
.entity {
|
31 |
position: relative;
|
32 |
display: inline-block;
|
@@ -34,7 +32,6 @@ header, footer {visibility: hidden;}
|
|
34 |
font-weight: normal;
|
35 |
cursor: help;
|
36 |
}
|
37 |
-
|
38 |
.entity .tooltip {
|
39 |
visibility: hidden;
|
40 |
background-color: #333;
|
@@ -52,12 +49,10 @@ header, footer {visibility: hidden;}
|
|
52 |
transition: opacity 0.05s;
|
53 |
font-size: 11px;
|
54 |
}
|
55 |
-
|
56 |
.entity:hover .tooltip {
|
57 |
visibility: visible;
|
58 |
opacity: 1;
|
59 |
}
|
60 |
-
|
61 |
.entity.marked {
|
62 |
background-color: rgba(255, 230, 0, 0.4);
|
63 |
line-height: 1.3;
|
@@ -69,39 +64,32 @@ header, footer {visibility: hidden;}
|
|
69 |
unsafe_allow_html=True,
|
70 |
)
|
71 |
|
72 |
-
#
|
73 |
-
tkn = os.getenv("tkn")
|
74 |
-
tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
|
75 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
76 |
-
"harshildarji/JuraNER", use_auth_token=tkn
|
77 |
-
)
|
78 |
-
ner = pipeline("ner", model=model, tokenizer=tokenizer)
|
79 |
-
|
80 |
-
# Entity labels
|
81 |
entity_labels = {
|
82 |
-
"AN": "
|
83 |
-
"EUN": "
|
84 |
-
"GRT": "
|
85 |
-
"GS": "
|
86 |
"INN": "Institution",
|
87 |
-
"LD": "
|
88 |
-
"LDS": "
|
89 |
-
"LIT": "
|
90 |
-
"MRK": "
|
91 |
-
"ORG": "
|
92 |
"PER": "Person",
|
93 |
-
"RR": "
|
94 |
-
"RS": "
|
95 |
-
"ST": "
|
96 |
-
"STR": "
|
97 |
-
"UN": "
|
98 |
-
"VO": "
|
99 |
-
"VS": "
|
100 |
-
"VT": "
|
|
|
101 |
}
|
102 |
|
103 |
|
104 |
-
#
|
105 |
def generate_fixed_colors(keys, alpha=0.25):
|
106 |
cmap = cm.get_cmap("tab20", len(keys))
|
107 |
rgba_colors = {}
|
@@ -112,16 +100,35 @@ def generate_fixed_colors(keys, alpha=0.25):
|
|
112 |
return rgba_colors
|
113 |
|
114 |
|
115 |
-
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys())
|
116 |
|
117 |
-
# UI
|
118 |
-
st.markdown("#### German Legal NER")
|
119 |
-
uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
|
120 |
-
threshold = st.slider("Confidence threshold:", 0.0, 1.0, 0.8, 0.01)
|
121 |
-
st.markdown("---")
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
def merge_entities(entities):
|
126 |
if not entities:
|
127 |
return []
|
@@ -135,10 +142,7 @@ def merge_entities(entities):
|
|
135 |
prev = merged[-1]
|
136 |
if ent["index"] == prev["index"] + 1:
|
137 |
tok = ent["word"]
|
138 |
-
if tok.startswith("##")
|
139 |
-
prev["word"] += tok[2:]
|
140 |
-
else:
|
141 |
-
prev["word"] += " " + tok
|
142 |
prev["end"] = ent["end"]
|
143 |
prev["index"] = ent["index"]
|
144 |
prev["score_sum"] += ent["score"]
|
@@ -172,7 +176,7 @@ def merge_entities(entities):
|
|
172 |
return final
|
173 |
|
174 |
|
175 |
-
#
|
176 |
def highlight_entities(line, merged_entities, threshold):
|
177 |
html = ""
|
178 |
last_end = 0
|
@@ -200,24 +204,29 @@ def highlight_entities(line, merged_entities, threshold):
|
|
200 |
return html
|
201 |
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
if uploaded_file:
|
204 |
raw_bytes = uploaded_file.read()
|
205 |
encoding = detect(raw_bytes)["encoding"]
|
206 |
if encoding is None:
|
207 |
-
st.error("
|
208 |
else:
|
209 |
text = raw_bytes.decode(encoding)
|
210 |
|
211 |
-
with st.spinner("
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
)
|
|
|
|
|
1 |
import re
|
2 |
import string
|
3 |
|
|
|
11 |
pipeline,
|
12 |
)
|
13 |
|
14 |
+
# Streamlit page setup
|
15 |
+
st.set_page_config(page_title="Juristische NER", page_icon="⚖️", layout="wide")
|
16 |
logging.set_verbosity(logging.ERROR)
|
17 |
|
18 |
st.markdown(
|
|
|
24 |
padding-left: 3rem;
|
25 |
padding-right: 3rem;
|
26 |
}
|
|
|
27 |
header, footer {visibility: hidden;}
|
|
|
28 |
.entity {
|
29 |
position: relative;
|
30 |
display: inline-block;
|
|
|
32 |
font-weight: normal;
|
33 |
cursor: help;
|
34 |
}
|
|
|
35 |
.entity .tooltip {
|
36 |
visibility: hidden;
|
37 |
background-color: #333;
|
|
|
49 |
transition: opacity 0.05s;
|
50 |
font-size: 11px;
|
51 |
}
|
|
|
52 |
.entity:hover .tooltip {
|
53 |
visibility: visible;
|
54 |
opacity: 1;
|
55 |
}
|
|
|
56 |
.entity.marked {
|
57 |
background-color: rgba(255, 230, 0, 0.4);
|
58 |
line-height: 1.3;
|
|
|
64 |
unsafe_allow_html=True,
|
65 |
)
|
66 |
|
67 |
+
# Entity label mapping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
entity_labels = {
|
69 |
+
"AN": "Rechtsbeistand",
|
70 |
+
"EUN": "EUNorm",
|
71 |
+
"GRT": "Gericht",
|
72 |
+
"GS": "Norm",
|
73 |
"INN": "Institution",
|
74 |
+
"LD": "Land",
|
75 |
+
"LDS": "Bezirk",
|
76 |
+
"LIT": "Schrifttum",
|
77 |
+
"MRK": "Marke",
|
78 |
+
"ORG": "Organisation",
|
79 |
"PER": "Person",
|
80 |
+
"RR": "RichterIn",
|
81 |
+
"RS": "Entscheidung",
|
82 |
+
"ST": "Stadt",
|
83 |
+
"STR": "Strasse",
|
84 |
+
"UN": "Unternehmen",
|
85 |
+
"VO": "Verordnung",
|
86 |
+
"VS": "Richtlinie",
|
87 |
+
"VT": "Vertrag",
|
88 |
+
"RED": "Schwärzung",
|
89 |
}
|
90 |
|
91 |
|
92 |
+
# Color generator
|
93 |
def generate_fixed_colors(keys, alpha=0.25):
|
94 |
cmap = cm.get_cmap("tab20", len(keys))
|
95 |
rgba_colors = {}
|
|
|
100 |
return rgba_colors
|
101 |
|
102 |
|
103 |
+
ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()))
|
104 |
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
+
# Caching model
|
107 |
+
@st.cache_resource
|
108 |
+
def load_ner_pipeline():
|
109 |
+
return pipeline(
|
110 |
+
"ner",
|
111 |
+
model=AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER"),
|
112 |
+
tokenizer=AutoTokenizer.from_pretrained("harshildarji/JuraNER"),
|
113 |
+
)
|
114 |
|
115 |
+
|
116 |
+
# Caching NER + merge per line
|
117 |
+
@st.cache_data(show_spinner=False)
|
118 |
+
def get_ner_merged_lines(text):
|
119 |
+
ner = load_ner_pipeline()
|
120 |
+
results = []
|
121 |
+
for line in text.splitlines():
|
122 |
+
if not line.strip():
|
123 |
+
results.append(("", []))
|
124 |
+
continue
|
125 |
+
tokens = ner(line)
|
126 |
+
merged = merge_entities(tokens)
|
127 |
+
results.append((line, merged))
|
128 |
+
return results
|
129 |
+
|
130 |
+
|
131 |
+
# Entity merging
|
132 |
def merge_entities(entities):
|
133 |
if not entities:
|
134 |
return []
|
|
|
142 |
prev = merged[-1]
|
143 |
if ent["index"] == prev["index"] + 1:
|
144 |
tok = ent["word"]
|
145 |
+
prev["word"] += tok[2:] if tok.startswith("##") else " " + tok
|
|
|
|
|
|
|
146 |
prev["end"] = ent["end"]
|
147 |
prev["index"] = ent["index"]
|
148 |
prev["score_sum"] += ent["score"]
|
|
|
176 |
return final
|
177 |
|
178 |
|
179 |
+
# Highlighting
|
180 |
def highlight_entities(line, merged_entities, threshold):
|
181 |
html = ""
|
182 |
last_end = 0
|
|
|
204 |
return html
|
205 |
|
206 |
|
207 |
+
# UI
|
208 |
+
st.markdown("#### Juristische Named Entity Recognition (NER)")
|
209 |
+
uploaded_file = st.file_uploader("Bitte laden Sie eine .txt-Datei hoch:", type="txt")
|
210 |
+
threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.8, 0.01)
|
211 |
+
st.markdown("---")
|
212 |
+
|
213 |
if uploaded_file:
|
214 |
raw_bytes = uploaded_file.read()
|
215 |
encoding = detect(raw_bytes)["encoding"]
|
216 |
if encoding is None:
|
217 |
+
st.error("Zeichenkodierung konnte nicht erkannt werden.")
|
218 |
else:
|
219 |
text = raw_bytes.decode(encoding)
|
220 |
|
221 |
+
with st.spinner("Modell wird auf jede Zeile angewendet..."):
|
222 |
+
merged_all_lines = get_ner_merged_lines(text)
|
223 |
+
|
224 |
+
for line, merged in merged_all_lines:
|
225 |
+
if not line.strip():
|
226 |
+
continue
|
227 |
+
|
228 |
+
html_line = highlight_entities(line, merged, threshold)
|
229 |
+
st.markdown(
|
230 |
+
f'<div style="margin-bottom:0.8rem; line-height:1.7;">{html_line}</div>',
|
231 |
+
unsafe_allow_html=True,
|
232 |
+
)
|
|