Prakhar Bhandari commited on
Commit
b77d203
·
1 Parent(s): babec93

First attempt at incorporating multiple graphs

Browse files
kg_builder/src/__pycache__/api_connections.cpython-39.pyc CHANGED
Binary files a/kg_builder/src/__pycache__/api_connections.cpython-39.pyc and b/kg_builder/src/__pycache__/api_connections.cpython-39.pyc differ
 
kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc CHANGED
Binary files a/kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc and b/kg_builder/src/__pycache__/knowledge_graph_builder.cpython-39.pyc differ
 
kg_builder/src/__pycache__/models.cpython-39.pyc CHANGED
Binary files a/kg_builder/src/__pycache__/models.cpython-39.pyc and b/kg_builder/src/__pycache__/models.cpython-39.pyc differ
 
kg_builder/src/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/kg_builder/src/__pycache__/utils.cpython-39.pyc and b/kg_builder/src/__pycache__/utils.cpython-39.pyc differ
 
kg_builder/src/api_connections.py CHANGED
@@ -11,17 +11,21 @@ from typing import Optional, List
11
 
12
  load_dotenv() # This loads the variables from .env into os.environ
13
 
14
- # Now use os.getenv to access your variables
15
- url = os.getenv("NEO4J_URL")
16
- username = os.getenv("NEO4J_USERNAME")
17
- password = os.getenv("NEO4J_PASSWORD")
18
- openai_api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
 
 
 
 
19
 
20
- graph = Neo4jGraph(
21
- url=url,
22
- username=username,
23
- password=password
24
- )
25
 
26
  def get_llm():
27
  api_key = os.getenv("OPENAI_API_KEY")
@@ -30,44 +34,23 @@ def get_llm():
30
  return ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
31
 
32
  def get_extraction_chain(
 
33
  allowed_nodes: Optional[List[str]] = None,
34
  allowed_rels: Optional[List[str]] = None
35
  ):
 
 
 
 
 
 
 
 
 
36
  llm = get_llm()
37
  prompt = ChatPromptTemplate.from_messages(
38
  [(
39
- "system",
40
- f"""# Knowledge Graph Instructions for GPT-4
41
- ## 1. Overview
42
- You are a sophisticated algorithm tailored for parsing Wikipedia pages to construct a knowledge graph about chemotherapy and related cancer treatments.
43
- - **Nodes** symbolize entities such as medical conditions, drugs, symptoms, treatments, and associated medical concepts.
44
- - The goal is to create a precise and comprehensible knowledge graph, serving as a reliable resource for medical practitioners and scholarly research.
45
-
46
- ## 2. Labeling Nodes
47
- - **Consistency**: Utilize uniform labels for node types to maintain clarity.
48
- - For instance, consistently label drugs as **"Drug"**, symptoms as **"Symptom"**, and treatments as **"Treatment"**.
49
- - **Node IDs**: Apply descriptive, legible identifiers for node IDs, sourced directly from the text.
50
-
51
- {'- **Allowed Node Labels:**' + ", ".join(['Drug', 'Symptom', 'Treatment', 'MedicalCondition', 'ResearchStudy']) if allowed_nodes else ""}
52
- {'- **Allowed Relationship Types**:' + ", ".join(['Treats', 'Causes', 'Researches', 'Recommends']) if allowed_rels else ""}
53
-
54
- ## 3. Handling Numerical Data and Dates
55
- - Integrate numerical data and dates as attributes of the corresponding nodes.
56
- - **No Isolated Nodes for Dates/Numbers**: Directly associate dates and numerical figures as attributes with pertinent nodes.
57
- - **Property Format**: Follow a straightforward key-value pattern for properties, with keys in camelCase, for example, `approvedYear`, `dosageAmount`.
58
-
59
- ## 4. Coreference Resolution
60
- - **Entity Consistency**: Guarantee uniform identification of each entity across the graph.
61
- - For example, if "Methotrexate" and "MTX" reference the same medication, uniformly apply "Methotrexate" as the node ID.
62
-
63
- ## 5. Relationship Naming Conventions
64
- - **Clarity and Standardization**: Utilize clear and standardized relationship names, preferring uppercase with underscores for readability.
65
- - For instance, use "HAS_SIDE_EFFECT" instead of "HASSIDEEFFECT", use "CAN_RESULT_FROM" instead of "CANRESULTFROM" etc. You keep making the same mistakes of storing the relationships without the "_" in between the words. Any further similar errors will lead to termination.
66
- - **Relevance and Specificity**: Choose relationship names that accurately reflect the connection between nodes, such as "INHIBITS" or "ACTIVATES" for interactions between substances.
67
-
68
- ## 6. Strict Compliance
69
- Rigorous adherence to these instructions is essential. Failure to comply with the specified formatting and labeling norms will necessitate output revision or discard.
70
- """),
71
  ("human", "Use the given format to extract information from the following input: {input}"),
72
  ("human", "Tip: Precision in the node and relationship creation is vital for the integrity of the knowledge graph."),
73
  ])
 
11
 
12
  load_dotenv() # This loads the variables from .env into os.environ
13
 
14
+ def get_graph_connection(category):
15
+ if category == "Chemotherapy":
16
+ url = os.getenv("CHEMO_NEO4J_URL")
17
+ username = os.getenv("CHEMO_NEO4J_USERNAME")
18
+ password = os.getenv("CHEMO_NEO4J_PASSWORD")
19
+ elif category == "Traffic Law":
20
+ url = os.getenv("TRAFFIC_NEO4J_URL")
21
+ username = os.getenv("TRAFFIC_NEO4J_USERNAME")
22
+ password = os.getenv("TRAFFIC_NEO4J_PASSWORD")
23
+ else:
24
+ raise ValueError(f"Unknown category: {category}")
25
+
26
+ return Neo4jGraph(url=url, username=username, password=password)
27
 
28
+ openai_api_key = os.getenv("OPENAI_API_KEY")
 
 
 
 
29
 
30
  def get_llm():
31
  api_key = os.getenv("OPENAI_API_KEY")
 
34
  return ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)
35
 
36
  def get_extraction_chain(
37
+ category,
38
  allowed_nodes: Optional[List[str]] = None,
39
  allowed_rels: Optional[List[str]] = None
40
  ):
41
+ if category == "Chemotherapy":
42
+ # Chemotherapy-specific prompt
43
+ prompt_text = ""
44
+ elif category == "Traffic Law":
45
+ # Traffic Law-specific prompt
46
+ prompt_text = "[Traffic Law-specific instructions]"
47
+ else:
48
+ raise ValueError("Unknown category")
49
+
50
  llm = get_llm()
51
  prompt = ChatPromptTemplate.from_messages(
52
  [(
53
+ "system",prompt_text),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ("human", "Use the given format to extract information from the following input: {input}"),
55
  ("human", "Tip: Precision in the node and relationship creation is vital for the integrity of the knowledge graph."),
56
  ])
kg_builder/src/graph_creation.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import WikipediaLoader
2
+ from langchain.text_splitter import TokenTextSplitter
3
+ from knowledge_graph_builder import extract_and_store_graph
4
+ from langchain.schema import Document
5
+ from dotenv import load_dotenv
6
+ from tqdm import tqdm
7
+ import os
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Define articles to load
13
+ articles = {
14
+ "Chemotherapy": "Chemotherapy",
15
+ "Traffic Law": "Traffic laws in the United States"
16
+ }
17
+
18
+ def build_graph_for_article(article_name, category):
19
+ print(f"Loading documents for: {article_name}")
20
+ # Load and process the Wikipedia article
21
+ raw_documents = WikipediaLoader(query=article_name).load()
22
+ if not raw_documents:
23
+ print(f"Failed to load content for {article_name}")
24
+ return
25
+
26
+ text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
27
+ documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
28
+
29
+ print("Building the knowledge graph...")
30
+ for i, document in tqdm(enumerate(documents), total=len(documents)):
31
+ extract_and_store_graph(document, category)
32
+
33
+ def main():
34
+ for category, title in articles.items():
35
+ build_graph_for_article(title, category)
36
+
37
+ if __name__ == "__main__":
38
+ main()
39
+
40
+ # import os
41
+ # from openai import OpenAI
42
+ # from api_connections import get_graph_connection
43
+ # from knowledge_graph_builder import extract_and_store_graph
44
+ # from query_graph import query_knowledge_graph
45
+ # from langchain_community.document_loaders import WikipediaLoader
46
+ # from langchain.text_splitter import TokenTextSplitter
47
+ # from tqdm import tqdm
48
+
49
+ # def get_llm():
50
+ # api_key = os.getenv("OPENAI_API_KEY")
51
+ # if not api_key:
52
+ # raise ValueError("No OpenAI API key found in environment variables.")
53
+ # return OpenAI(api_key=api_key)
54
+
55
+ # def classify_query(query):
56
+ # llm = get_llm()
57
+ # response = llm.Completion.create(
58
+ # model="text-davinci-003", # Consider updating to the latest model as necessary
59
+ # prompt=f"Classify the following query into 'Chemotherapy' or 'Traffic Law': {query}",
60
+ # max_tokens=60
61
+ # )
62
+ # return response.choices[0].text.strip()
63
+
64
+ # def main():
65
+ # print("Starting the script...")
66
+ # # Take Wikipedia article name as input
67
+ # article_name = input("Enter the Wikipedia article name: ")
68
+
69
+ # print(f"Loading documents for: {article_name}")
70
+ # # Load and process the Wikipedia article
71
+ # raw_documents = WikipediaLoader(query=article_name).load()
72
+ # text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
73
+ # documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
74
+
75
+ # print("Building the knowledge graph...")
76
+ # # Build the knowledge graph from the documents
77
+ # for i, d in tqdm(enumerate(documents), total=len(documents)):
78
+ # extract_and_store_graph(d)
79
+
80
+ # print("Graph construction complete. Please enter your query.")
81
+ # # Take a query related to the graph
82
+ # user_query = input("Enter your query related to the graph: ")
83
+
84
+ # print(f"Querying the graph with: {user_query}")
85
+ # # Query the graph and print the answer
86
+ # answer = query_knowledge_graph(user_query)
87
+ # print("Answer to your query:", answer)
88
+
89
+ # if __name__ == "__main__":
90
+ # main()
kg_builder/src/knowledge_graph_builder.py CHANGED
@@ -1,5 +1,5 @@
1
 
2
- from api_connections import graph
3
 
4
  from langchain_community.graphs.graph_document import (
5
  Node as BaseNode,
@@ -22,8 +22,11 @@ from langchain.chains.openai_functions import (
22
 
23
  def extract_and_store_graph(
24
  document: Document,
 
25
  nodes:Optional[List[str]] = None,
26
  rels:Optional[List[str]]=None) -> None:
 
 
27
  # Extract graph data using OpenAI functions
28
  extract_chain = get_extraction_chain(nodes, rels)
29
  data = extract_chain.invoke(document.page_content)['function']
 
1
 
2
+ from api_connections import get_graph_connection
3
 
4
  from langchain_community.graphs.graph_document import (
5
  Node as BaseNode,
 
22
 
23
  def extract_and_store_graph(
24
  document: Document,
25
+ category: str,
26
  nodes:Optional[List[str]] = None,
27
  rels:Optional[List[str]]=None) -> None:
28
+
29
+ graph = get_graph_connection(category)
30
  # Extract graph data using OpenAI functions
31
  extract_chain = get_extraction_chain(nodes, rels)
32
  data = extract_chain.invoke(document.page_content)['function']
kg_builder/src/main.py CHANGED
@@ -1,33 +1,43 @@
 
 
 
1
  from knowledge_graph_builder import extract_and_store_graph
2
  from query_graph import query_knowledge_graph
3
  from langchain_community.document_loaders import WikipediaLoader
4
  from langchain.text_splitter import TokenTextSplitter
5
  from tqdm import tqdm
6
 
7
- def main():
8
- print("Starting the script...")
9
- # Take Wikipedia article name as input
10
- article_name = input("Enter the Wikipedia article name: ")
11
-
12
- print(f"Loading documents for: {article_name}")
13
- # Load and process the Wikipedia article
14
- raw_documents = WikipediaLoader(query=article_name).load()
15
- text_splitter = TokenTextSplitter(chunk_size=4096, chunk_overlap=96)
16
- documents = text_splitter.split_documents(raw_documents[:5]) # Only process the first 5 documents
17
 
18
- print("Building the knowledge graph...")
19
- # Build the knowledge graph from the documents
20
- for i, d in tqdm(enumerate(documents), total=len(documents)):
21
- extract_and_store_graph(d)
 
 
 
 
22
 
23
- print("Graph construction complete. Please enter your query.")
24
- # Take a query related to the graph
25
- user_query = input("Enter your query related to the graph: ")
26
 
27
- print(f"Querying the graph with: {user_query}")
28
- # Query the graph and print the answer
29
- answer = query_knowledge_graph(user_query)
30
- print("Answer to your query:", answer)
 
 
 
 
 
 
 
 
 
31
 
32
  if __name__ == "__main__":
33
  main()
 
1
+ import os
2
+ from openai import OpenAI
3
+ from api_connections import get_graph_connection
4
  from knowledge_graph_builder import extract_and_store_graph
5
  from query_graph import query_knowledge_graph
6
  from langchain_community.document_loaders import WikipediaLoader
7
  from langchain.text_splitter import TokenTextSplitter
8
  from tqdm import tqdm
9
 
10
+ def get_llm():
11
+ api_key = os.getenv("OPENAI_API_KEY")
12
+ if not api_key:
13
+ raise ValueError("No OpenAI API key found in environment variables.")
14
+ return OpenAI(api_key=api_key)
 
 
 
 
 
15
 
16
+ def classify_query(query):
17
+ llm = get_llm()
18
+ response = llm.Completion.create(
19
+ model="text-davinci-003", # Consider updating to the latest model as necessary
20
+ prompt=f"Classify the following query into 'Chemotherapy' or 'Traffic Law': {query}",
21
+ max_tokens=60
22
+ )
23
+ return response.choices[0].text.strip()
24
 
25
+ def main():
26
+ print("Starting the script...")
 
27
 
28
+ # Get user query
29
+ query = input("Please enter your query: ")
30
+
31
+ # Classify the query
32
+ category = classify_query(query)
33
+ print(f"Query classified into category: {category}")
34
+
35
+ # Get the correct graph connection
36
+ graph = get_graph_connection(category)
37
+
38
+ # Query the correct graph
39
+ result = query_knowledge_graph(graph, query)
40
+ print(f"Query result: {result}")
41
 
42
  if __name__ == "__main__":
43
  main()
kg_builder/src/query_graph.py CHANGED
@@ -2,7 +2,7 @@ from langchain.chains import GraphCypherQAChain
2
  from langchain_openai import ChatOpenAI
3
  from api_connections import graph # Importing 'graph' from 'api_connections.py'
4
 
5
- def query_knowledge_graph(query):
6
  print("Refreshing the graph schema...")
7
  # Refresh the graph schema before querying
8
  graph.refresh_schema()
 
2
  from langchain_openai import ChatOpenAI
3
  from api_connections import graph # Importing 'graph' from 'api_connections.py'
4
 
5
+ def query_knowledge_graph(graph, query):
6
  print("Refreshing the graph schema...")
7
  # Refresh the graph schema before querying
8
  graph.refresh_schema()