Sadem-12 commited on
Commit
4751fac
·
verified ·
1 Parent(s): 6f2b05f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -28
app.py CHANGED
@@ -9,17 +9,14 @@ import os
9
 
10
  print("Installation complete. Loading models...")
11
 
12
- # Load models once at startup
13
  model_name = "csebuetnlp/mT5_multilingual_XLSum"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
 
17
- # If you have a GPU, use it
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  print(f"Using device: {device}")
20
  model = model.to(device)
21
 
22
- # Load question generator once
23
  question_generator = pipeline(
24
  "text2text-generation",
25
  model="valhalla/t5-small-e2e-qg",
@@ -43,9 +40,8 @@ def summarize_text(text, src_lang):
43
  return summary
44
 
45
  def generate_questions(summary):
46
- # Generate questions one at a time with beam search
47
  questions = []
48
- for _ in range(3): # Generate 3 questions
49
  result = question_generator(
50
  summary,
51
  max_length=64,
@@ -57,39 +53,36 @@ def generate_questions(summary):
57
  )
58
  questions.append(result[0]['generated_text'])
59
 
60
- # Remove duplicates
61
  questions = list(set(questions))
62
  return questions
63
 
64
  def generate_concept_map(summary, questions):
65
- # Use NetworkX and matplotlib for rendering
66
  G = nx.DiGraph()
67
 
68
- # Add summary as central node
69
  summary_short = summary[:50] + "..." if len(summary) > 50 else summary
70
  G.add_node("summary", label=summary_short)
71
 
72
- # Add question nodes and edges
73
  for i, question in enumerate(questions):
74
  q_short = question[:30] + "..." if len(question) > 30 else question
75
  node_id = f"Q{i}"
76
  G.add_node(node_id, label=q_short)
77
  G.add_edge("summary", node_id)
78
 
79
- # Create the plot directly in memory
80
  plt.figure(figsize=(10, 8))
81
- pos = nx.spring_layout(G, seed=42) # Fixed seed for consistent layout
82
  nx.draw(G, pos, with_labels=False, node_color='skyblue',
83
  node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1',
84
  edgecolors='black', linewidths=1)
85
 
86
- # Add labels with better font handling
87
- # FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX
88
  labels = nx.get_node_attributes(G, 'label')
89
  nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
90
  font_family='sans-serif')
91
 
92
- # Save to memory buffer
93
  buf = io.BytesIO()
94
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
95
  buf.seek(0)
@@ -101,7 +94,7 @@ def analyze_text(text, lang):
101
  if not text.strip():
102
  return "Please enter some text.", "No questions generated.", None
103
 
104
- # Process the text
105
  try:
106
  print("Generating summary...")
107
  summary = summarize_text(text, lang)
@@ -122,40 +115,32 @@ def analyze_text(text, lang):
122
  print(traceback.format_exc())
123
  return f"Error processing text: {str(e)}", "", None
124
 
125
- # Alternative simpler concept map function in case the above still has issues
126
  def generate_simple_concept_map(summary, questions):
127
  """Fallback concept map generator with minimal dependencies"""
128
  plt.figure(figsize=(10, 8))
129
 
130
- # Create a simple radial layout
131
  n_questions = len(questions)
132
 
133
- # Draw the central node (summary)
134
  plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black')
135
  plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary,
136
  ha='center', va='center', fontsize=9)
137
 
138
- # Draw the question nodes in a circle around the summary
139
  radius = 5
140
  for i, question in enumerate(questions):
141
  angle = 2 * 3.14159 * i / max(n_questions, 1)
142
  x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5)
143
  y = radius * 0.6 * (i % 2 * 2 - 1)
144
 
145
- # Draw node
146
  plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black')
147
 
148
- # Draw edge from summary to question
149
  plt.plot([0, x], [0, y], 'k-', alpha=0.6)
150
 
151
- # Add question text
152
  plt.text(x, y, question[:30] + "..." if len(question) > 30 else question,
153
  ha='center', va='center', fontsize=8)
154
 
155
  plt.axis('equal')
156
  plt.axis('off')
157
 
158
- # Save to memory buffer
159
  buf = io.BytesIO()
160
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
161
  buf.seek(0)
@@ -184,14 +169,11 @@ def analyze_text_with_fallback(text, lang):
184
 
185
  print("Creating concept map...")
186
  try:
187
- # Try the main concept map generator first
188
  concept_map_image = generate_concept_map(summary, questions)
189
  except Exception as e:
190
  print(f"Main concept map failed: {e}, using fallback")
191
- # If it fails, use the fallback generator
192
  concept_map_image = generate_simple_concept_map(summary, questions)
193
 
194
- # Format questions as a list
195
  questions_text = "\n".join([f"- {q}" for q in questions])
196
 
197
  return summary, questions_text, concept_map_image
@@ -202,7 +184,7 @@ def analyze_text_with_fallback(text, lang):
202
  return f"Error processing text: {str(e)}", "", None
203
 
204
  iface = gr.Interface(
205
- fn=analyze_text_with_fallback, # Use the function with fallback
206
  inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")],
207
  outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")],
208
  examples=examples,
@@ -210,5 +192,4 @@ iface = gr.Interface(
210
  description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map."
211
  )
212
 
213
- # For Colab, we need to use a public URL
214
  iface.launch(share=True)
 
9
 
10
  print("Installation complete. Loading models...")
11
 
 
12
  model_name = "csebuetnlp/mT5_multilingual_XLSum"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
14
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print(f"Using device: {device}")
18
  model = model.to(device)
19
 
 
20
  question_generator = pipeline(
21
  "text2text-generation",
22
  model="valhalla/t5-small-e2e-qg",
 
40
  return summary
41
 
42
  def generate_questions(summary):
 
43
  questions = []
44
+ for _ in range(3):
45
  result = question_generator(
46
  summary,
47
  max_length=64,
 
53
  )
54
  questions.append(result[0]['generated_text'])
55
 
 
56
  questions = list(set(questions))
57
  return questions
58
 
59
  def generate_concept_map(summary, questions):
60
+
61
  G = nx.DiGraph()
62
 
63
+
64
  summary_short = summary[:50] + "..." if len(summary) > 50 else summary
65
  G.add_node("summary", label=summary_short)
66
 
67
+
68
  for i, question in enumerate(questions):
69
  q_short = question[:30] + "..." if len(question) > 30 else question
70
  node_id = f"Q{i}"
71
  G.add_node(node_id, label=q_short)
72
  G.add_edge("summary", node_id)
73
 
74
+
75
  plt.figure(figsize=(10, 8))
76
+ pos = nx.spring_layout(G, seed=42)
77
  nx.draw(G, pos, with_labels=False, node_color='skyblue',
78
  node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1',
79
  edgecolors='black', linewidths=1)
80
 
81
+
 
82
  labels = nx.get_node_attributes(G, 'label')
83
  nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
84
  font_family='sans-serif')
85
 
 
86
  buf = io.BytesIO()
87
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
88
  buf.seek(0)
 
94
  if not text.strip():
95
  return "Please enter some text.", "No questions generated.", None
96
 
97
+
98
  try:
99
  print("Generating summary...")
100
  summary = summarize_text(text, lang)
 
115
  print(traceback.format_exc())
116
  return f"Error processing text: {str(e)}", "", None
117
 
 
118
  def generate_simple_concept_map(summary, questions):
119
  """Fallback concept map generator with minimal dependencies"""
120
  plt.figure(figsize=(10, 8))
121
 
 
122
  n_questions = len(questions)
123
 
 
124
  plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black')
125
  plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary,
126
  ha='center', va='center', fontsize=9)
127
 
 
128
  radius = 5
129
  for i, question in enumerate(questions):
130
  angle = 2 * 3.14159 * i / max(n_questions, 1)
131
  x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5)
132
  y = radius * 0.6 * (i % 2 * 2 - 1)
133
 
 
134
  plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black')
135
 
 
136
  plt.plot([0, x], [0, y], 'k-', alpha=0.6)
137
 
 
138
  plt.text(x, y, question[:30] + "..." if len(question) > 30 else question,
139
  ha='center', va='center', fontsize=8)
140
 
141
  plt.axis('equal')
142
  plt.axis('off')
143
 
 
144
  buf = io.BytesIO()
145
  plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
146
  buf.seek(0)
 
169
 
170
  print("Creating concept map...")
171
  try:
 
172
  concept_map_image = generate_concept_map(summary, questions)
173
  except Exception as e:
174
  print(f"Main concept map failed: {e}, using fallback")
 
175
  concept_map_image = generate_simple_concept_map(summary, questions)
176
 
 
177
  questions_text = "\n".join([f"- {q}" for q in questions])
178
 
179
  return summary, questions_text, concept_map_image
 
184
  return f"Error processing text: {str(e)}", "", None
185
 
186
  iface = gr.Interface(
187
+ fn=analyze_text_with_fallback,
188
  inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")],
189
  outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")],
190
  examples=examples,
 
192
  description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map."
193
  )
194
 
 
195
  iface.launch(share=True)