Testys commited on
Commit
2210a0e
1 Parent(s): 1549ba5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -30
main.py CHANGED
@@ -36,40 +36,50 @@ sentiment_model = SentimentCNNModel(
36
  sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
37
  sentiment_model.eval()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def analyze_text(text):
41
- # Tokenize input text for NER
42
- ner_inputs = ner_tokenizer(text, return_tensors="pt")
43
-
44
- input_ids = ner_inputs['input_ids']
45
-
46
- # Converting token IDs back to tokens
47
- tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]
48
-
49
-
50
- # Perform Named Entity Recognition
51
- with torch.no_grad():
52
- ner_outputs = ner_model(**ner_inputs)
53
-
54
- ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
55
- ner_labels = ner_predictions.tolist()
56
- ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]
57
 
58
- #matching the tokens with the labels
59
- ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]
60
 
61
- # Tokenize input text for sentiment analysis
62
- sentiment_inputs = sentiment_tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
63
 
64
- # Perform sentiment analysis
65
- with torch.no_grad():
66
- sentiment_outputs = sentiment_model(**sentiment_inputs)
67
- sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
68
- sentiment_scores = sentiment_probabilities.tolist()
69
- sentiment_id = sentiment_scores[0]
70
- sentiment = sentiment_config["id2label"][str(sentiment_id)]
71
 
72
- return ner_labels, sentiment
73
 
74
  def main():
75
  st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
@@ -139,4 +149,4 @@ def main():
139
  """, unsafe_allow_html=True)
140
 
141
  if __name__ == "__main__":
142
- main()
 
36
  sentiment_model.load_state_dict(torch.load(sentiment_model_name, map_location=torch.device('cpu')))
37
  sentiment_model.eval()
38
 
39
+ def analyze_text(text, window_size=512, stride=256):
40
+ # Initialize results
41
+ all_ner_labels = []
42
+ all_sentiments = []
43
+
44
+ # Process text in windows
45
+ for i in range(0, len(text), stride):
46
+ window = text[i:i+window_size]
47
+
48
+ # Tokenize input text for NER
49
+ ner_inputs = ner_tokenizer(window, return_tensors="pt", truncation=True, padding=True, max_length=window_size)
50
+
51
+ input_ids = ner_inputs['input_ids']
52
+ tokens = [ner_tokenizer.convert_ids_to_tokens(id) for id in input_ids.squeeze().tolist()]
53
 
54
+ # Perform Named Entity Recognition
55
+ with torch.no_grad():
56
+ ner_outputs = ner_model(**ner_inputs)
57
+
58
+ ner_predictions = torch.argmax(ner_outputs, dim=-1)[0]
59
+ ner_labels = ner_predictions.tolist()
60
+ ner_labels = [ner_config["id2labels"][str(label)] for label in ner_labels]
61
+ ner_labels = [f"{token}: {label}" for token, label in zip(tokens, ner_labels)]
62
+
63
+ all_ner_labels.extend(ner_labels) # Adjust logic to merge overlapping windows appropriately
 
 
 
 
 
 
 
64
 
65
+ # Tokenize input text for sentiment analysis
66
+ sentiment_inputs = sentiment_tokenizer(window, return_tensors="pt", truncation=True, padding=True, max_length=window_size)
67
 
68
+ # Perform sentiment analysis
69
+ with torch.no_grad():
70
+ sentiment_outputs = sentiment_model(**sentiment_inputs)
71
+ sentiment_probabilities = torch.argmax(sentiment_outputs, dim=1)
72
+ sentiment_scores = sentiment_probabilities.tolist()
73
+ sentiment_id = sentiment_scores[0]
74
+ sentiment = sentiment_config["id2label"][str(sentiment_id)]
75
+ all_sentiments.append(sentiment) # This needs logic to combine sentiment over windows
76
 
77
+ # For simplicity, aggregate sentiments by majority vote
78
+ from collections import Counter
79
+ sentiment_counts = Counter(all_sentiments)
80
+ final_sentiment = sentiment_counts.most_common(1)[0][0]
 
 
 
81
 
82
+ return all_ner_labels, final_sentiment
83
 
84
  def main():
85
  st.set_page_config(page_title="YorubaCNN for NER and Sentiment Analysis", layout="wide")
 
149
  """, unsafe_allow_html=True)
150
 
151
  if __name__ == "__main__":
152
+ main()