safiaa02 commited on
Commit
f3b63f9
·
verified ·
1 Parent(s): 1fa3d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  import faiss
5
  import numpy as np
6
  import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
@@ -23,7 +23,7 @@ age_categories = {
23
  }
24
 
25
  # Initialize FAISS and Sentence Transformer
26
- model = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
  def create_faiss_index(data):
29
  descriptions, age_keys = [], []
@@ -32,7 +32,7 @@ def create_faiss_index(data):
32
  descriptions.append(entry['description'])
33
  age_keys.append(int(age)) # Convert age to int
34
 
35
- embeddings = model.encode(descriptions, convert_to_numpy=True)
36
  index = faiss.IndexFlatL2(embeddings.shape[1])
37
  index.add(embeddings)
38
  return index, descriptions, age_keys
@@ -41,14 +41,14 @@ index, descriptions, age_keys = create_faiss_index(milestones)
41
 
42
  # Function to retrieve the closest milestone
43
  def retrieve_milestone(user_input):
44
- user_embedding = model.encode([user_input], convert_to_numpy=True)
45
  _, indices = index.search(user_embedding, 1)
46
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
 
48
- # Load IBM Granite model and tokenizer
49
- model_name = "ibm/granite-13b-chat"
50
  tokenizer = AutoTokenizer.from_pretrained(model_name)
51
- granite_model = AutoModelForCausalLM.from_pretrained(
52
  model_name, torch_dtype=torch.float16, device_map="auto"
53
  )
54
 
@@ -125,4 +125,4 @@ if st.button("🔍 Analyze", help="Click to analyze the child's development mile
125
  with open(pdf_file, "rb") as f:
126
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
127
 
128
- st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")
 
4
  import faiss
5
  import numpy as np
6
  import torch
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
  from reportlab.lib.pagesizes import A4
10
  from reportlab.platypus import Paragraph, SimpleDocTemplate, Spacer
 
23
  }
24
 
25
  # Initialize FAISS and Sentence Transformer
26
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
  def create_faiss_index(data):
29
  descriptions, age_keys = [], []
 
32
  descriptions.append(entry['description'])
33
  age_keys.append(int(age)) # Convert age to int
34
 
35
+ embeddings = embedding_model.encode(descriptions, convert_to_numpy=True)
36
  index = faiss.IndexFlatL2(embeddings.shape[1])
37
  index.add(embeddings)
38
  return index, descriptions, age_keys
 
41
 
42
  # Function to retrieve the closest milestone
43
  def retrieve_milestone(user_input):
44
+ user_embedding = embedding_model.encode([user_input], convert_to_numpy=True)
45
  _, indices = index.search(user_embedding, 1)
46
  return descriptions[indices[0][0]] if indices[0][0] < len(descriptions) else "No relevant milestone found."
47
 
48
+ # Load IBM Granite 3.1 model and tokenizer
49
+ model_name = "ibm-granite/granite-3.1-8b-instruct"
50
  tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ granite_model = AutoModelForSeq2SeqLM.from_pretrained(
52
  model_name, torch_dtype=torch.float16, device_map="auto"
53
  )
54
 
 
125
  with open(pdf_file, "rb") as f:
126
  st.download_button(label="📥 Download Progress Report", data=f, file_name="progress_report.pdf", mime="application/pdf")
127
 
128
+ st.warning("⚠️ The results provided are generated by AI and should be interpreted with caution. Please consult a pediatrician for professional advice.")