import streamlit as st import pandas as pd import torch import yaml from embeddings import load_model, compute_embeddings # Load configuration from YAML file with open("configs.yaml", "r") as file: configs = yaml.safe_load(file) # Load the processed movie dataset movie_data = pd.read_csv(configs["processed_dataset"]) # Streamlit app st.title("🎬 EasyRec Movie Recommender") # Dropdown for model selection model_names = configs['hf_models'] # Assuming this is a list of model names in your configs.yaml selected_model_name = st.selectbox("Select a model:", model_names) # Load the model based on user selection model, tokenizer = load_model(selected_model_name) # User input for movie description user_description = st.text_input("Enter a description of the type of movie you're interested in:", placeholder="e.g. A romantic comedy with a twist...") if user_description: # Load the precomputed movie embeddings from a .pt file embedding_dir_path = f"{configs['movie_embeddings']}/{selected_model_name}" embedding_file_path = f"{embedding_dir_path}/{configs['movie_embeddings']}.pt" movie_embeddings = torch.load(embedding_file_path) # Load the .pt file # Compute the embedding for the user input by passing it as a list user_embedding = compute_embeddings([user_description], model, tokenizer) similarity_scores = torch.matmul(movie_embeddings, user_embedding.T).flatten() # Set the number of top recommendations to display K = 5 top_k_indices = torch.argsort(similarity_scores, descending=True)[:K].tolist() # Get indices of top K # Display recommendations st.write("## 🎉 Top Recommendations:") for rank, movie_id in enumerate(top_k_indices, start=1): movie = movie_data.iloc[movie_id] # Convert runtime from minutes to hours and minutes hours = movie.runtime // 60 minutes = movie.runtime % 60 # Construct an HTML card for displaying the movie information st.markdown(f"### {rank}. {movie.title}") st.markdown(f"**Release Date:** {movie.release_date}    **Runtime:** {f'{hours}h {minutes}m' if hours > 0 else f'{minutes}m'}") st.markdown(f"⭐ {movie.vote_average} ({movie.vote_count} votes)") st.markdown(f"**Overview:** {movie.overview}") st.markdown(f"**Genres:** {movie.genres}") st.markdown(f"**Production Companies:** {movie.production_companies}") st.markdown(f"**Production Countries:** {movie.production_countries}") st.markdown("---") # Additional styling (optional) st.markdown( """ """, unsafe_allow_html=True )