Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
|
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
-
from transformers import
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
from reportlab.lib.pagesizes import A4
|
10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
@@ -23,7 +23,7 @@ age_categories = {
|
|
23 |
}
|
24 |
|
25 |
# Initialize FAISS and Sentence Transformer
|
26 |
-
|
27 |
|
28 |
def create_faiss_index(data):
|
29 |
descriptions, age_keys = [], []
|
@@ -32,7 +32,7 @@ def create_faiss_index(data):
|
|
32 |
descriptions.append(entry['description'])
|
33 |
age_keys.append(int(age)) # Convert age to int
|
34 |
|
35 |
-
embeddings =
|
36 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
37 |
index.add(embeddings)
|
38 |
return index, descriptions, age_keys
|
@@ -41,14 +41,14 @@ index, descriptions, age_keys = create_faiss_index(milestones)
|
|
41 |
|
42 |
# Function to retrieve the closest milestone
|
43 |
def retrieve_milestone(user_input):
|
44 |
-
user_embedding =
|
45 |
_, indices = index.search(user_embedding, 1)
|
46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
47 |
|
48 |
-
# Load IBM Granite model and tokenizer
|
49 |
-
model_name = "ibm/granite-
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
51 |
-
granite_model =
|
52 |
model_name, torch_dtype=torch.float16, device_map="auto"
|
53 |
)
|
54 |
|
@@ -125,4 +125,4 @@ if st.button("🔍 Analyze", help="Click to analyze the child's development mile
|
|
125 |
with open(pdf_file, "rb") as f:
|
126 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
127 |
|
128 |
-
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|
|
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
from reportlab.lib.pagesizes import A4
|
10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
|
|
23 |
}
|
24 |
|
25 |
# Initialize FAISS and Sentence Transformer
|
26 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
27 |
|
28 |
def create_faiss_index(data):
|
29 |
descriptions, age_keys = [], []
|
|
|
32 |
descriptions.append(entry['description'])
|
33 |
age_keys.append(int(age)) # Convert age to int
|
34 |
|
35 |
+
embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
|
36 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
37 |
index.add(embeddings)
|
38 |
return index, descriptions, age_keys
|
|
|
41 |
|
42 |
# Function to retrieve the closest milestone
|
43 |
def retrieve_milestone(user_input):
|
44 |
+
user_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
|
45 |
_, indices = index.search(user_embedding, 1)
|
46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
47 |
|
48 |
+
# Load IBM Granite 3.1 model and tokenizer
|
49 |
+
model_name = "ibm-granite/granite-3.1-8b-instruct"
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
51 |
+
granite_model = AutoModelForSeq2SeqLM.from_pretrained(
|
52 |
model_name, torch_dtype=torch.float16, device_map="auto"
|
53 |
)
|
54 |
|
|
|
125 |
with open(pdf_file, "rb") as f:
|
126 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
127 |
|
128 |
+
st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
|