Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,8 @@ import json
|
|
3 |
import streamlit as st
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
-
import torch
|
7 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from sentence_transformers import SentenceTransformer
|
|
|
9 |
from reportlab.lib.pagesizes import A4
|
10 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
11 |
from reportlab.lib.styles import getSampleStyleSheet
|
@@ -16,23 +15,32 @@ with open('milestones.json', 'r') as f:
|
|
16 |
|
17 |
# Age categories for dropdown selection
|
18 |
age_categories = {
|
19 |
-
"Up to 2 months": 2,
|
20 |
-
"Up to
|
21 |
-
"Up to
|
22 |
-
"Up to
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
|
25 |
# Initialize FAISS and Sentence Transformer
|
26 |
-
|
27 |
|
28 |
def create_faiss_index(data):
|
29 |
-
descriptions
|
|
|
30 |
for age, categories in data.items():
|
31 |
for entry in categories:
|
32 |
descriptions.append(entry['description'])
|
33 |
-
age_keys.append(int(age))
|
34 |
|
35 |
-
embeddings =
|
36 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
37 |
index.add(embeddings)
|
38 |
return index, descriptions, age_keys
|
@@ -41,85 +49,67 @@ index, descriptions, age_keys = create_faiss_index(milestones)
|
|
41 |
|
42 |
# Function to retrieve the closest milestone
|
43 |
def retrieve_milestone(user_input):
|
44 |
-
user_embedding =
|
45 |
_, indices = index.search(user_embedding, 1)
|
46 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
|
51 |
-
@st.cache_resource # Cache model to avoid reloading on every interaction
|
52 |
-
def load_model():
|
53 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
54 |
-
model = AutoModelForCausalLM.from_pretrained(
|
55 |
-
MODEL_NAME,
|
56 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
57 |
-
device_map="auto" # Auto-select GPU/CPU
|
58 |
-
)
|
59 |
-
return tokenizer, model
|
60 |
-
|
61 |
-
tokenizer, granite_model = load_model()
|
62 |
|
63 |
def generate_response(user_input, child_age):
|
64 |
relevant_milestone = retrieve_milestone(user_input)
|
65 |
-
prompt = (
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(granite_model.device)
|
73 |
-
output = granite_model.generate(**inputs, max_length=512)
|
74 |
-
return tokenizer.decode(output[0], skip_special_tokens=True)
|
75 |
|
76 |
# Streamlit UI Styling
|
77 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
st.markdown("<h1
|
80 |
-
st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.")
|
81 |
|
82 |
# User selects child's age
|
83 |
selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
|
84 |
child_age = age_categories[selected_age]
|
85 |
|
86 |
# User input for traits and skills
|
87 |
-
placeholder_text = "
|
88 |
user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
|
89 |
|
90 |
def generate_pdf_report(ai_response):
|
91 |
pdf_file = "progress_report.pdf"
|
92 |
doc = SimpleDocTemplate(pdf_file, pagesize=A4)
|
93 |
styles = getSampleStyleSheet()
|
94 |
-
|
95 |
-
elements
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
for part in ai_response.split('\n'):
|
103 |
part = part.strip().lstrip('0123456789.- ')
|
104 |
if part:
|
105 |
elements.append(Paragraph(f"• {part}", styles['Normal']))
|
106 |
elements.append(Spacer(1, 5))
|
107 |
-
|
108 |
-
disclaimer = ("This report is AI-generated and is for informational purposes only. "
|
109 |
-
"It should not be considered a substitute for professional medical advice. "
|
110 |
-
"Always consult a qualified pediatrician for expert guidance on your child's development.")
|
111 |
elements.append(Spacer(1, 12))
|
112 |
elements.append(Paragraph(disclaimer, styles['Italic']))
|
113 |
-
|
114 |
doc.build(elements)
|
115 |
return pdf_file
|
116 |
|
117 |
if st.button("🔍 Analyze", help="Click to analyze the child's development milestones"):
|
118 |
ai_response = generate_response(user_input, child_age)
|
119 |
-
|
120 |
st.subheader("📊 Development Insights:")
|
121 |
st.markdown(f"<div style='background-color:#44475a; color:#ffffff; padding: 15px; border-radius: 10px;'>{ai_response}</div>", unsafe_allow_html=True)
|
122 |
-
|
123 |
pdf_file = generate_pdf_report(ai_response)
|
124 |
with open(pdf_file, "rb") as f:
|
125 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|
|
|
3 |
import streamlit as st
|
4 |
import faiss
|
5 |
import numpy as np
|
|
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
from transformers import pipeline
|
8 |
from reportlab.lib.pagesizes import A4
|
9 |
from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
|
10 |
from reportlab.lib.styles import getSampleStyleSheet
|
|
|
15 |
|
16 |
# Age categories for dropdown selection
|
17 |
age_categories = {
|
18 |
+
"Up to 2 months": 2,
|
19 |
+
"Up to 4 months": 4,
|
20 |
+
"Up to 6 months": 6,
|
21 |
+
"Up to 9 months": 9,
|
22 |
+
"Up to 1 year": 12,
|
23 |
+
"Up to 15 months": 15,
|
24 |
+
"Up to 18 months": 18,
|
25 |
+
"Up to 2 years": 24,
|
26 |
+
"Up to 30 months": 30,
|
27 |
+
"Up to 3 years": 36,
|
28 |
+
"Up to 4 years": 48,
|
29 |
+
"Up to 5 years": 60
|
30 |
}
|
31 |
|
32 |
# Initialize FAISS and Sentence Transformer
|
33 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
34 |
|
35 |
def create_faiss_index(data):
|
36 |
+
descriptions = []
|
37 |
+
age_keys = []
|
38 |
for age, categories in data.items():
|
39 |
for entry in categories:
|
40 |
descriptions.append(entry['description'])
|
41 |
+
age_keys.append(int(age))
|
42 |
|
43 |
+
embeddings = model.encode(descriptions, convert_to_numpy=True)
|
44 |
index = faiss.IndexFlatL2(embeddings.shape[1])
|
45 |
index.add(embeddings)
|
46 |
return index, descriptions, age_keys
|
|
|
49 |
|
50 |
# Function to retrieve the closest milestone
|
51 |
def retrieve_milestone(user_input):
|
52 |
+
user_embedding = model.encode([user_input], convert_to_numpy=True)
|
53 |
_, indices = index.search(user_embedding, 1)
|
54 |
return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
|
55 |
|
56 |
+
# Initialize IBM Granite Model
|
57 |
+
ibm_model = pipeline("text-generation", model="ibm-granite", max_length=512)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def generate_response(user_input, child_age):
|
60 |
relevant_milestone = retrieve_milestone(user_input)
|
61 |
+
prompt = (f"The child is {child_age} months old. Based on the given traits: {user_input}, "
|
62 |
+
f"determine whether the child is meeting expected milestones. "
|
63 |
+
f"Relevant milestone: {relevant_milestone}. "
|
64 |
+
"If there are any concerns, suggest steps the parents can take. ")
|
65 |
+
response = ibm_model(prompt)
|
66 |
+
return response[0]['generated_text']
|
|
|
|
|
|
|
|
|
67 |
|
68 |
# Streamlit UI Styling
|
69 |
st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
|
70 |
+
st.markdown("""
|
71 |
+
<style>
|
72 |
+
.stApp { background-color: #1e1e2e; color: #ffffff; }
|
73 |
+
.stTitle { text-align: center; color: #ffcc00; font-size: 36px; font-weight: bold; }
|
74 |
+
</style>
|
75 |
+
""", unsafe_allow_html=True)
|
76 |
|
77 |
+
st.markdown("<h1 class='stTitle'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
|
78 |
+
st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.", unsafe_allow_html=True)
|
79 |
|
80 |
# User selects child's age
|
81 |
selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
|
82 |
child_age = age_categories[selected_age]
|
83 |
|
84 |
# User input for traits and skills
|
85 |
+
placeholder_text = "Describe your child's behavior and skills."
|
86 |
user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
|
87 |
|
88 |
def generate_pdf_report(ai_response):
|
89 |
pdf_file = "progress_report.pdf"
|
90 |
doc = SimpleDocTemplate(pdf_file, pagesize=A4)
|
91 |
styles = getSampleStyleSheet()
|
92 |
+
elements = []
|
93 |
+
elements.append(Paragraph("Child Development Progress Report", styles['Title']))
|
94 |
+
elements.append(Spacer(1, 12))
|
95 |
+
elements.append(Paragraph("Development Insights:", styles['Heading2']))
|
96 |
+
elements.append(Spacer(1, 10))
|
97 |
+
response_parts = ai_response.split('\n')
|
98 |
+
for part in response_parts:
|
|
|
|
|
99 |
part = part.strip().lstrip('0123456789.- ')
|
100 |
if part:
|
101 |
elements.append(Paragraph(f"• {part}", styles['Normal']))
|
102 |
elements.append(Spacer(1, 5))
|
103 |
+
disclaimer = "This report is AI-generated and is for informational purposes only. "
|
|
|
|
|
|
|
104 |
elements.append(Spacer(1, 12))
|
105 |
elements.append(Paragraph(disclaimer, styles['Italic']))
|
|
|
106 |
doc.build(elements)
|
107 |
return pdf_file
|
108 |
|
109 |
if st.button("🔍 Analyze", help="Click to analyze the child's development milestones"):
|
110 |
ai_response = generate_response(user_input, child_age)
|
|
|
111 |
st.subheader("📊 Development Insights:")
|
112 |
st.markdown(f"<div style='background-color:#44475a; color:#ffffff; padding: 15px; border-radius: 10px;'>{ai_response}</div>", unsafe_allow_html=True)
|
|
|
113 |
pdf_file = generate_pdf_report(ai_response)
|
114 |
with open(pdf_file, "rb") as f:
|
115 |
st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
|