Harshil Darji commited on
Commit
4a3eaed
·
1 Parent(s): 4cf33f6

update app

Browse files
Files changed (1) hide show
  1. app.py +69 -60
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import string
4
 
@@ -12,7 +11,8 @@ from transformers import (
12
  pipeline,
13
  )
14
 
15
- st.set_page_config(page_title="German Legal NER", page_icon="⚖️", layout="wide")
 
16
  logging.set_verbosity(logging.ERROR)
17
 
18
  st.markdown(
@@ -24,9 +24,7 @@ st.markdown(
24
  padding-left: 3rem;
25
  padding-right: 3rem;
26
  }
27
-
28
  header, footer {visibility: hidden;}
29
-
30
  .entity {
31
  position: relative;
32
  display: inline-block;
@@ -34,7 +32,6 @@ header, footer {visibility: hidden;}
34
  font-weight: normal;
35
  cursor: help;
36
  }
37
-
38
  .entity .tooltip {
39
  visibility: hidden;
40
  background-color: #333;
@@ -52,12 +49,10 @@ header, footer {visibility: hidden;}
52
  transition: opacity 0.05s;
53
  font-size: 11px;
54
  }
55
-
56
  .entity:hover .tooltip {
57
  visibility: visible;
58
  opacity: 1;
59
  }
60
-
61
  .entity.marked {
62
  background-color: rgba(255, 230, 0, 0.4);
63
  line-height: 1.3;
@@ -69,39 +64,32 @@ header, footer {visibility: hidden;}
69
  unsafe_allow_html=True,
70
  )
71
 
72
- # Load model
73
- tkn = os.getenv("tkn")
74
- tokenizer = AutoTokenizer.from_pretrained("harshildarji/JuraNER", use_auth_token=tkn)
75
- model = AutoModelForTokenClassification.from_pretrained(
76
- "harshildarji/JuraNER", use_auth_token=tkn
77
- )
78
- ner = pipeline("ner", model=model, tokenizer=tokenizer)
79
-
80
- # Entity labels
81
  entity_labels = {
82
- "AN": "Lawyer",
83
- "EUN": "European legal norm",
84
- "GRT": "Court",
85
- "GS": "Law",
86
  "INN": "Institution",
87
- "LD": "Country",
88
- "LDS": "Landscape",
89
- "LIT": "Legal literature",
90
- "MRK": "Brand",
91
- "ORG": "Organization",
92
  "PER": "Person",
93
- "RR": "Judge",
94
- "RS": "Court decision",
95
- "ST": "City",
96
- "STR": "Street",
97
- "UN": "Company",
98
- "VO": "Ordinance",
99
- "VS": "Regulation",
100
- "VT": "Contract",
 
101
  }
102
 
103
 
104
- # Fixed colors
105
  def generate_fixed_colors(keys, alpha=0.25):
106
  cmap = cm.get_cmap("tab20", len(keys))
107
  rgba_colors = {}
@@ -112,16 +100,35 @@ def generate_fixed_colors(keys, alpha=0.25):
112
  return rgba_colors
113
 
114
 
115
- ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()), alpha=0.30)
116
 
117
- # UI
118
- st.markdown("#### German Legal NER")
119
- uploaded_file = st.file_uploader("Upload a .txt file", type="txt")
120
- threshold = st.slider("Confidence threshold:", 0.0, 1.0, 0.8, 0.01)
121
- st.markdown("---")
122
 
 
 
 
 
 
 
 
 
123
 
124
- # Merge logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def merge_entities(entities):
126
  if not entities:
127
  return []
@@ -135,10 +142,7 @@ def merge_entities(entities):
135
  prev = merged[-1]
136
  if ent["index"] == prev["index"] + 1:
137
  tok = ent["word"]
138
- if tok.startswith("##"):
139
- prev["word"] += tok[2:]
140
- else:
141
- prev["word"] += " " + tok
142
  prev["end"] = ent["end"]
143
  prev["index"] = ent["index"]
144
  prev["score_sum"] += ent["score"]
@@ -172,7 +176,7 @@ def merge_entities(entities):
172
  return final
173
 
174
 
175
- # HTML highlighting
176
  def highlight_entities(line, merged_entities, threshold):
177
  html = ""
178
  last_end = 0
@@ -200,24 +204,29 @@ def highlight_entities(line, merged_entities, threshold):
200
  return html
201
 
202
 
 
 
 
 
 
 
203
  if uploaded_file:
204
  raw_bytes = uploaded_file.read()
205
  encoding = detect(raw_bytes)["encoding"]
206
  if encoding is None:
207
- st.error("Could not detect file encoding.")
208
  else:
209
  text = raw_bytes.decode(encoding)
210
 
211
- with st.spinner("Processing..."):
212
- for line in text.splitlines():
213
- if not line.strip():
214
- st.write("")
215
- continue
216
-
217
- tokens = ner(line)
218
- merged = merge_entities(tokens)
219
- html_line = highlight_entities(line, merged, threshold)
220
- st.markdown(
221
- f'<div style="margin:0;padding:0;line-height:1.7;">{html_line}</div>',
222
- unsafe_allow_html=True,
223
- )
 
 
1
  import re
2
  import string
3
 
 
11
  pipeline,
12
  )
13
 
14
+ # Streamlit page setup
15
+ st.set_page_config(page_title="Juristische NER", page_icon="⚖️", layout="wide")
16
  logging.set_verbosity(logging.ERROR)
17
 
18
  st.markdown(
 
24
  padding-left: 3rem;
25
  padding-right: 3rem;
26
  }
 
27
  header, footer {visibility: hidden;}
 
28
  .entity {
29
  position: relative;
30
  display: inline-block;
 
32
  font-weight: normal;
33
  cursor: help;
34
  }
 
35
  .entity .tooltip {
36
  visibility: hidden;
37
  background-color: #333;
 
49
  transition: opacity 0.05s;
50
  font-size: 11px;
51
  }
 
52
  .entity:hover .tooltip {
53
  visibility: visible;
54
  opacity: 1;
55
  }
 
56
  .entity.marked {
57
  background-color: rgba(255, 230, 0, 0.4);
58
  line-height: 1.3;
 
64
  unsafe_allow_html=True,
65
  )
66
 
67
+ # Entity label mapping
 
 
 
 
 
 
 
 
68
  entity_labels = {
69
+ "AN": "Rechtsbeistand",
70
+ "EUN": "EUNorm",
71
+ "GRT": "Gericht",
72
+ "GS": "Norm",
73
  "INN": "Institution",
74
+ "LD": "Land",
75
+ "LDS": "Bezirk",
76
+ "LIT": "Schrifttum",
77
+ "MRK": "Marke",
78
+ "ORG": "Organisation",
79
  "PER": "Person",
80
+ "RR": "RichterIn",
81
+ "RS": "Entscheidung",
82
+ "ST": "Stadt",
83
+ "STR": "Strasse",
84
+ "UN": "Unternehmen",
85
+ "VO": "Verordnung",
86
+ "VS": "Richtlinie",
87
+ "VT": "Vertrag",
88
+ "RED": "Schwärzung",
89
  }
90
 
91
 
92
+ # Color generator
93
  def generate_fixed_colors(keys, alpha=0.25):
94
  cmap = cm.get_cmap("tab20", len(keys))
95
  rgba_colors = {}
 
100
  return rgba_colors
101
 
102
 
103
+ ENTITY_COLORS = generate_fixed_colors(list(entity_labels.keys()))
104
 
 
 
 
 
 
105
 
106
+ # Caching model
107
+ @st.cache_resource
108
+ def load_ner_pipeline():
109
+ return pipeline(
110
+ "ner",
111
+ model=AutoModelForTokenClassification.from_pretrained("harshildarji/JuraNER"),
112
+ tokenizer=AutoTokenizer.from_pretrained("harshildarji/JuraNER"),
113
+ )
114
 
115
+
116
+ # Caching NER + merge per line
117
+ @st.cache_data(show_spinner=False)
118
+ def get_ner_merged_lines(text):
119
+ ner = load_ner_pipeline()
120
+ results = []
121
+ for line in text.splitlines():
122
+ if not line.strip():
123
+ results.append(("", []))
124
+ continue
125
+ tokens = ner(line)
126
+ merged = merge_entities(tokens)
127
+ results.append((line, merged))
128
+ return results
129
+
130
+
131
+ # Entity merging
132
  def merge_entities(entities):
133
  if not entities:
134
  return []
 
142
  prev = merged[-1]
143
  if ent["index"] == prev["index"] + 1:
144
  tok = ent["word"]
145
+ prev["word"] += tok[2:] if tok.startswith("##") else " " + tok
 
 
 
146
  prev["end"] = ent["end"]
147
  prev["index"] = ent["index"]
148
  prev["score_sum"] += ent["score"]
 
176
  return final
177
 
178
 
179
+ # Highlighting
180
  def highlight_entities(line, merged_entities, threshold):
181
  html = ""
182
  last_end = 0
 
204
  return html
205
 
206
 
207
+ # UI
208
+ st.markdown("#### Juristische Named Entity Recognition (NER)")
209
+ uploaded_file = st.file_uploader("Bitte laden Sie eine .txt-Datei hoch:", type="txt")
210
+ threshold = st.slider("Schwellenwert für das Modellvertrauen:", 0.0, 1.0, 0.8, 0.01)
211
+ st.markdown("---")
212
+
213
  if uploaded_file:
214
  raw_bytes = uploaded_file.read()
215
  encoding = detect(raw_bytes)["encoding"]
216
  if encoding is None:
217
+ st.error("Zeichenkodierung konnte nicht erkannt werden.")
218
  else:
219
  text = raw_bytes.decode(encoding)
220
 
221
+ with st.spinner("Modell wird auf jede Zeile angewendet..."):
222
+ merged_all_lines = get_ner_merged_lines(text)
223
+
224
+ for line, merged in merged_all_lines:
225
+ if not line.strip():
226
+ continue
227
+
228
+ html_line = highlight_entities(line, merged, threshold)
229
+ st.markdown(
230
+ f'<div style="margin-bottom:0.8rem; line-height:1.7;">{html_line}</div>',
231
+ unsafe_allow_html=True,
232
+ )