Spaces:
Running
Running
ACMCMC
commited on
Commit
·
47c6369
1
Parent(s):
d8b73be
WIP
Browse files- app.py +13 -5
- llm_res.py +1 -1
- utils.py +9 -9
app.py
CHANGED
@@ -45,8 +45,7 @@ with st.container(): # user input
|
|
45 |
col1, col2 = st.columns((6, 1))
|
46 |
|
47 |
with col1:
|
48 |
-
description_input = st.text_area(label="Enter a disease description 👇", placeholder='A
|
49 |
-
|
50 |
with col2:
|
51 |
st.text('') # dummy to center vertically
|
52 |
st.text('') # dummy to center vertically
|
@@ -67,25 +66,34 @@ with st.container():
|
|
67 |
)
|
68 |
status.info(f'Found {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
69 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
|
|
70 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
71 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
72 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
73 |
-
get_similarities_among_diseases_uris(diseases_uris)
|
|
|
|
|
|
|
74 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
75 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
76 |
status.write("Augmenting the set of diseases by finding others with related embeddings...")
|
77 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
78 |
# print(augmented_set_of_diseases)
|
|
|
79 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
80 |
status.write("Getting the clinical trials related to the diseases found...")
|
81 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
82 |
augmented_set_of_diseases, encoder
|
83 |
)
|
|
|
|
|
|
|
84 |
status.write("Getting the details of the clinical trials...")
|
85 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
86 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
87 |
)
|
88 |
status.json(json_of_clinical_trials, expanded=False)
|
|
|
89 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
90 |
status.write("Getting a summary of the clinical trials...")
|
91 |
response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
@@ -109,9 +117,9 @@ with st.container():
|
|
109 |
st.info(
|
110 |
"""This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
|
111 |
|
112 |
-
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the (
|
113 |
|
114 |
-
(
|
115 |
|
116 |
Specifically, it optimizes the following cost function:
|
117 |
$$"""
|
|
|
45 |
col1, col2 = st.columns((6, 1))
|
46 |
|
47 |
with col1:
|
48 |
+
description_input = st.text_area(label="Enter a disease description 👇", placeholder='A disorder manifested in memory loss and other cognitive impairments among elderly patients (60+ years old), especially women.')
|
|
|
49 |
with col2:
|
50 |
st.text('') # dummy to center vertically
|
51 |
st.text('') # dummy to center vertically
|
|
|
66 |
)
|
67 |
status.info(f'Found {len(diseases_related_to_the_user_text)} diseases related to the description you entered.')
|
68 |
status.json(diseases_related_to_the_user_text, expanded=False)
|
69 |
+
status.divider()
|
70 |
# 3. Get the similarities of the embeddings of those diseases (cosine similarity of the embeddings of the nodes of such diseases)
|
71 |
status.write("Getting the similarities among the diseases to filter out less promising ones...")
|
72 |
diseases_uris = [disease["uri"] for disease in diseases_related_to_the_user_text]
|
73 |
+
similarities = get_similarities_among_diseases_uris(diseases_uris)
|
74 |
+
status.info(f'Obtained similarity information among the diseases by measuring the cosine similarity of their embeddings. Using the similarity information to filter out less promising diseases.')
|
75 |
+
status.json(similarities, expanded=False)
|
76 |
+
status.divider()
|
77 |
# 4. Potentially filter out the diseases that are not similar enough (e.g. similarity < 0.8)
|
78 |
# 5. Augment the set of diseases: add new diseases that are similar to the ones that are already in the set, until we get 10-15 diseases
|
79 |
status.write("Augmenting the set of diseases by finding others with related embeddings...")
|
80 |
augmented_set_of_diseases = augment_the_set_of_diseaces(diseases_uris)
|
81 |
# print(augmented_set_of_diseases)
|
82 |
+
status.info(f'Augmented set of diseases: {len(augmented_set_of_diseases)} diseases.')
|
83 |
# 6. Query the embeddings of the diseases related to each clinical trial (also in the DB), to get the most similar clinical trials to our set of diseases
|
84 |
status.write("Getting the clinical trials related to the diseases found...")
|
85 |
clinical_trials_related_to_the_diseases = get_clinical_trials_related_to_diseases(
|
86 |
augmented_set_of_diseases, encoder
|
87 |
)
|
88 |
+
status.info(f'Found {len(clinical_trials_related_to_the_diseases)} clinical trials related to the diseases.')
|
89 |
+
status.json(clinical_trials_related_to_the_diseases, expanded=False)
|
90 |
+
status.divider()
|
91 |
status.write("Getting the details of the clinical trials...")
|
92 |
json_of_clinical_trials = get_clinical_records_by_ids(
|
93 |
[trial["nct_id"] for trial in clinical_trials_related_to_the_diseases]
|
94 |
)
|
95 |
status.json(json_of_clinical_trials, expanded=False)
|
96 |
+
status.divider()
|
97 |
# 7. Use an LLM to get a summary of the clinical trials, in plain text format.
|
98 |
status.write("Getting a summary of the clinical trials...")
|
99 |
response, stats_dict = get_short_summary_out_of_json_files(json_of_clinical_trials)
|
|
|
117 |
st.info(
|
118 |
"""This is a graph of the relevant diseases that we found, based on the description that you entered. The diseases are connected by edges if they are similar to each other. The color of the edges represents the similarity of the diseases.
|
119 |
|
120 |
+
We use the embeddings of the diseases to determine the similarity between them. The embeddings are generated using a Representation Learning algorithm that learns the topological relations among the nodes in the graph, depending on how they are connected. We utilize the [PyKeen](https://github.com/pykeen/pykeen) implementation of TransH to train an embedding model.
|
121 |
|
122 |
+
[TransH](https://ojs.aaai.org/index.php/AAAI/article/view/8870) utilizes hyperplanes to model relations between entities. It is a multi-relational model that can handle many-to-many relations between entities. The model is trained on the triples of the graph, where the triples are the subject, relation, and object of the graph. The model learns the embeddings of the entities and the relations, such that the embeddings of the subject and object are close to each other when the relation is true.
|
123 |
|
124 |
Specifically, it optimizes the following cost function:
|
125 |
$$"""
|
llm_res.py
CHANGED
@@ -221,7 +221,7 @@ def process_dictionaty_with_llm_to_generate_response(json_data):
|
|
221 |
return filtered_data
|
222 |
|
223 |
def get_short_summary_out_of_json_files(data_json):
|
224 |
-
|
225 |
|
226 |
# # Task
|
227 |
# You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
|
|
221 |
return filtered_data
|
222 |
|
223 |
def get_short_summary_out_of_json_files(data_json):
|
224 |
+
prompt_template = """You are an expert clinician working on the analysis of reports of clinical trials.
|
225 |
|
226 |
# # Task
|
227 |
# You will be given a set of descriptions of clinical trials. Your job is to come up with a short summary (100-200 words) of the descriptions of the clinical trials. Your users are clinical researchers who are experts in medicine, so you should be technical and specific, including scientific terms. Always be faithful to the original information written in the reports.
|
utils.py
CHANGED
@@ -125,15 +125,15 @@ def get_similarities_among_diseases_uris(
|
|
125 |
|
126 |
|
127 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
128 |
-
|
129 |
-
for i in range(15-len(
|
130 |
with engine.connect() as conn:
|
131 |
with conn.begin():
|
132 |
sql = f"""
|
133 |
-
SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(
|
134 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
135 |
-
WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in
|
136 |
-
AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in
|
137 |
AND e2.label != 'nan'
|
138 |
GROUP BY e2.label
|
139 |
ORDER BY score DESC
|
@@ -142,9 +142,9 @@ def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
|
142 |
result = conn.execute(text(sql))
|
143 |
data = result.fetchall()
|
144 |
|
145 |
-
|
146 |
|
147 |
-
return
|
148 |
|
149 |
def get_embedding(string: str, encoder) -> List[float]:
|
150 |
# Embed the string using sentence-transformers
|
@@ -162,14 +162,14 @@ def get_diseases_related_to_a_textual_description(
|
|
162 |
with engine.connect() as conn:
|
163 |
with conn.begin():
|
164 |
sql = f"""
|
165 |
-
SELECT TOP
|
166 |
FROM Test.DiseaseDescriptions d
|
167 |
ORDER BY distance DESC
|
168 |
"""
|
169 |
result = conn.execute(text(sql))
|
170 |
data = result.fetchall()
|
171 |
|
172 |
-
return [{"uri": row[0], "distance": row[1]} for row in data]
|
173 |
|
174 |
def get_clinical_trials_related_to_diseases(
|
175 |
diseases: List[str], encoder
|
|
|
125 |
|
126 |
|
127 |
def augment_the_set_of_diseaces(diseases: List[str]) -> str:
|
128 |
+
augmented_diseases = diseases.copy()
|
129 |
+
for i in range(15-len(augmented_diseases)):
|
130 |
with engine.connect() as conn:
|
131 |
with conn.begin():
|
132 |
sql = f"""
|
133 |
+
SELECT TOP 1 e2.uri AS new_disease, (SUM(VECTOR_COSINE(e1.embedding, e2.embedding))/ {len(augmented_diseases)}) AS score
|
134 |
FROM Test.EntityEmbeddings e1, Test.EntityEmbeddings e2
|
135 |
+
WHERE e1.uri IN ({','.join([f"'{disease}'" for disease in augmented_diseases])})
|
136 |
+
AND e2.uri NOT IN ({','.join([f"'{disease}'" for disease in augmented_diseases])})
|
137 |
AND e2.label != 'nan'
|
138 |
GROUP BY e2.label
|
139 |
ORDER BY score DESC
|
|
|
142 |
result = conn.execute(text(sql))
|
143 |
data = result.fetchall()
|
144 |
|
145 |
+
augmented_diseases.append(data[0][0])
|
146 |
|
147 |
+
return augmented_diseases
|
148 |
|
149 |
def get_embedding(string: str, encoder) -> List[float]:
|
150 |
# Embed the string using sentence-transformers
|
|
|
162 |
with engine.connect() as conn:
|
163 |
with conn.begin():
|
164 |
sql = f"""
|
165 |
+
SELECT TOP 10 d.uri, VECTOR_COSINE(d.embedding, TO_VECTOR('{string_representation}', DOUBLE)) AS distance
|
166 |
FROM Test.DiseaseDescriptions d
|
167 |
ORDER BY distance DESC
|
168 |
"""
|
169 |
result = conn.execute(text(sql))
|
170 |
data = result.fetchall()
|
171 |
|
172 |
+
return [{"uri": row[0], "distance": row[1]} for row in data if row[1] > 0.8]
|
173 |
|
174 |
def get_clinical_trials_related_to_diseases(
|
175 |
diseases: List[str], encoder
|