Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import os | |
| from dotenv import load_dotenv | |
| from functools import lru_cache | |
| # Load environment variables | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # App title and description | |
| st.title("I am Your GrowBuddy 🌱") | |
| st.write("Let me help you start gardening. Let's grow together!") | |
| # Function to load model only once (with quantization for CPU optimization) | |
| def load_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN, use_fast=True) | |
| # Quantized model for better CPU performance (with 8-bit precision) | |
| model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN, torch_dtype=torch.float32) | |
| return tokenizer, model | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None, None | |
| # Load model and tokenizer (cached) | |
| tokenizer, model = load_model() | |
| if not tokenizer or not model: | |
| st.stop() | |
| # Ensure model is on CPU (set to float32 for better performance on CPU) | |
| device = torch.device("cpu") | |
| model = model.to(device) | |
| # Initialize session state messages | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [ | |
| {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} | |
| ] | |
| # Display conversation history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # LRU Cache for repeated queries to avoid redundant computation | |
| def cached_generate_response(prompt, tokenizer, model): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
| outputs = model.generate(inputs["input_ids"], max_new_tokens=50, temperature=0.7, do_sample=True) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Function to generate response with optimization | |
| def generate_response(prompt): | |
| try: | |
| # Check cache for previous result (for repeated queries) | |
| cached_response = cached_generate_response(prompt, tokenizer, model) | |
| return cached_response | |
| except Exception as e: | |
| st.error(f"Error during text generation: {e}") | |
| return "Sorry, I couldn't process your request." | |
| # User input field for gardening questions | |
| user_input = st.chat_input("Type your gardening question here:") | |
| if user_input: | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Generating your answer..."): | |
| response = generate_response(user_input) | |
| st.write(response) | |
| # Update session state with new messages | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |