File size: 1,390 Bytes
f3d785b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import requests

from conversation import Conversation


class BaseModel:
    name: str
    endpoint: str
    namespace: str
    generation_params: dict

    def __init__(self, name, endpoint, namespace, generation_params):
        self.name = name
        self.endpoint = endpoint
        self.namespace = namespace
        self.generation_params = generation_params

    def generate_response(self, conversation):
        prompt = self._get_prompt(conversation)
        response = self._get_response(prompt)
        return response

    def _get_prompt(self, conversation: Conversation):
        print(conversation.__dict__)
        prompt = "\n".join(
            [conversation.memory, conversation.prompt]
        ).strip()

        for message in conversation.messages:
            prompt += f"\n{message['from'].strip()}: {message['value'].strip()}"
        prompt += f"\n{conversation.bot_label}:"
        print(prompt)
        return prompt

    def _get_response(self, text):
        api = str(os.environ.get("API_BASE_PATH")).replace("\{\}", "{}")
        api = api.format(self.endpoint, self.namespace)

        payload = {'instances': [text], "parameters": self.generation_params}
        resp = requests.post(api, json=payload, timeout=600)
        assert resp.status_code == 200, (resp.content, resp.status_code)
        return resp.json()["predictions"][0].strip()