from flask import Flask, request, abort, Response from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.exceptions import HTTPException import os, threading, json, waitress, datetime, traceback from llama_cpp import Llama from dotenv import load_dotenv load_dotenv() import sentry_sdk from flask import Flask from sentry_sdk.integrations.flask import FlaskIntegration sentry_sdk.init( dsn="https://5dcf8a99012c4c86b9b1f0293f6b4c2e@o4505516024004608.ingest.sentry.io/4505541971935232", integrations=[ FlaskIntegration(), ], # Set traces_sample_rate to 1.0 to capture 100% # of transactions for performance monitoring. # We recommend adjusting this value in production. traces_sample_rate=1.0 ) #Variables DEBUGMODEENABLED = (os.getenv('debugModeEnabled', 'False') == 'True') modelName = "vicuna" llm = None AlpacaLoaded = False #Chat Functions def load_alpaca(): global llm, AlpacaLoaded, modelName if not AlpacaLoaded: print("Loading Alpaca...") try: llm = Llama(model_path=f"./resources/{modelName}-ggml-model-q4.bin", use_mmap=False, n_threads=2, verbose=False, n_ctx=2048) #use_mlock=True AlpacaLoaded = True print("Done loading Alpaca.") return "Done" except AttributeError: print("Error loading Alpaca. Please make sure you have the model file in the resources folder.") return "Error" else: print("Alpaca already loaded.") return "Already Loaded" def getChatResponse(modelOutput): return str(modelOutput["choices"][0]['message']['content']) def reload_alpaca(): global llm, AlpacaLoaded, modelName if AlpacaLoaded: llm = None input("Pleease confirm that the memory is cleared!") AlpacaLoaded = False load_alpaca() return "Done" #Authentication Functions def loadHashes(): global hashesDict try: with open("resources/hashes.json", "r") as f: hashesDict = json.load(f) except FileNotFoundError: hashesDict = {} def saveHashes(): global hashesDict with open("resources/hashes.json", "w") as f: json.dump(hashesDict, f) def addHashes(username: str, password: str): global hashesDict hashesDict[username] = generate_password_hash(password, method='scrypt') saveHashes() def checkCredentials(username: str , password: str): global hashesDict if username in hashesDict: return check_password_hash(hashesDict[username], password) else: return False def verifyHeaders(): #Check + Obtain Authorization header try: user, passw = request.headers['Authorization'].split(":") except (KeyError, ValueError): abort(401) #Check if Authorization header is valid credentialsValid = checkCredentials(user, passw) if not credentialsValid: abort(403) else: return user loadHashes() #addHashes("test", "test") #General Functions def getIsoTime(): return str(datetime.datetime.now().isoformat()) #Flask App app = Flask(__name__) @app.route('/') def main(): return """
Hello, World!
""" @app.route('/chat', methods=['GET', 'POST']) def chat(): if request.method == 'POST': print("Chat Completion Requested.") verifyHeaders() print("Headers verified") messages = request.get_json() print("Got Message" + str(messages)) if AlpacaLoaded: modelOutput = llm.create_chat_completion(messages=messages, max_tokens=1024) responseMessage = modelOutput["choices"][0]['message'] print(f"\n\nResponseMessage: {responseMessage}\n\n") return Response(json.dumps(responseMessage, indent=2), content_type='application/json') else: print("Alpaca not loaded. ") abort(503, "Alpaca not loaded. Please wait a few seconds and try again.") else: return "Ready" if AlpacaLoaded else "Not Ready", 200 if AlpacaLoaded else 503 @app.route('/sentry_check') def trigger_error(): division_by_zero = 1 / 0 @app.errorhandler(HTTPException) def handle_exception(e): errorInfo = json.dumps({"error": f"{e.code} - {e.name}", "message": e.description}, indent=2) return Response(errorInfo, content_type='application/json'), e.code @app.errorhandler(Exception) def handle_errors(e): print(f"INTERNAL SERVER ERROR 500 @ {request.path}") exceptionInfo = f"{type(e).__name__}: {str(e)}" errorTraceback = traceback.format_exc() print(errorTraceback) sentry_sdk.capture_exception(e) errorInfo = json.dumps({"error": f"500 - Internal Server Error", "message": exceptionInfo}, indent=2) return Response(errorInfo, content_type='application/json'), 500 if __name__ == '__main__': threading.Thread(target=load_alpaca, daemon=True).start() port = int(os.getenv("port", "8080")) print("Server successfully started.") if DEBUGMODEENABLED: app.run(host='0.0.0.0', port=port) else: waitress.serve(app, host='0.0.0.0', port=port, url_scheme='https')