Prakhar Bhandari
commited on
Commit
·
b77d203
1
Parent(s):
babec93
First attempt at incorporating multiple graphs
Browse files- kg_builder/src/__pycache__/api_connections.cpython-39.pyc +0 -0
- kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc +0 -0
- kg_builder/src/__pycache__/models.cpython-39.pyc +0 -0
- kg_builder/src/__pycache__/utils.cpython-39.pyc +0 -0
- kg_builder/src/api_connections.py +25 -42
- kg_builder/src/graph_creation.py +90 -0
- kg_builder/src/knowledge_graph_builder.py +4 -1
- kg_builder/src/main.py +31 -21
- kg_builder/src/query_graph.py +1 -1
kg_builder/src/__pycache__/api_connections.cpython-39.pyc
CHANGED
Binary files a/kg_builder/src/__pycache__/api_connections.cpython-39.pyc and b/kg_builder/src/__pycache__/api_connections.cpython-39.pyc differ
|
|
kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc
CHANGED
Binary files a/kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc and b/kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc differ
|
|
kg_builder/src/__pycache__/models.cpython-39.pyc
CHANGED
Binary files a/kg_builder/src/__pycache__/models.cpython-39.pyc and b/kg_builder/src/__pycache__/models.cpython-39.pyc differ
|
|
kg_builder/src/__pycache__/utils.cpython-39.pyc
CHANGED
Binary files a/kg_builder/src/__pycache__/utils.cpython-39.pyc and b/kg_builder/src/__pycache__/utils.cpython-39.pyc differ
|
|
kg_builder/src/api_connections.py
CHANGED
@@ -11,17 +11,21 @@ from typing import Optional, List
|
|
11 |
|
12 |
load_dotenv() # This loads the variables from .env into os.environ
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
url=url,
|
22 |
-
username=username,
|
23 |
-
password=password
|
24 |
-
)
|
25 |
|
26 |
def get_llm():
|
27 |
api_key = os.getenv("OPENAI_API_KEY")
|
@@ -30,44 +34,23 @@ def get_llm():
|
|
30 |
return ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
|
31 |
|
32 |
def get_extraction_chain(
|
|
|
33 |
allowed_nodes: Optional[List[str]] = None,
|
34 |
allowed_rels: Optional[List[str]] = None
|
35 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
llm = get_llm()
|
37 |
prompt = ChatPromptTemplate.from_messages(
|
38 |
[(
|
39 |
-
|
40 |
-
f"""# Knowledge Graph Instructions for GPT-4
|
41 |
-
## 1. Overview
|
42 |
-
You are a sophisticated algorithm tailored for parsing Wikipedia pages to construct a knowledge graph about chemotherapy and related cancer treatments.
|
43 |
-
- **Nodes** symbolize entities such as medical conditions, drugs, symptoms, treatments, and associated medical concepts.
|
44 |
-
- The goal is to create a precise and comprehensible knowledge graph, serving as a reliable resource for medical practitioners and scholarly research.
|
45 |
-
|
46 |
-
## 2. Labeling Nodes
|
47 |
-
- **Consistency**: Utilize uniform labels for node types to maintain clarity.
|
48 |
-
- For instance, consistently label drugs as **"Drug"**, symptoms as **"Symptom"**, and treatments as **"Treatment"**.
|
49 |
-
- **Node IDs**: Apply descriptive, legible identifiers for node IDs, sourced directly from the text.
|
50 |
-
|
51 |
-
{'- **Allowed Node Labels:**' + ", ".join(['Drug', 'Symptom', 'Treatment', 'MedicalCondition', 'ResearchStudy']) if allowed_nodes else ""}
|
52 |
-
{'- **Allowed Relationship Types**:' + ", ".join(['Treats', 'Causes', 'Researches', 'Recommends']) if allowed_rels else ""}
|
53 |
-
|
54 |
-
## 3. Handling Numerical Data and Dates
|
55 |
-
- Integrate numerical data and dates as attributes of the corresponding nodes.
|
56 |
-
- **No Isolated Nodes for Dates/Numbers**: Directly associate dates and numerical figures as attributes with pertinent nodes.
|
57 |
-
- **Property Format**: Follow a straightforward key-value pattern for properties, with keys in camelCase, for example, `approvedYear`, `dosageAmount`.
|
58 |
-
|
59 |
-
## 4. Coreference Resolution
|
60 |
-
- **Entity Consistency**: Guarantee uniform identification of each entity across the graph.
|
61 |
-
- For example, if "Methotrexate" and "MTX" reference the same medication, uniformly apply "Methotrexate" as the node ID.
|
62 |
-
|
63 |
-
## 5. Relationship Naming Conventions
|
64 |
-
- **Clarity and Standardization**: Utilize clear and standardized relationship names, preferring uppercase with underscores for readability.
|
65 |
-
- For instance, use "HAS_SIDE_EFFECT" instead of "HASSIDEEFFECT", use "CAN_RESULT_FROM" instead of "CANRESULTFROM" etc. You keep making the same mistakes of storing the relationships without the "_" in between the words. Any further similar errors will lead to termination.
|
66 |
-
- **Relevance and Specificity**: Choose relationship names that accurately reflect the connection between nodes, such as "INHIBITS" or "ACTIVATES" for interactions between substances.
|
67 |
-
|
68 |
-
## 6. Strict Compliance
|
69 |
-
Rigorous adherence to these instructions is essential. Failure to comply with the specified formatting and labeling norms will necessitate output revision or discard.
|
70 |
-
"""),
|
71 |
("human", "Use the given format to extract information from the following input: {input}"),
|
72 |
("human", "Tip: Precision in the node and relationship creation is vital for the integrity of the knowledge graph."),
|
73 |
])
|
|
|
11 |
|
12 |
load_dotenv() # This loads the variables from .env into os.environ
|
13 |
|
14 |
+
def get_graph_connection(category):
|
15 |
+
if category == "Chemotherapy":
|
16 |
+
url = os.getenv("CHEMO_NEO4J_URL")
|
17 |
+
username = os.getenv("CHEMO_NEO4J_USERNAME")
|
18 |
+
password = os.getenv("CHEMO_NEO4J_PASSWORD")
|
19 |
+
elif category == "Traffic Law":
|
20 |
+
url = os.getenv("TRAFFIC_NEO4J_URL")
|
21 |
+
username = os.getenv("TRAFFIC_NEO4J_USERNAME")
|
22 |
+
password = os.getenv("TRAFFIC_NEO4J_PASSWORD")
|
23 |
+
else:
|
24 |
+
raise ValueError(f"Unknown category: {category}")
|
25 |
+
|
26 |
+
return Neo4jGraph(url=url, username=username, password=password)
|
27 |
|
28 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def get_llm():
|
31 |
api_key = os.getenv("OPENAI_API_KEY")
|
|
|
34 |
return ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
|
35 |
|
36 |
def get_extraction_chain(
|
37 |
+
category,
|
38 |
allowed_nodes: Optional[List[str]] = None,
|
39 |
allowed_rels: Optional[List[str]] = None
|
40 |
):
|
41 |
+
if category == "Chemotherapy":
|
42 |
+
# Chemotherapy-specific prompt
|
43 |
+
prompt_text = ""
|
44 |
+
elif category == "Traffic Law":
|
45 |
+
# Traffic Law-specific prompt
|
46 |
+
prompt_text = "[Traffic Law-specific instructions]"
|
47 |
+
else:
|
48 |
+
raise ValueError("Unknown category")
|
49 |
+
|
50 |
llm = get_llm()
|
51 |
prompt = ChatPromptTemplate.from_messages(
|
52 |
[(
|
53 |
+
"system",prompt_text),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
("human", "Use the given format to extract information from the following input: {input}"),
|
55 |
("human", "Tip: Precision in the node and relationship creation is vital for the integrity of the knowledge graph."),
|
56 |
])
|
kg_builder/src/graph_creation.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import WikipediaLoader
|
2 |
+
from langchain.text_splitter import TokenTextSplitter
|
3 |
+
from knowledge_graph_builder import extract_and_store_graph
|
4 |
+
from langchain.schema import Document
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# Define articles to load
|
13 |
+
articles = {
|
14 |
+
"Chemotherapy": "Chemotherapy",
|
15 |
+
"Traffic Law": "Traffic laws in the United States"
|
16 |
+
}
|
17 |
+
|
18 |
+
def build_graph_for_article(article_name, category):
|
19 |
+
print(f"Loading documents for: {article_name}")
|
20 |
+
# Load and process the Wikipedia article
|
21 |
+
raw_documents = WikipediaLoader(query=article_name).load()
|
22 |
+
if not raw_documents:
|
23 |
+
print(f"Failed to load content for {article_name}")
|
24 |
+
return
|
25 |
+
|
26 |
+
text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
|
27 |
+
documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
|
28 |
+
|
29 |
+
print("Building the knowledge graph...")
|
30 |
+
for i, document in tqdm(enumerate(documents), total=len(documents)):
|
31 |
+
extract_and_store_graph(document, category)
|
32 |
+
|
33 |
+
def main():
|
34 |
+
for category, title in articles.items():
|
35 |
+
build_graph_for_article(title, category)
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
main()
|
39 |
+
|
40 |
+
# import os
|
41 |
+
# from openai import OpenAI
|
42 |
+
# from api_connections import get_graph_connection
|
43 |
+
# from knowledge_graph_builder import extract_and_store_graph
|
44 |
+
# from query_graph import query_knowledge_graph
|
45 |
+
# from langchain_community.document_loaders import WikipediaLoader
|
46 |
+
# from langchain.text_splitter import TokenTextSplitter
|
47 |
+
# from tqdm import tqdm
|
48 |
+
|
49 |
+
# def get_llm():
|
50 |
+
# api_key = os.getenv("OPENAI_API_KEY")
|
51 |
+
# if not api_key:
|
52 |
+
# raise ValueError("No OpenAI API key found in environment variables.")
|
53 |
+
# return OpenAI(api_key=api_key)
|
54 |
+
|
55 |
+
# def classify_query(query):
|
56 |
+
# llm = get_llm()
|
57 |
+
# response = llm.Completion.create(
|
58 |
+
# model="text-davinci-003", # Consider updating to the latest model as necessary
|
59 |
+
# prompt=f"Classify the following query into 'Chemotherapy' or 'Traffic Law': {query}",
|
60 |
+
# max_tokens=60
|
61 |
+
# )
|
62 |
+
# return response.choices[0].text.strip()
|
63 |
+
|
64 |
+
# def main():
|
65 |
+
# print("Starting the script...")
|
66 |
+
# # Take Wikipedia article name as input
|
67 |
+
# article_name = input("Enter the Wikipedia article name: ")
|
68 |
+
|
69 |
+
# print(f"Loading documents for: {article_name}")
|
70 |
+
# # Load and process the Wikipedia article
|
71 |
+
# raw_documents = WikipediaLoader(query=article_name).load()
|
72 |
+
# text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
|
73 |
+
# documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
|
74 |
+
|
75 |
+
# print("Building the knowledge graph...")
|
76 |
+
# # Build the knowledge graph from the documents
|
77 |
+
# for i, d in tqdm(enumerate(documents), total=len(documents)):
|
78 |
+
# extract_and_store_graph(d)
|
79 |
+
|
80 |
+
# print("Graph construction complete. Please enter your query.")
|
81 |
+
# # Take a query related to the graph
|
82 |
+
# user_query = input("Enter your query related to the graph: ")
|
83 |
+
|
84 |
+
# print(f"Querying the graph with: {user_query}")
|
85 |
+
# # Query the graph and print the answer
|
86 |
+
# answer = query_knowledge_graph(user_query)
|
87 |
+
# print("Answer to your query:", answer)
|
88 |
+
|
89 |
+
# if __name__ == "__main__":
|
90 |
+
# main()
|
kg_builder/src/knowledge_graph_builder.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
|
2 |
-
from api_connections import
|
3 |
|
4 |
from langchain_community.graphs.graph_document import (
|
5 |
Node as BaseNode,
|
@@ -22,8 +22,11 @@ from langchain.chains.openai_functions import (
|
|
22 |
|
23 |
def extract_and_store_graph(
|
24 |
document: Document,
|
|
|
25 |
nodes:Optional[List[str]] = None,
|
26 |
rels:Optional[List[str]]=None) -> None:
|
|
|
|
|
27 |
# Extract graph data using OpenAI functions
|
28 |
extract_chain = get_extraction_chain(nodes, rels)
|
29 |
data = extract_chain.invoke(document.page_content)['function']
|
|
|
1 |
|
2 |
+
from api_connections import get_graph_connection
|
3 |
|
4 |
from langchain_community.graphs.graph_document import (
|
5 |
Node as BaseNode,
|
|
|
22 |
|
23 |
def extract_and_store_graph(
|
24 |
document: Document,
|
25 |
+
category: str,
|
26 |
nodes:Optional[List[str]] = None,
|
27 |
rels:Optional[List[str]]=None) -> None:
|
28 |
+
|
29 |
+
graph = get_graph_connection(category)
|
30 |
# Extract graph data using OpenAI functions
|
31 |
extract_chain = get_extraction_chain(nodes, rels)
|
32 |
data = extract_chain.invoke(document.page_content)['function']
|
kg_builder/src/main.py
CHANGED
@@ -1,33 +1,43 @@
|
|
|
|
|
|
|
|
1 |
from knowledge_graph_builder import extract_and_store_graph
|
2 |
from query_graph import query_knowledge_graph
|
3 |
from langchain_community.document_loaders import WikipediaLoader
|
4 |
from langchain.text_splitter import TokenTextSplitter
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
-
def
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
print(f"Loading documents for: {article_name}")
|
13 |
-
# Load and process the Wikipedia article
|
14 |
-
raw_documents = WikipediaLoader(query=article_name).load()
|
15 |
-
text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
|
16 |
-
documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
user_query = input("Enter your query related to the graph: ")
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
main()
|
|
|
1 |
+
import os
|
2 |
+
from openai import OpenAI
|
3 |
+
from api_connections import get_graph_connection
|
4 |
from knowledge_graph_builder import extract_and_store_graph
|
5 |
from query_graph import query_knowledge_graph
|
6 |
from langchain_community.document_loaders import WikipediaLoader
|
7 |
from langchain.text_splitter import TokenTextSplitter
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
+
def get_llm():
|
11 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
12 |
+
if not api_key:
|
13 |
+
raise ValueError("No OpenAI API key found in environment variables.")
|
14 |
+
return OpenAI(api_key=api_key)
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def classify_query(query):
|
17 |
+
llm = get_llm()
|
18 |
+
response = llm.Completion.create(
|
19 |
+
model="text-davinci-003", # Consider updating to the latest model as necessary
|
20 |
+
prompt=f"Classify the following query into 'Chemotherapy' or 'Traffic Law': {query}",
|
21 |
+
max_tokens=60
|
22 |
+
)
|
23 |
+
return response.choices[0].text.strip()
|
24 |
|
25 |
+
def main():
|
26 |
+
print("Starting the script...")
|
|
|
27 |
|
28 |
+
# Get user query
|
29 |
+
query = input("Please enter your query: ")
|
30 |
+
|
31 |
+
# Classify the query
|
32 |
+
category = classify_query(query)
|
33 |
+
print(f"Query classified into category: {category}")
|
34 |
+
|
35 |
+
# Get the correct graph connection
|
36 |
+
graph = get_graph_connection(category)
|
37 |
+
|
38 |
+
# Query the correct graph
|
39 |
+
result = query_knowledge_graph(graph, query)
|
40 |
+
print(f"Query result: {result}")
|
41 |
|
42 |
if __name__ == "__main__":
|
43 |
main()
|
kg_builder/src/query_graph.py
CHANGED
@@ -2,7 +2,7 @@ from langchain.chains import GraphCypherQAChain
|
|
2 |
from langchain_openai import ChatOpenAI
|
3 |
from api_connections import graph # Importing 'graph' from 'api_connections.py'
|
4 |
|
5 |
-
def query_knowledge_graph(query):
|
6 |
print("Refreshing the graph schema...")
|
7 |
# Refresh the graph schema before querying
|
8 |
graph.refresh_schema()
|
|
|
2 |
from langchain_openai import ChatOpenAI
|
3 |
from api_connections import graph # Importing 'graph' from 'api_connections.py'
|
4 |
|
5 |
+
def query_knowledge_graph(graph, query):
|
6 |
print("Refreshing the graph schema...")
|
7 |
# Refresh the graph schema before querying
|
8 |
graph.refresh_schema()
|