vera / backend /run.py
liujch1998's picture
Chrome extension
c942159
from flask import Flask, render_template, redirect, request, jsonify, make_response
import datetime
import torch
import transformers
device = torch.device('cuda')
MODEL_NAME = 'liujch1998/vera'
class Interactive:
def __init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
self.model.D = self.model.shared.embedding_dim
self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device)
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
self.model.eval()
self.t = self.model.shared.weight[32097, 0].item()
def run(self, statement):
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device)
with torch.no_grad():
output = self.model(input_ids)
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
hidden = last_hidden_state[0, -1, :] # (D)
logit = self.linear(hidden).squeeze(-1) # ()
logit_calibrated = logit / self.t
score = logit.sigmoid()
score_calibrated = logit_calibrated.sigmoid()
return {
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'),
'statement': statement,
'logit': logit.item(),
'logit_calibrated': logit_calibrated.item(),
'score': score.item(),
'score_calibrated': score_calibrated.item(),
}
interactive = Interactive()
app = Flask(__name__)
@app.route('/', methods=['GET', 'POST'])
def main():
try:
print(request)
data = request.get_json()
statement = data.get('statement')
except Exception as e:
return jsonify({
'success': False,
'message': 'Please provide a statement.',
}), 400
try:
result = interactive.run(statement)
except Exception as e:
return jsonify({
'success': False,
'message': 'Internal error.',
}), 500
return jsonify(result)
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8372, threaded=True, ssl_context=('/etc/letsencrypt/live/qa.cs.washington.edu/fullchain.pem', '/etc/letsencrypt/live/qa.cs.washington.edu/privkey.pem'))
# 8372 is when you type Vera on a phone keypad