Jaane commited on
Commit
d26463a
·
verified ·
1 Parent(s): 4272847

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -112
app.py CHANGED
@@ -3,18 +3,21 @@ import torch
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
4
  from sentence_transformers import SentenceTransformer, util
5
  import requests
6
- import warnings
7
  import os
8
- from concurrent.futures import ThreadPoolExecutor
 
9
 
10
- # Set environment variables and suppress warnings
11
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TensorFlow verbosity
12
- warnings.filterwarnings("ignore", category=FutureWarning) # Suppress FutureWarnings
13
- warnings.filterwarnings("ignore", category=UserWarning) # Suppress UserWarnings
 
14
 
15
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
 
16
 
17
- # GPT-powered sentence segmentation function
18
  def segment_into_sentences_groq(passage):
19
  headers = {
20
  "Authorization": f"Bearer {GROQ_API_KEY}",
@@ -25,146 +28,109 @@ def segment_into_sentences_groq(passage):
25
  "messages": [
26
  {
27
  "role": "system",
28
- "content": "you are to segment the sentence by adding '1!2@3#' at the end of each sentence. Return only the segmented sentences, nothing else."
29
  },
30
  {
31
  "role": "user",
32
- "content": f"Segment this passage into sentences with '1!2@3#' as a delimiter: {passage}"
33
  }
34
  ],
35
- "temperature": 0.7,
36
- "max_tokens": 1024
37
  }
38
-
39
  response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
40
  if response.status_code == 200:
41
- try:
42
- segmented_text = response.json()["choices"][0]["message"]["content"]
43
- sentences = segmented_text.split("1!2@3#")
44
- return [sentence.strip() for sentence in sentences if sentence.strip()]
45
- except (KeyError, IndexError):
46
- raise ValueError("Unexpected response structure from Groq API.")
47
  else:
48
  raise ValueError(f"Groq API error: {response.text}")
49
 
50
-
51
  class TextEnhancer:
52
  def __init__(self):
53
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
54
- self.executor = ThreadPoolExecutor(max_workers=3) # Parallel processing pool
55
-
56
- # Load models
57
- self._load_models()
58
-
59
- def _load_models(self):
60
  self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
61
  self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
62
-
63
  self.grammar_pipeline = pipeline(
64
  "text2text-generation",
65
  model="Grammarly/coedit-large",
66
  device=0 if self.device == "cuda" else -1
67
  )
68
-
69
  self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)
70
 
71
- def enhance_text(self, text, min_similarity=0.8):
72
  sentences = segment_into_sentences_groq(text)
73
-
74
- # Process sentences in parallel
75
- results = list(self.executor.map(lambda s: self._process_sentence(s, min_similarity), sentences))
76
-
77
- # Join enhanced sentences into a single text
78
- enhanced_text = ". ".join(results).strip() + "."
79
- return enhanced_text
80
-
81
- def _process_sentence(self, sentence, min_similarity):
82
- if not sentence.strip():
83
- return sentence
84
-
85
- # Generate paraphrases
86
- inputs = self.paraphrase_tokenizer(
87
- f"paraphrase: {sentence}",
88
- return_tensors="pt",
89
- padding=True,
90
- max_length=150,
91
- truncation=True
92
- ).to(self.device)
93
-
94
- outputs = self.paraphrase_model.generate(
95
- **inputs,
96
- max_length=len(sentence.split()) + 20,
97
- num_return_sequences=3,
98
- num_beams=3,
99
- temperature=0.7
100
- )
101
-
102
- paraphrases = [
103
- self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
104
- for output in outputs
105
- ]
106
-
107
- # Calculate semantic similarity
108
- sentence_embedding = self.similarity_model.encode(sentence, convert_to_tensor=True)
109
- paraphrase_embeddings = self.similarity_model.encode(paraphrases, convert_to_tensor=True)
110
- similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings).squeeze()
111
-
112
- # Filter paraphrases by similarity
113
- valid_paraphrases = [
114
- para for para, sim in zip(paraphrases, similarities)
115
- if sim >= min_similarity
116
- ]
117
-
118
- # Grammar correction for the most similar paraphrase
119
- if valid_paraphrases:
120
- corrected = self.grammar_pipeline(valid_paraphrases[0])[0]["generated_text"]
121
- return self._humanize_text(corrected)
122
- else:
123
- return sentence
124
-
125
- def _humanize_text(self, text):
126
- # Introduce minor variations to mimic human-written text
127
- import random
128
- contractions = {"can't": "cannot", "won't": "will not", "it's": "it is"}
129
- words = text.split()
130
- text = " ".join([contractions.get(word, word) if random.random() > 0.9 else word for word in words])
131
-
132
- if random.random() > 0.7:
133
- text = text.replace(" and ", ", and ")
134
- return text
135
-
136
-
137
  def create_interface():
138
  enhancer = TextEnhancer()
139
-
140
  def process_text(text, similarity_threshold):
141
  try:
142
  return enhancer.enhance_text(text, min_similarity=similarity_threshold / 100)
143
  except Exception as e:
144
  return f"Error: {str(e)}"
145
-
146
- interface = gr.Interface(
147
  fn=process_text,
148
  inputs=[
149
- gr.Textbox(
150
- label="Input Text",
151
- placeholder="Enter text to enhance...",
152
- lines=10
153
- ),
154
- gr.Slider(
155
- minimum=50,
156
- maximum=100,
157
- value=80,
158
- label="Minimum Semantic Similarity (%)"
159
- )
160
  ],
161
- outputs=gr.Textbox(label="Enhanced Text", lines=10),
162
  title="Text Enhancement System",
163
- description="Improve text quality while preserving original meaning.",
164
  )
165
- return interface
166
-
167
 
168
  if __name__ == "__main__":
169
  interface = create_interface()
170
- interface.launch()
 
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
4
  from sentence_transformers import SentenceTransformer, util
5
  import requests
 
6
  import os
7
+ import warnings
8
+ from transformers import logging
9
 
10
+ # Suppress warnings
11
+ warnings.filterwarnings("ignore", category=FutureWarning)
12
+ warnings.filterwarnings("ignore", category=UserWarning)
13
+ warnings.filterwarnings("ignore")
14
+ logging.set_verbosity_error()
15
 
16
+ # Set API keys and environment variables
17
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Ensure you set this in Hugging Face Spaces
18
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
19
 
20
+ # Groq API sentence segmentation
21
  def segment_into_sentences_groq(passage):
22
  headers = {
23
  "Authorization": f"Bearer {GROQ_API_KEY}",
 
28
  "messages": [
29
  {
30
  "role": "system",
31
+ "content": "Segment sentences by adding '1!2@3#' at the end of each sentence."
32
  },
33
  {
34
  "role": "user",
35
+ "content": f"Segment the passage: {passage}"
36
  }
37
  ],
38
+ "temperature": 1.0,
39
+ "max_tokens": 8192
40
  }
 
41
  response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
42
  if response.status_code == 200:
43
+ data = response.json()
44
+ segmented_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
45
+ sentences = segmented_text.split("1!2@3#")
46
+ return [sentence.strip() for sentence in sentences if sentence.strip()]
 
 
47
  else:
48
  raise ValueError(f"Groq API error: {response.text}")
49
 
50
+ # Text enhancement class
51
  class TextEnhancer:
52
  def __init__(self):
53
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
54
  self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
55
  self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
 
56
  self.grammar_pipeline = pipeline(
57
  "text2text-generation",
58
  model="Grammarly/coedit-large",
59
  device=0 if self.device == "cuda" else -1
60
  )
 
61
  self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)
62
 
63
+ def enhance_text(self, text, min_similarity=0.8, max_variations=2):
64
  sentences = segment_into_sentences_groq(text)
65
+ enhanced_sentences = []
66
+
67
+ for sentence in sentences:
68
+ if not sentence.strip():
69
+ continue
70
+
71
+ # Generate paraphrases
72
+ inputs = self.paraphrase_tokenizer(
73
+ f"paraphrase: {sentence}",
74
+ return_tensors="pt",
75
+ padding=True,
76
+ max_length=150,
77
+ truncation=True
78
+ ).to(self.device)
79
+
80
+ outputs = self.paraphrase_model.generate(
81
+ **inputs,
82
+ max_length=150,
83
+ num_return_sequences=max_variations,
84
+ num_beams=max_variations
85
+ )
86
+ paraphrases = [
87
+ self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
88
+ for output in outputs
89
+ ]
90
+
91
+ # Calculate semantic similarity
92
+ sentence_embedding = self.similarity_model.encode(sentence)
93
+ paraphrase_embeddings = self.similarity_model.encode(paraphrases)
94
+ similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings)
95
+
96
+ # Select the most similar paraphrase
97
+ valid_paraphrases = [
98
+ para for para, sim in zip(paraphrases, similarities[0])
99
+ if sim >= min_similarity
100
+ ]
101
+ if valid_paraphrases:
102
+ corrected = self.grammar_pipeline(
103
+ valid_paraphrases[0],
104
+ max_length=150,
105
+ num_return_sequences=1
106
+ )[0]["generated_text"]
107
+ enhanced_sentences.append(corrected)
108
+ else:
109
+ enhanced_sentences.append(sentence)
110
+
111
+ return ". ".join(enhanced_sentences).strip() + "."
112
+
113
+ # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def create_interface():
115
  enhancer = TextEnhancer()
116
+
117
  def process_text(text, similarity_threshold):
118
  try:
119
  return enhancer.enhance_text(text, min_similarity=similarity_threshold / 100)
120
  except Exception as e:
121
  return f"Error: {str(e)}"
122
+
123
+ return gr.Interface(
124
  fn=process_text,
125
  inputs=[
126
+ gr.Textbox(lines=10, placeholder="Enter text to enhance...", label="Input Text"),
127
+ gr.Slider(50, 100, 80, label="Minimum Semantic Similarity (%)")
 
 
 
 
 
 
 
 
 
128
  ],
129
+ outputs=gr.Textbox(lines=10, label="Enhanced Text"),
130
  title="Text Enhancement System",
131
+ description="Enhance text quality with semantic preservation."
132
  )
 
 
133
 
134
  if __name__ == "__main__":
135
  interface = create_interface()
136
+ interface.launch(server_name="0.0.0.0", server_port=7860)