harir commited on
Commit
4b8a8ee
·
verified ·
1 Parent(s): 95faede

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +40 -6
models.py CHANGED
@@ -13,7 +13,7 @@ You will be tasked to classify sentences as 'J' or 'V'
13
 
14
  Text: "{sentence}"
15
 
16
- Please classify this text as either 'J', 'W', or 'V'. Only output 'J', 'W', or 'V' with no additional explanation.<|endoftext|>
17
  <|assistant|>
18
  """
19
  return prompt
@@ -33,15 +33,42 @@ Please revise this text such that it maintains the criticism in the original tex
33
  """
34
  return prompt
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def query_model_score(sentence, api_key, model_id, prompt_fun):
37
  API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
38
  headers = {"Authorization": f"Bearer {api_key}"}
39
  prompt = prompt_fun(sentence)
40
  def query(payload):
41
- print(payload)
42
  response = requests.post(API_URL, headers=headers, json=payload)
43
  return response.json()
44
- parameters = {"max_new_tokens" : 20, "temperature": 0.0, "return_full_text": False}
45
  options = {"wait_for_model": True}
46
  data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
47
  score = data[0]['generated_text']
@@ -57,7 +84,7 @@ def query_model_revise(sentence, api_key, model_id, prompt_fun):
57
  def query(payload):
58
  response = requests.post(API_URL, headers=headers, json=payload)
59
  return response.json()
60
- parameters = {"max_new_tokens" : 200, "temperature": 0.0, "return_full_text": False}
61
  options = {"wait_for_model": True}
62
  data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
63
  revision = data[0]['generated_text']
@@ -75,6 +102,13 @@ def revise_review(review, api_key, model_id, highlight_color):
75
  "message": ""
76
  }
77
 
 
 
 
 
 
 
 
78
  try:
79
  review = review.replace('"', "'")
80
  sentences = parser.parse_sentences(review)
@@ -83,13 +117,13 @@ def revise_review(review, api_key, model_id, highlight_color):
83
  review_revision = ""
84
  for sentence in sentences:
85
  if len(sentence) > 20:
86
- score = query_model_score(sentence, api_key, model_id, zephyr_score)
87
  if score == 0:
88
  review_revision += " " + sentence
89
  else:
90
  review_score = 1
91
  revision_count +=1
92
- revision = query_model_revise(sentence, api_key, model_id, zephyr_revise)
93
  revision = revision.strip().strip('"')
94
  review_revision += f"<div style='background-color: {highlight_color}; display: inline;'>{revision}</div>"
95
  else:
 
13
 
14
  Text: "{sentence}"
15
 
16
+ Please classify this text as either 'J' or 'V'. Only output 'J' or 'V' with no additional explanation.<|endoftext|>
17
  <|assistant|>
18
  """
19
  return prompt
 
33
  """
34
  return prompt
35
 
36
+ def mistral_score(sentence):
37
+ prompt = f"""<s>[INST]
38
+ You are an assistant helping with paper reviews.
39
+ You will be tasked to classify sentences as 'J' or 'V'
40
+
41
+ 'J' is positive or 'J' is encouraging.
42
+ 'J' has a neutral tone or 'J' is professional.
43
+ 'V' is overly blunt or 'V' contains excessive negativity and no constructive feedback.
44
+ 'V' contains an accusatory tone or 'V' contains sweeping generalizations or 'V' contains personal attacks.
45
+
46
+ Text: "{sentence}"
47
+
48
+ Please classify this text as either 'J' or 'V'. Only output 'J' or 'V' with no additional explanation. [/INST]"""
49
+ return prompt
50
+
51
+ def mistral_revise(sentence):
52
+ prompt = f"""<s>[INST]
53
+ You are an assistant that helps users revise Paper Reviews.
54
+ Paper reviews exist to provide authors of academic research papers constructive critism.
55
+
56
+ This is text found in a review.
57
+ This text was classified as 'toxic':
58
+
59
+ Text: "{sentence}"
60
+
61
+ Please revise this text such that it maintains the criticism in the original text and delivers it in a friendly but professional manner. Make minimal changes to the original text. [/INST] Revised Text: """
62
+ return prompt
63
+
64
  def query_model_score(sentence, api_key, model_id, prompt_fun):
65
  API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
66
  headers = {"Authorization": f"Bearer {api_key}"}
67
  prompt = prompt_fun(sentence)
68
  def query(payload):
 
69
  response = requests.post(API_URL, headers=headers, json=payload)
70
  return response.json()
71
+ parameters = {"max_new_tokens" : 5, "temperature": 0.1, "return_full_text": False}
72
  options = {"wait_for_model": True}
73
  data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
74
  score = data[0]['generated_text']
 
84
  def query(payload):
85
  response = requests.post(API_URL, headers=headers, json=payload)
86
  return response.json()
87
+ parameters = {"max_new_tokens" : 200, "temperature": 0.1, "return_full_text": False}
88
  options = {"wait_for_model": True}
89
  data = query({"inputs": f"{prompt}", "parameters": parameters, "options": options})
90
  revision = data[0]['generated_text']
 
102
  "message": ""
103
  }
104
 
105
+ if 'zephyr' in model_id:
106
+ revision_prompt = zephyr_revise
107
+ score_prompt = zephyr_score
108
+ elif 'mistral' in model_id:
109
+ revision_prompt = mistral_revise
110
+ score_prompt = mistral_score
111
+
112
  try:
113
  review = review.replace('"', "'")
114
  sentences = parser.parse_sentences(review)
 
117
  review_revision = ""
118
  for sentence in sentences:
119
  if len(sentence) > 20:
120
+ score = query_model_score(sentence, api_key, model_id, score_prompt)
121
  if score == 0:
122
  review_revision += " " + sentence
123
  else:
124
  review_score = 1
125
  revision_count +=1
126
+ revision = query_model_revise(sentence, api_key, model_id, revision_prompt)
127
  revision = revision.strip().strip('"')
128
  review_revision += f"<div style='background-color: {highlight_color}; display: inline;'>{revision}</div>"
129
  else: