vhr1007 commited on
Commit
500c1ba
1 Parent(s): b55082a
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image
2
+ FROM python:3.8-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Copy requirements.txt and install dependencies
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache-dir -r requirements.txt
10
+
11
+ # Copy the rest of the application code
12
+ COPY . .
13
+
14
+
15
+ # Command to run the FastAPI application
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import login
2
+ from fastapi import FastAPI, Depends, HTTPException
3
+ import logging
4
+ from pydantic import BaseModel
5
+ from sentence_transformers import SentenceTransformer
6
+ from services.qdrant_searcher import QdrantSearcher
7
+ from services.openai_service import generate_rag_response
8
+ from utils.auth import token_required
9
+ from dotenv import load_dotenv
10
+ import os
11
+
12
+ load_dotenv() # Load environment variables from .env file
13
+
14
+ app = FastAPI()
15
+
16
+ os.environ["HF_HOME"] = "/tmp/huggingface_cache"
17
+
18
+ # Ensure the cache directory exists
19
+ cache_dir = os.environ["HF_HOME"]
20
+ if not os.path.exists(cache_dir):
21
+ os.makedirs(cache_dir)
22
+
23
+ # Setup logging
24
+ logging.basicConfig(level=logging.INFO)
25
+ # Load Hugging Face token from environment variable
26
+ huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')
27
+
28
+
29
+ if huggingface_token:
30
+ login(token=huggingface_token, add_to_git_credential=True)
31
+ else:
32
+ raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")
33
+
34
+
35
+ # Initialize the Qdrant searcher
36
+ qdrant_url = os.getenv('QDRANT_URL')
37
+ access_token = os.getenv('QDRANT_ACCESS_TOKEN')
38
+ encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2', trust_remote_code=True) # Replace with your actual encoder
39
+ searcher = QdrantSearcher(encoder, qdrant_url, access_token)
40
+
41
+ # Request body models
42
+ class SearchDocumentsRequest(BaseModel):
43
+ query: str
44
+ limit: int = 3
45
+
46
+ class GenerateRAGRequest(BaseModel):
47
+ search_query: str
48
+
49
+ @app.post("/api/search-documents")
50
+ async def search_documents(
51
+ body: SearchDocumentsRequest,
52
+ credentials: tuple = Depends(token_required)
53
+ ):
54
+ customer_id, user_id = credentials
55
+
56
+ # Check if customer_id or user_id is missing
57
+ if not customer_id or not user_id:
58
+ logging.error("Failed to extract customer_id or user_id from the JWT token.")
59
+ raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
60
+
61
+ logging.info("Received request to search documents")
62
+ try:
63
+ collection_name = "my_embeddings"
64
+ hits, error = searcher.search_documents(collection_name, body.query, user_id, body.limit)
65
+
66
+ if error:
67
+ logging.error(f"Search documents error: {error}")
68
+ raise HTTPException(status_code=500, detail=error)
69
+
70
+ return hits
71
+ except Exception as e:
72
+ logging.error(f"Unexpected error: {e}")
73
+ raise HTTPException(status_code=500, detail=str(e))
74
+
75
+ @app.post("/api/generate-rag-response")
76
+ async def generate_rag_response_api(
77
+ body: GenerateRAGRequest,
78
+ credentials: tuple = Depends(token_required)
79
+ ):
80
+ customer_id, user_id = credentials
81
+
82
+ # Check if customer_id or user_id is missing
83
+ if not customer_id or not user_id:
84
+ logging.error("Failed to extract customer_id or user_id from the JWT token.")
85
+ raise HTTPException(status_code=401, detail="Invalid token: missing customer_id or user_id")
86
+
87
+ logging.info("Received request to generate RAG response")
88
+ try:
89
+ collection_name = "my_embeddings"
90
+ hits, error = searcher.search_documents(collection_name, body.search_query, user_id)
91
+
92
+ if error:
93
+ logging.error(f"Search documents error: {error}")
94
+ raise HTTPException(status_code=500, detail=error)
95
+
96
+ response, error = generate_rag_response(hits, body.search_query)
97
+
98
+ if error:
99
+ logging.error(f"Generate RAG response error: {error}")
100
+ raise HTTPException(status_code=500, detail=error)
101
+
102
+ return {"response": response}
103
+ except Exception as e:
104
+ logging.error(f"Unexpected error: {e}")
105
+ raise HTTPException(status_code=500, detail=str(e))
106
+
107
+ if __name__ == '__main__':
108
+ import uvicorn
109
+ uvicorn.run(app, host='0.0.0.0', port=8000)
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+
4
+ QDRANT_URL = os.getenv('QDRANT_URL')
5
+ QDRANT_ACCESS_TOKEN = os.getenv('QDRANT_ACCESS_TOKEN')
6
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
7
+ JWKS_URL = os.getenv('JWKS_URL')
requiements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.78.0
2
+ uvicorn==0.17.6
3
+ pandas==1.3.5
4
+ qdrant-client==0.9.2
5
+ sentence-transformers==2.2.2
6
+ openai==0.27.0
7
+ PyJWT==2.6.0
8
+ python-dotenv==0.19.2
services/__init__py ADDED
File without changes
services/openai_service.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from openai import OpenAI
4
+ from openai import OpenAIError, RateLimitError
5
+ from config import OPENAI_API_KEY
6
+
7
+ # Initialize the OpenAI client with the API key from the environment variable
8
+ #api_key = os.getenv('OPENAI_API_KEY')
9
+
10
+ client = OpenAI(api_key=OPENAI_API_KEY)
11
+
12
+ def generate_rag_response(json_output, user_query):
13
+ logging.info("Generating RAG response")
14
+
15
+ # Extract text from the JSON output
16
+ context_texts = [hit['chunk_text'] for hit in json_output]
17
+
18
+ # Create the context for the prompt
19
+ context = "\n".join(context_texts)
20
+ prompt = f"Based on the given context, answer the user query: {user_query}\nContext:\n{context}"
21
+
22
+ main_prompt = [
23
+ {"role": "system", "content": "You are a helpful assistant."},
24
+ {"role": "user", "content": prompt}
25
+ ]
26
+
27
+ try:
28
+ # Create a chat completion request
29
+ chat_completion = client.chat.completions.create(
30
+ messages=main_prompt,
31
+ model="gpt-4o-mini", # Use the gpt-4o-mini model
32
+ timeout=10
33
+ )
34
+
35
+ # Log the response from the model
36
+ logging.info("RAG response generation completed")
37
+ logging.info(f"RAG response: {chat_completion.choices[0].message.content}")
38
+ return chat_completion.choices[0].message.content, None
39
+
40
+ except RateLimitError as e:
41
+ logging.error(f"Rate limit exceeded: {e}")
42
+ return None, "Rate limit exceeded. Please try again later."
43
+ except OpenAIError as e:
44
+ logging.error(f"OpenAI API error: {e}")
45
+ return None, f"An error occurred: {str(e)}"
46
+ except Exception as e:
47
+ logging.error(f"Unexpected error: {e}")
48
+ return None, str(e)
services/qdrant_searcher.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client.http.models import Filter, FieldCondition
4
+
5
+ class QdrantSearcher:
6
+ def __init__(self, encoder, qdrant_url, access_token):
7
+ self.encoder = encoder
8
+ self.client = QdrantClient(url=qdrant_url, api_key=access_token)
9
+
10
+ def search_documents(self, collection_name, query, user_id, limit=3):
11
+ logging.info("Starting document search")
12
+ query_vector = self.encoder.encode(query).tolist()
13
+ query_filter = Filter(must=[FieldCondition(key="user_id", match={"value": user_id})])
14
+
15
+ try:
16
+ hits = self.client.search(
17
+ collection_name=collection_name,
18
+ query_vector=query_vector,
19
+ limit=limit,
20
+ query_filter=query_filter
21
+ )
22
+ except Exception as e:
23
+ logging.error(f"Error during Qdrant search: {e}")
24
+ return None, str(e)
25
+
26
+ if not hits:
27
+ logging.info("No documents found for the given query")
28
+ return None, "No documents found for the given query."
29
+
30
+ hits_list = []
31
+ for hit in hits:
32
+ hit_info = {
33
+ "id": hit.id,
34
+ "score": hit.score,
35
+ "file_id": hit.payload.get('file_id'),
36
+ "organization_id": hit.payload.get('organization_id'),
37
+ "chunk_index": hit.payload.get('chunk_index'),
38
+ "chunk_text": hit.payload.get('chunk_text'),
39
+ "s3_bucket_key": hit.payload.get('s3_bucket_key')
40
+ }
41
+ hits_list.append(hit_info)
42
+
43
+ logging.info(f"Document search completed with {len(hits_list)} hits")
44
+ return hits_list, None
utils/__init__.py ADDED
File without changes
utils/auth.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from fastapi import Depends, HTTPException
3
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
+ import jwt
5
+ from jwt import PyJWKClient
6
+ from config import JWKS_URL
7
+
8
+ security = HTTPBearer()
9
+
10
+ def get_public_key(token: str):
11
+ try:
12
+ jwks_client = PyJWKClient(JWKS_URL)
13
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
14
+ return signing_key.key
15
+ except Exception as e:
16
+ logging.error(f"Error fetching public key: {e}")
17
+ raise
18
+
19
+ def token_required(credentials: HTTPAuthorizationCredentials = Depends(security)):
20
+ token = credentials.credentials
21
+ try:
22
+ public_key = get_public_key(token)
23
+ decoded = jwt.decode(
24
+ token,
25
+ public_key,
26
+ algorithms=['RS256'],
27
+ issuer="https://assuring-lobster-64.clerk.accounts.dev"
28
+ )
29
+ customer_id = decoded.get('org_id')
30
+ user_id = decoded.get('sub')
31
+ logging.info(f"Customer/Org ID: {customer_id}, User ID: {user_id}")
32
+ if not customer_id:
33
+ logging.error("Customer ID is missing in the token!")
34
+ raise HTTPException(status_code=401, detail="Customer ID is missing in the token!")
35
+ return customer_id, user_id
36
+ except jwt.ExpiredSignatureError:
37
+ logging.error("Token has expired")
38
+ raise HTTPException(status_code=401, detail="Token has expired")
39
+ except jwt.InvalidTokenError as e:
40
+ logging.error(f"Invalid token: {e}")
41
+ raise HTTPException(status_code=401, detail="Invalid token")
42
+ except Exception as e:
43
+ logging.error(f"Error decoding token: {e}")
44
+ raise HTTPException(status_code=401, detail=str(e))