# Import necessary libraries import os import streamlit as st import pandas as pd from groq import Groq from PIL import Image from io import BytesIO import requests from streamlit_lottie import st_lottie # Import streamlit-lottie # Setup environment variables (set your API keys here) os.environ["GROQ_API_KEY"] = "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" # Initialize Groq client client = Groq(api_key=os.environ.get("GROQ_API_KEY")) # Load dataset once using caching @st.cache_data def load_data(): # Load the CSV file only once data = pd.read_csv('movie_dataset.csv') # Change the path here return data # Load Groq model for chat completion def get_groq_completion(query, model="gemma2-9b-it"): chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": query, } ], model=model, ) return chat_completion.choices[0].message.content # Function to load Lottie animations def load_lottie_url(url: str): r = requests.get(url) if r.status_code != 200: return None return r.json() # Define Streamlit app def main(): # Page Configuration st.set_page_config(page_title="Movie RAG Chatbot", page_icon="🎥", layout="wide", initial_sidebar_state="expanded") # Apply CSS styling for dark theme and animations st.markdown(""" """, unsafe_allow_html=True) st.title("🎬 Movie RAG Chatbot") # Load the dataset only once data = load_data() st.write("### Dataset Overview", data.head()) # Load Lottie animation lottie_animation_url = "https://assets5.lottiefiles.com/packages/lf20_OZ6W7g.json" # Replace with your desired Lottie URL lottie_json = load_lottie_url(lottie_animation_url) if lottie_json: st_lottie(lottie_json, speed=1, width=None, height=None, key="loading") # Query input user_query = st.text_input("Ask about movies or actors:", placeholder="e.g., Tell me about sci-fi movies.") if st.button("Get Response"): # Fetch info from RAG model groq_response = get_groq_completion(user_query) st.write("### Chatbot Response") st.markdown(f"Groq Model: {groq_response}") # Run the main app if __name__ == "__main__": main()