Spaces:
Runtime error
Runtime error
import os | |
import requests | |
from conversation import Conversation | |
class BaseModel: | |
name: str | |
endpoint: str | |
namespace: str | |
generation_params: dict | |
def __init__(self, name, endpoint, namespace, generation_params): | |
self.name = name | |
self.endpoint = endpoint | |
self.namespace = namespace | |
self.generation_params = generation_params | |
def generate_response(self, conversation): | |
prompt = self._get_prompt(conversation) | |
response = self._get_response(prompt) | |
return response | |
def _get_prompt(self, conversation: Conversation): | |
print(conversation.__dict__) | |
prompt = "\n".join( | |
[conversation.memory, conversation.prompt] | |
).strip() | |
for message in conversation.messages: | |
prompt += f"\n{message['from'].strip()}: {message['value'].strip()}" | |
prompt += f"\n{conversation.bot_label}:" | |
print(prompt) | |
return prompt | |
def _get_response(self, text): | |
api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}") | |
api = api.format(self.endpoint, self.namespace) | |
payload = {'instances': [text], "parameters": self.generation_params} | |
resp = requests.post(api, json=payload, timeout=600) | |
assert resp.status_code == 200, (resp.content, resp.status_code) | |
return resp.json()["predictions"][0].strip() | |