File size: 4,961 Bytes
88be9f7
520af46
 
 
88be9f7
520af46
 
 
 
88be9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520af46
 
766da77
520af46
 
 
 
 
88be9f7
520af46
 
 
a255a76
520af46
 
88be9f7
520af46
 
88be9f7
520af46
 
88be9f7
520af46
 
 
 
88be9f7
 
 
 
 
 
 
 
520af46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f832ac4
520af46
 
 
 
 
 
88be9f7
520af46
 
 
 
f832ac4
520af46
 
 
 
88be9f7
 
 
 
520af46
 
88be9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
520af46
 
88be9f7
520af46
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from flask import Flask, request, abort, Response
from werkzeug.security import generate_password_hash, check_password_hash
from werkzeug.exceptions import HTTPException

import os, threading, json, waitress, datetime, traceback
from llama_cpp import Llama
from dotenv import load_dotenv
load_dotenv()

import sentry_sdk
from flask import Flask
from sentry_sdk.integrations.flask import FlaskIntegration

sentry_sdk.init(
    dsn="https://5dcf8a99012c4c86b9b1f0293f6b4c2e@o4505516024004608.ingest.sentry.io/4505541971935232",
    integrations=[
        FlaskIntegration(),
    ],

    # Set traces_sample_rate to 1.0 to capture 100%
    # of transactions for performance monitoring.
    # We recommend adjusting this value in production.
    traces_sample_rate=1.0
)

#Variables
DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True')
modelName = "vicuna"
llm = None
AlpacaLoaded = False

#Chat Functions
def load_alpaca():
  global llm, AlpacaLoaded, modelName
  if not AlpacaLoaded:
    print("Loading Alpaca...")
    try:
      llm = Llama(model_path=f"./resources/{modelName}-ggml-model-q4.bin", use_mmap=False, n_threads=2, verbose=False, n_ctx=2048) #use_mlock=True
      AlpacaLoaded = True
      print("Done loading Alpaca.")
      return "Done"
    except AttributeError:
      print("Error loading Alpaca. Please make sure you have the model file in the resources folder.")
      return "Error"
  else:
    print("Alpaca already loaded.")
    return "Already Loaded"

def getChatResponse(modelOutput):
  return str(modelOutput["choices"][0]['message']['content'])

def reload_alpaca():
  global llm, AlpacaLoaded, modelName
  if AlpacaLoaded:
    llm = None
    input("Pleease confirm that the memory is cleared!")
    AlpacaLoaded = False
  load_alpaca()
  return "Done"
#Authentication Functions
def loadHashes():
  global hashesDict
  try:
    with open("resources/hashes.json", "r") as f:
      hashesDict = json.load(f)
  except FileNotFoundError:
    hashesDict = {}
  
def saveHashes():
  global hashesDict
  with open("resources/hashes.json", "w") as f:
    json.dump(hashesDict, f)

def addHashes(username: str, password: str):
  global hashesDict
  hashesDict[username] = generate_password_hash(password, method='scrypt')
  saveHashes()

def checkCredentials(username: str , password: str):
  global hashesDict
  if username in hashesDict:
    return check_password_hash(hashesDict[username], password)
  else:
    return False

def verifyHeaders():
  #Check + Obtain Authorization header
  try:
    user, passw = request.headers['Authorization'].split(":")
  except (KeyError, ValueError):
    abort(401)

  #Check if Authorization header is valid
  credentialsValid = checkCredentials(user, passw)
  if not credentialsValid:
    abort(403)
  else:
    return user


loadHashes()
#addHashes("test", "test")

#General Functions
def getIsoTime():
  return str(datetime.datetime.now().isoformat())

#Flask App
app = Flask(__name__)

@app.route('/')
def main():
  return """<!DOCTYPE HTML>
<html>
  <head><meta name='color-scheme' content='dark'></head>
  <body><p>Hello, World!</p></body>
</html>"""

@app.route('/chat', methods=['GET', 'POST'])
def chat():
  if request.method == 'POST':
    print("Chat Completion Requested.")
    verifyHeaders()
    print("Headers verified")
    messages = request.get_json()
    print("Got Message" + str(messages))

    if AlpacaLoaded:
      modelOutput = llm.create_chat_completion(messages=messages, max_tokens=1024)
      responseMessage = modelOutput["choices"][0]['message']
      print(f"\n\nResponseMessage: {responseMessage}\n\n")
      return Response(json.dumps(responseMessage, indent=2), content_type='application/json')
    else:
      print("Alpaca not loaded. ")
      abort(503, "Alpaca not loaded. Please wait a few seconds and try again.")
  else:
    return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503

@app.route('/sentry_check')
def trigger_error():
  division_by_zero = 1 / 0

@app.errorhandler(HTTPException)
def handle_exception(e):
  errorInfo = json.dumps({"error": f"{e.code} - {e.name}", "message": e.description}, indent=2)
  return Response(errorInfo, content_type='application/json'), e.code

@app.errorhandler(Exception)
def handle_errors(e):
  print(f"INTERNAL SERVER ERROR 500 @ {request.path}")
  exceptionInfo = f"{type(e).__name__}: {str(e)}"
  errorTraceback = traceback.format_exc()
  print(errorTraceback)
  sentry_sdk.capture_exception(e)
  errorInfo = json.dumps({"error": f"500 - Internal Server Error", "message": exceptionInfo}, indent=2)
  return Response(errorInfo, content_type='application/json'), 500



if __name__ == '__main__':
  threading.Thread(target=load_alpaca, daemon=True).start()

  port = int(os.getenv("port", "8080"))
  print("Server successfully started.")
  if DEBUGMODEENABLED:
    app.run(host='0.0.0.0', port=port)
  else:
    waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https')