thankrandomness commited on
Commit
a3db2dc
·
1 Parent(s): 6d03bc9

calculate avg_similarity

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -106,31 +106,9 @@ def retrieve_relevant_text(input_text, similarity_threshold=0.1): # Lower thres
106
  def evaluate_efficiency(dataset_split, similarity_threshold=0.1):
107
  y_true = []
108
  y_pred = []
109
- # texts = [] # To store texts for debugging
110
-
111
- # for i, row in enumerate(dataset_split):
112
- # for note in row['notes']:
113
- # text = note.get('text', '')
114
- # annotations_list = [annotation['code'] for annotation in note.get('annotations', []) if 'code' in annotation]
115
-
116
- # if text and annotations_list:
117
- # # Store the original text for each entry in y_true for debugging
118
- # texts.append(text)
119
- # y_true.extend(annotations_list)
120
-
121
- # # Retrieve predictions for the current text
122
- # retrieved_results = retrieve_relevant_text(text, similarity_threshold=similarity_threshold)
123
- # retrieved_codes = [result['code'] for result in retrieved_results]
124
-
125
- # # Limit predictions to the length of true annotations to ensure consistent lengths
126
- # y_pred.extend(retrieved_codes[:len(annotations_list)])
127
-
128
- # # Debugging output
129
- # for idx, (text, true_codes, pred_codes) in enumerate(zip(texts, y_true, y_pred)):
130
- # print(f"\nExample {idx + 1}")
131
- # print(f"Text: {text}")
132
- # print(f"Ground Truth Codes (y_true): {true_codes}")
133
- # print(f"Predicted Codes (y_pred): {pred_codes}")
134
  for i, row in enumerate(dataset_split):
135
  for note in row['notes']:
136
  text = note.get('text', '')
@@ -140,15 +118,28 @@ def evaluate_efficiency(dataset_split, similarity_threshold=0.1):
140
  retrieved_results = retrieve_relevant_text(text, similarity_threshold=similarity_threshold)
141
  retrieved_codes = [result['code'] for result in retrieved_results]
142
 
 
 
 
 
 
143
  # Ground truth
144
  y_true.extend(annotations_list)
145
  # Predictions (limit to length of true annotations to avoid mismatch)
146
  y_pred.extend(retrieved_codes[:len(annotations_list)])
147
 
 
 
 
148
  # Debugging output to check for mismatches and understand results
149
  print("Sample y_true:", y_true[:10])
150
  print("Sample y_pred:", y_pred[:10])
151
 
 
 
 
 
 
152
  if len(y_true) != len(y_pred):
153
  min_length = min(len(y_true), len(y_pred))
154
  y_true = y_true[:min_length]
@@ -159,10 +150,10 @@ def evaluate_efficiency(dataset_split, similarity_threshold=0.1):
159
  recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
160
  f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
161
 
162
- return precision, recall, f1
163
 
164
  # Calculate retrieval efficiency metrics
165
- precision, recall, f1 = evaluate_efficiency(dataset['validation'], similarity_threshold=0.1)
166
 
167
  # Gradio interface
168
  def gradio_interface(input_text):
@@ -179,7 +170,8 @@ def gradio_interface(input_text):
179
  return "\n".join(formatted_results)
180
 
181
  # Display retrieval efficiency metrics
182
- metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
 
183
 
184
  with gr.Blocks() as interface:
185
  gr.Markdown("# Text Retrieval with Efficiency Metrics")
 
106
  def evaluate_efficiency(dataset_split, similarity_threshold=0.1):
107
  y_true = []
108
  y_pred = []
109
+ total_similarity = 0
110
+ total_items = 0
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  for i, row in enumerate(dataset_split):
113
  for note in row['notes']:
114
  text = note.get('text', '')
 
118
  retrieved_results = retrieve_relevant_text(text, similarity_threshold=similarity_threshold)
119
  retrieved_codes = [result['code'] for result in retrieved_results]
120
 
121
+ # Sum up similarity scores for average calculation
122
+ for result in retrieved_results:
123
+ total_similarity += result['similarity_score']
124
+ total_items += 1
125
+
126
  # Ground truth
127
  y_true.extend(annotations_list)
128
  # Predictions (limit to length of true annotations to avoid mismatch)
129
  y_pred.extend(retrieved_codes[:len(annotations_list)])
130
 
131
+ for result in retrieved_results:
132
+ print(f" Code: {result['code']}, Similarity Score: {result['similarity_score']:.2f}")
133
+
134
  # Debugging output to check for mismatches and understand results
135
  print("Sample y_true:", y_true[:10])
136
  print("Sample y_pred:", y_pred[:10])
137
 
138
+ if total_items > 0:
139
+ avg_similarity = total_similarity / total_items
140
+ else:
141
+ avg_similarity = 0
142
+
143
  if len(y_true) != len(y_pred):
144
  min_length = min(len(y_true), len(y_pred))
145
  y_true = y_true[:min_length]
 
150
  recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
151
  f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
152
 
153
+ return precision, recall, f1, avg_similarity
154
 
155
  # Calculate retrieval efficiency metrics
156
+ precision, recall, f1, avg_similarity = evaluate_efficiency(dataset['validation'], similarity_threshold=0.1)
157
 
158
  # Gradio interface
159
  def gradio_interface(input_text):
 
170
  return "\n".join(formatted_results)
171
 
172
  # Display retrieval efficiency metrics
173
+ # metrics = f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}"
174
+ metrics = f"Accuracy: {avg_similarity:.2f}"
175
 
176
  with gr.Blocks() as interface:
177
  gr.Markdown("# Text Retrieval with Efficiency Metrics")