Fangrui Liu commited on
Commit
d0f7013
β€’
1 Parent(s): fc3e81d
Files changed (4) hide show
  1. README.md +1 -1
  2. app.py +91 -0
  3. funcs.py +72 -0
  4. requirements.txt +3 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: GPTs Myscale Backend
3
  emoji: πŸ“š
4
  colorFrom: gray
5
  colorTo: gray
 
1
  ---
2
+ title: GPTs Myscale Backend RestAPI
3
  emoji: πŸ“š
4
  colorFrom: gray
5
  colorTo: gray
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import inspect
3
+ import os
4
+ import requests
5
+ import json
6
+ from io import BytesIO
7
+ from typing import List, Type
8
+
9
+ from flask import Flask, jsonify, render_template, request, send_file
10
+ from flask_restx import Resource, Api, fields
11
+ from funcs import emb_wiki, emb_arxiv, WikiKnowledgeBase, ArXivKnowledgeBase
12
+
13
+ app = Flask(__name__)
14
+ api = Api(
15
+ app,
16
+ version="0.1",
17
+ terms_url="https://myscale.com/terms/",
18
+ contact_email="[email protected]",
19
+ title="MyScale Open Knowledge Base",
20
+ description="An API to get relevant page from MyScale Open Knowledge Base",
21
+ )
22
+
23
+ query_result = api.model(
24
+ "QueryResult",
25
+ {
26
+ "documents": fields.String,
27
+ "num_retrieved": fields.Integer,
28
+ },
29
+ )
30
+
31
+ kb_list = {
32
+ "wiki": lambda: WikiKnowledgeBase(embedding=emb_wiki),
33
+ "arxiv": lambda: ArXivKnowledgeBase(embedding=emb_arxiv),
34
+ }
35
+
36
+ query_parser = api.parser()
37
+ query_parser.add_argument(
38
+ "subject",
39
+ required=True,
40
+ type=str,
41
+ help="a sentence or phrase describes the subject you want to query.",
42
+ )
43
+ query_parser.add_argument(
44
+ "where_str", required=True, type=str, help="a sql-like where string to build filter"
45
+ )
46
+ query_parser.add_argument(
47
+ "limit", required=False, type=int, default=4, help="desired number of retrieved documents"
48
+ )
49
+
50
+
51
+ @api.route(
52
+ "/get_related_docs/<string:knowledge_base>",
53
+ doc={
54
+ "description": (
55
+ "Get some related papers.\nYou should use schema here:\n\n"
56
+ "CREATE TABLE ArXiv (\n"
57
+ " `id` String,\n"
58
+ " `abstract` String, -- abstract of the paper. avoid using this column to do LIKE match\n"
59
+ " `pubdate` DateTime, \n"
60
+ " `title` String, -- title of the paper\n"
61
+ " `categories` Array(String), -- arxiv category of the paper\n"
62
+ " `authors` Array(String), -- authors of the paper\n"
63
+ " `comment` String, -- extra comments of the paper\n"
64
+ "ORDER BY id\n\n"
65
+ "CREATE TABLE Wikipedia (\n"
66
+ " `id` String,\n"
67
+ " `text` String, -- abstract of the wiki page. avoid using this column to do LIKE match\n"
68
+ " `title` String, -- title of the paper\n"
69
+ " `view` Float32,\n"
70
+ " `url` String, -- URL to this wiki page\n"
71
+ "ORDER BY id\n\n"
72
+ "You should avoid using LIKE on long text columns."
73
+ ),
74
+ },
75
+ )
76
+ @api.param("knowledge_base", "Knowledge base used to query. Must be one of ['wiki', 'arxiv']")
77
+ class get_related_docs(Resource):
78
+ @api.expect(query_parser)
79
+ @api.marshal_with(query_result)
80
+ def get(self, knowledge_base):
81
+ args = query_parser.parse_args()
82
+ kb = kb_list[knowledge_base]()
83
+ print(kb)
84
+ print(args.subject, args.where_str, args.limit)
85
+ docs, num_docs = kb(args.subject, args.where_str, args.limit)
86
+ return {"documents": docs, "num_retrieved": num_docs}
87
+
88
+
89
+ if __name__ == "__main__":
90
+ # print(json.dumps(api.__schema__))
91
+ app.run(host="0.0.0.0", port=7860, debug=True)
funcs.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Tuple
3
+ import clickhouse_connect
4
+ from sentence_transformers import SentenceTransformer
5
+ from InstructorEmbedding import INSTRUCTOR
6
+
7
+
8
+ emb_wiki = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
9
+ emb_arxiv = INSTRUCTOR('hkunlp/instructor-xl')
10
+
11
+ class ArXivKnowledgeBase:
12
+ def __init__(self, embedding: SentenceTransformer) -> None:
13
+ self.db = clickhouse_connect.get_client(
14
+ host='msc-4a9e710a.us-east-1.aws.staging.myscale.cloud',
15
+ port=443,
16
+ username='chatdata',
17
+ password='myscale_rocks'
18
+ )
19
+ self.embedding: SentenceTransformer = embedding
20
+ self.table: str = 'default.ChatArXiv'
21
+ self.embedding_col = "vector"
22
+ self.must_have_cols: List[str] = ['id', 'abstract', 'authors', 'categories', 'comment', 'title', 'pubdate']
23
+
24
+
25
+ def __call__(self, subject: str, where_str: str = None, limit: int = 5) -> Tuple[str, int]:
26
+ q_emb = self.embedding.encode(subject).tolist()
27
+ q_emb_str = ",".join(map(str, q_emb))
28
+ if where_str:
29
+ where_str = f"WHERE {where_str}"
30
+ else:
31
+ where_str = ""
32
+
33
+ q_str = f"""
34
+ SELECT dist, {','.join(self.must_have_cols)}
35
+ FROM {self.table}
36
+ {where_str}
37
+ ORDER BY distance({self.embedding_col}, [{q_emb_str}])
38
+ AS dist ASC
39
+ LIMIT {limit}
40
+ """
41
+
42
+ docs = [r for r in self.db.query(q_str).named_results()]
43
+ return '\n'.join([str(d) for d in docs]), len(docs)
44
+
45
+ class WikiKnowledgeBase(ArXivKnowledgeBase):
46
+ def __init__(self, embedding: SentenceTransformer) -> None:
47
+ super().__init__(embedding)
48
+ self.table: str = 'wiki.Wikipedia'
49
+ self.embedding_col = "emb"
50
+ self.must_have_cols: List[str] = ['text', 'title', 'views', 'url']
51
+
52
+
53
+ if __name__ == '__main__':
54
+ # kb = ArXivKnowledgeBase(embedding=emb_arxiv)
55
+ kb = WikiKnowledgeBase(embedding=emb_wiki)
56
+ d = kb("When did Steven Jobs die?", "", 5)
57
+ print(d)
58
+
59
+
60
+ d = {"components": {
61
+ "schemas": {
62
+ "type": "object",
63
+ "properties": {
64
+ "todos":{
65
+ "type": "array",
66
+ "items":{"type": "string"},
67
+ "description": "The list of todos.",
68
+ }
69
+ }
70
+ }
71
+ }
72
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ clickhouse_connect
2
+ flask
3
+ flask-restx