import json import os from typing import Dict, List, Any from llama_cpp import Llama import gemma_tools as gem MAX_TOKENS=8192 class EndpointHandler(): def __init__(self, data): self.model = Llama.from_pretrained("lmstudio-ai/gemma-2b-it-GGUF", filename="gemma-2b-it-q4_k_m.gguf", n_ctx=8192) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: args = gem.get_args_or_none(data) fmat = "system\n{system_prompt} \nuser\n{prompt} \nmodel" print(args, fmat) if not args[0]: return { "status": args["status"], "message": args["description"] } try: fmat = fmat.format(system_prompt = args["system_prompt"], prompt = args["inputs"]) except Exception as e: return json.dumps({ "status": "error", "reason": "invalid format" }) max_length = data.pop("max_length", 512) try: max_length = int(max_length) except Exception as e: return json.dumps({ "status": "error", "reason": "max_length was passed as something that was absolutely not a plain old int" }) res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length) return res