madhavkotecha commited on
Commit
4335938
·
verified ·
1 Parent(s): 392c154

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -0
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import sklearn_crfsuite
3
+ from sklearn_crfsuite import metrics
4
+ from nltk.stem import LancasterStemmer
5
+ import numpy as np
6
+ from sklearn.metrics import confusion_matrix
7
+ import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+ import re
10
+ import gradio as gr
11
+ lancaster = LancasterStemmer()
12
+
13
+ nltk.download('brown')
14
+ nltk.download('universal_tagset')
15
+
16
+ class CRF_POS_Tagger:
17
+ def __init__(self, train=False):
18
+ print("Loading Data...")
19
+ self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
20
+ print("Data Loaded...")
21
+ self.corpus = [[(word, tag) for word, tag in sentence] for sentence in self.corpus]
22
+ self.actual_tag = []
23
+ self.predicted_tag = []
24
+ self.prefixes = [
25
+ "a", "anti", "auto", "bi", "co", "dis", "en", "em", "ex", "in", "im",
26
+ "inter", "mis", "non", "over", "pre", "re", "sub", "trans", "un", "under"
27
+ ]
28
+
29
+ self.suffixes = [
30
+ "able", "ible", "al", "ance", "ence", "dom", "er", "or", "ful", "hood",
31
+ "ic", "ing", "ion", "tion", "ity", "ty", "ive", "less", "ly", "ment",
32
+ "ness", "ous", "ship", "y", "es", "s"
33
+ ]
34
+
35
+ self.prefix_pattern = f"^({'|'.join(self.prefixes)})"
36
+ self.suffix_pattern = f"({'|'.join(self.suffixes)})$"
37
+
38
+ self.X = [[self.word_features(sentence, i) for i in range(len(sentence))] for sentence in self.corpus]
39
+ self.y = [[postag for _, postag in sentence] for sentence in self.corpus]
40
+
41
+ self.split = int(0.8 * len(self.X))
42
+ self.X_train = self.X[:self.split]
43
+ self.y_train = self.y[:self.split]
44
+ self.X_test = self.X[self.split:]
45
+ self.y_test = self.y[self.split:]
46
+ print("Data Loaded...")
47
+ self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
48
+ print("Model Created...")
49
+ if train:
50
+ self.train()
51
+
52
+ def word_splitter(self, word):
53
+ prefix = ""
54
+ stem = word
55
+ suffix = ""
56
+
57
+ prefix_match = re.match(self.prefix_pattern, word)
58
+ if prefix_match:
59
+ prefix = prefix_match.group(1)
60
+ stem = word[len(prefix):]
61
+
62
+ suffix_match = re.search(self.suffix_pattern, stem)
63
+ if suffix_match:
64
+ suffix = suffix_match.group(1)
65
+ stem = stem[: -len(suffix)]
66
+
67
+ return prefix, stem, suffix
68
+
69
+ # Define a function to extract features for each word in a sentence
70
+ def word_features(self, sentence, i):
71
+ word = sentence[i][0]
72
+ prefix, stem, suffix = self.word_splitter(word)
73
+ # features = {
74
+ # 'word': word,
75
+ # 'prefix': prefix,
76
+ # # 'stem': stem,
77
+ # 'stem': lancaster.stem(word),
78
+ # 'suffix': suffix,
79
+ # 'position': i,
80
+ # 'is_first': i == 0, #if the word is a first word
81
+ # 'is_last': i == len(sentence) - 1, #if the word is a last word
82
+ # # 'is_capitalized': word[0].upper() == word[0],
83
+ # 'is_all_caps': word.isupper(), #word is in uppercase
84
+ # 'is_all_lower': word.islower(), #word is in lowercase
85
+
86
+ # 'prefix-1': word[0],
87
+ # 'prefix-2': word[:2],
88
+ # 'prefix-3': word[:3],
89
+ # 'suffix-1': word[-1],
90
+ # 'suffix-2': word[-2:],
91
+ # 'suffix-3': word[-3:],
92
+
93
+ # 'prefix-un': word[:2] == 'un', #if word starts with un
94
+ # 'prefix-re': word[:2] == 're', #if word starts with re
95
+ # 'prefix-over': word[:4] == 'over', #if word starts with over
96
+ # 'prefix-dis': word[:4] == 'dis', #if word starts with dis
97
+ # 'prefix-mis': word[:4] == 'mis', #if word starts with mis
98
+ # 'prefix-pre': word[:4] == 'pre', #if word starts with pre
99
+ # 'prefix-non': word[:4] == 'non', #if word starts with non
100
+ # 'prefix-de': word[:3] == 'de', #if word starts with de
101
+ # 'prefix-in': word[:3] == 'in', #if word starts with in
102
+ # 'prefix-en': word[:3] == 'en', #if word starts with en
103
+
104
+ # 'suffix-ed': word[-2:] == 'ed', #if word ends with ed
105
+ # 'suffix-ing': word[-3:] == 'ing', #if word ends with ing
106
+ # 'suffix-es': word[-2:] == 'es', #if word ends with es
107
+ # 'suffix-ly': word[-2:] == 'ly', #if word ends with ly
108
+ # 'suffix-ment': word[-4:] == 'ment', #if word ends with ment
109
+ # 'suffix-er': word[-2:] == 'er', #if word ends with er
110
+ # 'suffix-ive': word[-3:] == 'ive',
111
+ # 'suffix-ous': word[-3:] == 'ous',
112
+ # 'suffix-ness': word[-4:] == 'ness',
113
+ # 'ends_with_s': word[-1] == 's',
114
+ # 'ends_with_es': word[-2:] == 'es',
115
+
116
+ # 'has_hyphen': '-' in word, #if word has hypen
117
+ # 'is_numeric': word.isdigit(), #if word is in numeric
118
+ # 'capitals_inside': word[1:].lower() != word[1:],
119
+ # 'is_title_case': word.istitle(), #if first letter is in uppercase
120
+
121
+ # }
122
+ # features = {
123
+ # 'word': word,
124
+ # 'prefix': prefix,
125
+ # 'stem': lancaster.stem(word),
126
+ # 'suffix': suffix,
127
+ # 'position': i,
128
+ # 'is_first': i == 0,
129
+ # 'is_last': i == len(sentence) - 1,
130
+ # 'is_all_caps': word.isupper(),
131
+ # 'is_all_lower': word.islower(),
132
+
133
+ # 'prev_word': sentence[i-1][0] if i > 0 else "<START>",
134
+ # 'next_word': sentence[i+1][0] if i < len(sentence) - 1 else "<END>",
135
+ # 'prev_is_capitalized': sentence[i-1][0].istitle() if i > 0 else False,
136
+ # 'next_is_capitalized': sentence[i+1][0].istitle() if i < len(sentence) - 1 else False,
137
+ # 'prev_is_numeric': sentence[i-1][0].isdigit() if i > 0 else False,
138
+ # 'next_is_numeric': sentence[i+1][0].isdigit() if i < len(sentence) - 1 else False,
139
+ # 'prev_suffix': self.word_splitter(sentence[i-1][0])[2] if i > 0 else "<START>",
140
+ # 'next_suffix': self.word_splitter(sentence[i+1][0])[2] if i < len(sentence) - 1 else "<END>",
141
+ # 'prev_prefix': self.word_splitter(sentence[i-1][0])[0] if i > 0 else "<START>",
142
+ # 'next_prefix': self.word_splitter(sentence[i+1][0])[0] if i < len(sentence) - 1 else "<END>",
143
+ # }
144
+
145
+ features = {
146
+ 'word': word,
147
+ 'is_first': i == 0,
148
+ 'is_last': i == len(sentence) - 1,
149
+ 'is_capitalized': word[0].upper() == word[0],
150
+ 'is_all_caps': word.upper() == word,
151
+ 'is_all_lower': word.lower() == word,
152
+
153
+ 'prefix-1': word[0],
154
+ 'prefix-2': word[:2],
155
+ 'prefix-3': word[:3],
156
+ 'suffix-1': word[-1],
157
+ 'suffix-2': word[-2:],
158
+ 'suffix-3': word[-3:],
159
+
160
+ 'prev_word': '' if i == 0 else sentence[i-1][0],
161
+ 'next_word': '' if i == len(sentence)-1 else sentence[i+1][0],
162
+
163
+ 'has_hyphen': '-' in word,
164
+ 'is_numeric': word.isdigit(),
165
+ 'capitals_inside': word[1:].lower() != word[1:]
166
+ }
167
+
168
+
169
+ if i > 0:
170
+ # prev_word, prev_postag = sentence[i-1]
171
+ prev_word = sentence[i-1][0]
172
+ prev_prefix, prev_stem, prev_suffix = self.word_splitter(prev_word)
173
+
174
+ features.update({
175
+ 'prev_word': prev_word,
176
+ # 'prev_postag': prev_postag,
177
+ 'prev_prefix': prev_prefix,
178
+ 'prev_stem': lancaster.stem(prev_word),
179
+ 'prev_suffix': prev_suffix,
180
+ 'prev:is_all_caps': prev_word.isupper(),
181
+ 'prev:is_all_lower': prev_word.islower(),
182
+ 'prev:is_numeric': prev_word.isdigit(),
183
+ 'prev:is_title_case': prev_word.istitle(),
184
+ })
185
+
186
+ if i < len(sentence)-1:
187
+ next_word = sentence[i-1][0]
188
+ next_prefix, next_stem, next_suffix = self.word_splitter(next_word)
189
+ features.update({
190
+ 'next_word': next_word,
191
+ 'next_prefix': next_prefix,
192
+ 'next_stem': lancaster.stem(next_word),
193
+ 'next_suffix': next_suffix,
194
+ 'next:is_all_caps': next_word.isupper(),
195
+ 'next:is_all_lower': next_word.islower(),
196
+ 'next:is_numeric': next_word.isdigit(),
197
+ 'next:is_title_case': next_word.istitle(),
198
+ })
199
+
200
+ return features
201
+
202
+ def train(self, data=None):
203
+ if data:
204
+ X_train, y_train = zip(*data)
205
+ else:
206
+ X_train, y_train = self.X_train, self.y_train
207
+
208
+ print("Training CRF Model...", len(self.X_train), len(self.y_train))
209
+
210
+ # Ensure X_train is a list of lists of dictionaries
211
+ X_train = [list(map(dict, x)) for x in X_train]
212
+ self.crf_model.fit(X_train, y_train)
213
+
214
+ def predict(self, X_test):
215
+ return self.crf_model.predict(X_test)
216
+
217
+ def accuracy(self, test_data):
218
+ X_test, y_test = zip(*test_data)
219
+ y_pred = self.predict(X_test)
220
+ self.actual_tag.extend([item for sublist in y_test for item in sublist])
221
+ self.predicted_tag.extend([item for sublist in y_pred for item in sublist])
222
+ print(len(self.actual_tag), len(self.predicted_tag))
223
+ return metrics.flat_accuracy_score(y_test, y_pred)
224
+
225
+ def cross_validation(self):
226
+ validator = CRF_POS_Tagger()
227
+ data = list(zip(self.X, self.y))
228
+ print("Cross-Validation...")
229
+ accuracies = []
230
+ for i in range(5):
231
+ n1 = int(i / 5.0 * len(data))
232
+ n2 = int((i + 1) / 5.0 * len(data))
233
+ test_data = data[n1:n2]
234
+ train_data = data[:n1] + data[n2:]
235
+ validator.train(train_data)
236
+ acc = validator.accuracy(test_data)
237
+ accuracies.append(acc)
238
+ self.actual_tag = validator.actual_tag
239
+ self.predicted_tag = validator.predicted_tag
240
+ return accuracies, sum(accuracies) / 5.0
241
+
242
+ def con_matrix(self):
243
+ self.labels = np.unique(self.actual_tag)
244
+ print(self.labels, self.actual_tag, self.predicted_tag)
245
+ conf_matrix = confusion_matrix(self.actual_tag, self.predicted_tag, labels=self.labels)
246
+ normalized_matrix = conf_matrix/np.sum(conf_matrix, axis=1, keepdims=True)
247
+ plt.figure(figsize=(10, 7))
248
+ sns.heatmap(normalized_matrix, annot=True, fmt='.2f', cmap='Blues', xticklabels=self.labels, yticklabels=self.labels)
249
+ plt.xlabel('Predicted Tags')
250
+ plt.ylabel('Actual Tags')
251
+ plt.title('Confusion Matrix Heatmap')
252
+ plt.savefig("Confusion_matrix.png")
253
+ plt.show()
254
+
255
+ return normalized_matrix
256
+
257
+ def per_pos_accuracy(self, conf_matrix):
258
+ print("Per Tag Precision, Recall, and F-Score:")
259
+ per_tag_metrics = {}
260
+
261
+ for i, tag in enumerate(self.labels):
262
+ true_positives = conf_matrix[i, i]
263
+ false_positives = np.sum(conf_matrix[:, i]) - true_positives
264
+ false_negatives = np.sum(conf_matrix[i, :]) - true_positives
265
+
266
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
267
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
268
+ f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
269
+ beta_0_5 = 0.5
270
+ beta_2 = 2.0
271
+
272
+ f0_5_score = (1 + beta_0_5**2) * (precision * recall) / ((beta_0_5**2 * precision) + recall) if (precision + recall) > 0 else 0
273
+ f2_score = (1 + beta_2**2) * (precision * recall) / ((beta_2**2 * precision) + recall) if (precision + recall) > 0 else 0
274
+
275
+ per_tag_metrics[tag] = {
276
+ 'Precision': precision,
277
+ 'Recall': recall,
278
+ 'f1-Score': f1_score,
279
+ 'f05-Score': f0_5_score,
280
+ 'f2-Score': f2_score
281
+ }
282
+
283
+ print(f"{tag}: Precision = {precision:.2f}, Recall = {recall:.2f}, f1-Score = {f1_score:.2f}, "
284
+ f"f05-Score = {f0_5_score:.2f}, f2-Score = {f2_score:.2f}")
285
+
286
+ def tagging(self, input):
287
+ sentence = (re.sub(r'(\S)([.,;:!?])', r'\1 \2', input.strip())).split()
288
+ sentence_list = [[word] for word in sentence]
289
+ features = [self.word_features(sentence_list, i) for i in range(len(sentence_list))]
290
+
291
+ predicted_tags = self.crf_model.predict([features])
292
+ output = "".join(f"{sentence[i]}[{predicted_tags[0][i]}] " for i in range(len(sentence)))
293
+ return output
294
+
295
+
296
+ validate = CRF_POS_Tagger()
297
+ accuracies, avg_accuracy = validate.cross_validation()
298
+ print(f"Cross-Validation Accuracies: {accuracies}")
299
+ print(f"Average Accuracy: {avg_accuracy}")
300
+
301
+ conf_matrix = validate.con_matrix()
302
+ print(validate.per_pos_accuracy(conf_matrix))
303
+
304
+ tagger = CRF_POS_Tagger(True)
305
+ interface = gr.Interface(fn = tagger.tagging,
306
+ inputs = gr.Textbox(
307
+ label="Input Sentence",
308
+ placeholder="Enter your sentence here...",
309
+ ),
310
+ outputs = gr.Textbox(
311
+ label="Tagged Output",
312
+ placeholder="Tagged sentence appears here...",
313
+ ),
314
+ title = "Conditional Random Field POS Tagger",
315
+ description = "CS626 Assignment 1B (Autumn 2024)",
316
+ theme=gr.themes.Soft())
317
+ interface.launch(inline = False, share = True)