Spaces:
Runtime error
Runtime error
from flask import Flask, render_template, redirect, request, jsonify, make_response | |
import datetime | |
import torch | |
import transformers | |
device = torch.device('cuda') | |
MODEL_NAME = 'liujch1998/vera' | |
class Interactive: | |
def __init__(self): | |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) | |
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload') | |
self.model.D = self.model.shared.embedding_dim | |
self.linear = torch.nn.Linear(self.model.D, 1, dtype=self.model.dtype).to(device) | |
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D) | |
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1) | |
self.model.eval() | |
self.t = self.model.shared.weight[32097, 0].item() | |
def run(self, statement): | |
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest', truncation='longest_first', max_length=128).input_ids.to(device) | |
with torch.no_grad(): | |
output = self.model(input_ids) | |
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D) | |
hidden = last_hidden_state[0, -1, :] # (D) | |
logit = self.linear(hidden).squeeze(-1) # () | |
logit_calibrated = logit / self.t | |
score = logit.sigmoid() | |
score_calibrated = logit_calibrated.sigmoid() | |
return { | |
'timestamp': datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), | |
'statement': statement, | |
'logit': logit.item(), | |
'logit_calibrated': logit_calibrated.item(), | |
'score': score.item(), | |
'score_calibrated': score_calibrated.item(), | |
} | |
interactive = Interactive() | |
app = Flask(__name__) | |
def main(): | |
try: | |
print(request) | |
data = request.get_json() | |
statement = data.get('statement') | |
except Exception as e: | |
return jsonify({ | |
'success': False, | |
'message': 'Please provide a statement.', | |
}), 400 | |
try: | |
result = interactive.run(statement) | |
except Exception as e: | |
return jsonify({ | |
'success': False, | |
'message': 'Internal error.', | |
}), 500 | |
return jsonify(result) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=8372, threaded=True, ssl_context=('/etc/letsencrypt/live/qa.cs.washington.edu/fullchain.pem', '/etc/letsencrypt/live/qa.cs.washington.edu/privkey.pem')) | |
# 8372 is when you type Vera on a phone keypad | |