File size: 1,676 Bytes
b77d203
 
 
c8025cd
a16ade2
 
c8025cd
b77d203
 
 
 
a16ade2
 
c8025cd
b77d203
 
a16ade2
f253677
a16ade2
f253677
 
 
 
a16ade2
 
f253677
 
 
 
 
a16ade2
 
 
 
c8025cd
b77d203
 
 
a16ade2
b77d203
 
a16ade2
b77d203
a16ade2
b77d203
 
 
c8025cd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
from openai import OpenAI
from api_connections import get_graph_connection
from query_graph import query_knowledge_graph

import openai

def get_llm():
    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("No OpenAI API key found in environment variables.")
    openai.api_key = api_key
    return openai

def classify_query(query):
    llm = get_llm()
    try:
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo-16k",
            messages=[
                {"role": "system", "content": "Respond with classifying the query into 'Chemotherapy' or 'Traffic Law' based on the content of the user's query. Do not respond with anything else. Only 'Chemotherapy' or 'Traffic Law' depending on whichever field the query is closest to."},
                {"role": "user", "content": f"{query}"}
            ],
            max_tokens=60
        )
        category = response.choices[0].message.content
        if category in ["Chemotherapy", "Traffic Law"]:
            return category
        else:
            return None
    except Exception as e:
        print(f"Error during classification: {e}")
        return None


def main():
    print("Starting the script...")
    # Get user query
    query = input("Please enter your query: ")    
    # Classify the query
    category = classify_query(query)
    print(f"Query classified into category: {category}")    
    # Get the correct graph connection
    graph = get_graph_connection(category)    
    # Query the correct graph
    result = query_knowledge_graph(graph, query)
    print(f"Query result: {result}")
if __name__ == "__main__":
    main()