safiaa02 commited on
Commit
35e1c43
·
verified ·
1 Parent(s): 781de02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -54
app.py CHANGED
@@ -3,9 +3,8 @@ import json
3
  import streamlit as st
4
  import faiss
5
  import numpy as np
6
- import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
 
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
11
  from reportlab.lib.styles import getSampleStyleSheet
@@ -16,23 +15,32 @@ with open('milestones.json', 'r') as f:
16
 
17
  # Age categories for dropdown selection
18
  age_categories = {
19
- "Up to 2 months": 2, "Up to 4 months": 4, "Up to 6 months": 6,
20
- "Up to 9 months": 9, "Up to 1 year": 12, "Up to 15 months": 15,
21
- "Up to 18 months": 18, "Up to 2 years": 24, "Up to 30 months": 30,
22
- "Up to 3 years": 36, "Up to 4 years": 48, "Up to 5 years": 60
 
 
 
 
 
 
 
 
23
  }
24
 
25
  # Initialize FAISS and Sentence Transformer
26
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
  def create_faiss_index(data):
29
- descriptions, age_keys = [], []
 
30
  for age, categories in data.items():
31
  for entry in categories:
32
  descriptions.append(entry['description'])
33
- age_keys.append(int(age)) # Convert age to int
34
 
35
- embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
36
  index = faiss.IndexFlatL2(embeddings.shape[1])
37
  index.add(embeddings)
38
  return index, descriptions, age_keys
@@ -41,85 +49,67 @@ index, descriptions, age_keys = create_faiss_index(milestones)
41
 
42
  # Function to retrieve the closest milestone
43
  def retrieve_milestone(user_input):
44
- user_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
45
  _, indices = index.search(user_embedding, 1)
46
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
 
48
- # Load IBM Granite 3.1 model and tokenizer from Hugging Face
49
- MODEL_NAME = "ibm-granite/granite-3.1-8b-instruct"
50
-
51
- @st.cache_resource # Cache model to avoid reloading on every interaction
52
- def load_model():
53
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
54
- model = AutoModelForCausalLM.from_pretrained(
55
- MODEL_NAME,
56
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
57
- device_map="auto" # Auto-select GPU/CPU
58
- )
59
- return tokenizer, model
60
-
61
- tokenizer, granite_model = load_model()
62
 
63
  def generate_response(user_input, child_age):
64
  relevant_milestone = retrieve_milestone(user_input)
65
- prompt = (
66
- f"The child is {child_age} months old. Based on the given traits: {user_input}, "
67
- f"determine whether the child is meeting expected milestones. "
68
- f"Relevant milestone: {relevant_milestone}. "
69
- "If there are any concerns, suggest steps the parents can take."
70
- )
71
-
72
- inputs = tokenizer(prompt, return_tensors="pt").to(granite_model.device)
73
- output = granite_model.generate(**inputs, max_length=512)
74
- return tokenizer.decode(output[0], skip_special_tokens=True)
75
 
76
  # Streamlit UI Styling
77
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
 
 
 
 
 
 
78
 
79
- st.markdown("<h1 style='text-align:center; color:#ffcc00;'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
80
- st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.")
81
 
82
  # User selects child's age
83
  selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
84
  child_age = age_categories[selected_age]
85
 
86
  # User input for traits and skills
87
- placeholder_text = "For example, your child might say simple words like 'mama' and 'dada' and smile when spoken to."
88
  user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
89
 
90
  def generate_pdf_report(ai_response):
91
  pdf_file = "progress_report.pdf"
92
  doc = SimpleDocTemplate(pdf_file, pagesize=A4)
93
  styles = getSampleStyleSheet()
94
-
95
- elements = [
96
- Paragraph("Child Development Progress Report", styles['Title']),
97
- Spacer(1, 12),
98
- Paragraph("Development Insights:", styles['Heading2']),
99
- Spacer(1, 10)
100
- ]
101
-
102
- for part in ai_response.split('\n'):
103
  part = part.strip().lstrip('0123456789.- ')
104
  if part:
105
  elements.append(Paragraph(f"• {part}", styles['Normal']))
106
  elements.append(Spacer(1, 5))
107
-
108
- disclaimer = ("This report is AI-generated and is for informational purposes only. "
109
- "It should not be considered a substitute for professional medical advice. "
110
- "Always consult a qualified pediatrician for expert guidance on your child's development.")
111
  elements.append(Spacer(1, 12))
112
  elements.append(Paragraph(disclaimer, styles['Italic']))
113
-
114
  doc.build(elements)
115
  return pdf_file
116
 
117
  if st.button("🔍 Analyze", help="Click to analyze the child's development milestones"):
118
  ai_response = generate_response(user_input, child_age)
119
-
120
  st.subheader("📊 Development Insights:")
121
  st.markdown(f"<div style='background-color:#44475a; color:#ffffff; padding: 15px; border-radius: 10px;'>{ai_response}</div>", unsafe_allow_html=True)
122
-
123
  pdf_file = generate_pdf_report(ai_response)
124
  with open(pdf_file, "rb") as f:
125
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
 
3
  import streamlit as st
4
  import faiss
5
  import numpy as np
 
 
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import pipeline
8
  from reportlab.lib.pagesizes import A4
9
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
10
  from reportlab.lib.styles import getSampleStyleSheet
 
15
 
16
  # Age categories for dropdown selection
17
  age_categories = {
18
+ "Up to 2 months": 2,
19
+ "Up to 4 months": 4,
20
+ "Up to 6 months": 6,
21
+ "Up to 9 months": 9,
22
+ "Up to 1 year": 12,
23
+ "Up to 15 months": 15,
24
+ "Up to 18 months": 18,
25
+ "Up to 2 years": 24,
26
+ "Up to 30 months": 30,
27
+ "Up to 3 years": 36,
28
+ "Up to 4 years": 48,
29
+ "Up to 5 years": 60
30
  }
31
 
32
  # Initialize FAISS and Sentence Transformer
33
+ model = SentenceTransformer('all-MiniLM-L6-v2')
34
 
35
  def create_faiss_index(data):
36
+ descriptions = []
37
+ age_keys = []
38
  for age, categories in data.items():
39
  for entry in categories:
40
  descriptions.append(entry['description'])
41
+ age_keys.append(int(age))
42
 
43
+ embeddings = model.encode(descriptions, convert_to_numpy=True)
44
  index = faiss.IndexFlatL2(embeddings.shape[1])
45
  index.add(embeddings)
46
  return index, descriptions, age_keys
 
49
 
50
  # Function to retrieve the closest milestone
51
  def retrieve_milestone(user_input):
52
+ user_embedding = model.encode([user_input], convert_to_numpy=True)
53
  _, indices = index.search(user_embedding, 1)
54
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
55
 
56
+ # Initialize IBM Granite Model
57
+ ibm_model = pipeline("text-generation", model="ibm-granite", max_length=512)
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def generate_response(user_input, child_age):
60
  relevant_milestone = retrieve_milestone(user_input)
61
+ prompt = (f"The child is {child_age} months old. Based on the given traits: {user_input}, "
62
+ f"determine whether the child is meeting expected milestones. "
63
+ f"Relevant milestone: {relevant_milestone}. "
64
+ "If there are any concerns, suggest steps the parents can take. ")
65
+ response = ibm_model(prompt)
66
+ return response[0]['generated_text']
 
 
 
 
67
 
68
  # Streamlit UI Styling
69
  st.set_page_config(page_title="Tiny Triumphs Tracker", page_icon="👶", layout="wide")
70
+ st.markdown("""
71
+ <style>
72
+ .stApp { background-color: #1e1e2e; color: #ffffff; }
73
+ .stTitle { text-align: center; color: #ffcc00; font-size: 36px; font-weight: bold; }
74
+ </style>
75
+ """, unsafe_allow_html=True)
76
 
77
+ st.markdown("<h1 class='stTitle'>👶 Tiny Triumphs Tracker</h1>", unsafe_allow_html=True)
78
+ st.markdown("Track your child's key growth milestones from birth to 5 years and detect early developmental concerns.", unsafe_allow_html=True)
79
 
80
  # User selects child's age
81
  selected_age = st.selectbox("📅 Select child's age:", list(age_categories.keys()))
82
  child_age = age_categories[selected_age]
83
 
84
  # User input for traits and skills
85
+ placeholder_text = "Describe your child's behavior and skills."
86
  user_input = st.text_area("✍️ Enter child's behavioral traits and skills:", placeholder=placeholder_text)
87
 
88
  def generate_pdf_report(ai_response):
89
  pdf_file = "progress_report.pdf"
90
  doc = SimpleDocTemplate(pdf_file, pagesize=A4)
91
  styles = getSampleStyleSheet()
92
+ elements = []
93
+ elements.append(Paragraph("Child Development Progress Report", styles['Title']))
94
+ elements.append(Spacer(1, 12))
95
+ elements.append(Paragraph("Development Insights:", styles['Heading2']))
96
+ elements.append(Spacer(1, 10))
97
+ response_parts = ai_response.split('\n')
98
+ for part in response_parts:
 
 
99
  part = part.strip().lstrip('0123456789.- ')
100
  if part:
101
  elements.append(Paragraph(f"• {part}", styles['Normal']))
102
  elements.append(Spacer(1, 5))
103
+ disclaimer = "This report is AI-generated and is for informational purposes only. "
 
 
 
104
  elements.append(Spacer(1, 12))
105
  elements.append(Paragraph(disclaimer, styles['Italic']))
 
106
  doc.build(elements)
107
  return pdf_file
108
 
109
  if st.button("🔍 Analyze", help="Click to analyze the child's development milestones"):
110
  ai_response = generate_response(user_input, child_age)
 
111
  st.subheader("📊 Development Insights:")
112
  st.markdown(f"<div style='background-color:#44475a; color:#ffffff; padding: 15px; border-radius: 10px;'>{ai_response}</div>", unsafe_allow_html=True)
 
113
  pdf_file = generate_pdf_report(ai_response)
114
  with open(pdf_file, "rb") as f:
115
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")