File size: 1,562 Bytes
2edd118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# SparkGPT.py
from . import SparkApi
#以下密钥信息从os环境获取
import os

appid = os.environ['APPID']
api_secret = os.environ['APISecret'] 
api_key = os.environ['APIKey']


from .BaseLLM import BaseLLM

    


class SparkGPT(BaseLLM):

    def __init__(self, model="Spark2.0"):
        super(SparkGPT,self).__init__()
        if model == "Spark2.0":
            self.domain = "generalv2"    # v2.0版本
            self.Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"  # v2.0环境的地址
        elif model == "Spark1.5":
            self.domain = "general"   # v1.5版本
            self.Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"  # v1.5环境的地址
        else:
            raise Exception("Unknown Spark model")
        # SparkApi.answer =""
        self.messages = ''
        

    def initialize_message(self):
        self.messages = ''

    def ai_message(self, payload):
        self.messages = self.messages + "AI: " + payload 

    def system_message(self, payload):
        self.messages = self.messages + "System: " + payload 

    def user_message(self, payload):
        self.messages = self.messages + "User: " + payload 

    def get_response(self):
        # question = checklen(getText("user",Input))

        message_json = [{"role": "user", "content": self.messages}]
        SparkApi.answer =""
        SparkApi.main(appid,api_key,api_secret,self.Spark_url,self.domain,message_json)
        return SparkApi.answer
    
    def print_prompt(self):
        print(type(self.messages))
        print(self.messages)