Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| # Set page configuration as the first Streamlit command | |
| st.set_page_config( | |
| page_title="FDA NDA Submission Assistant", | |
| layout="centered", | |
| initial_sidebar_state="auto" | |
| ) | |
| # Apply custom CSS for retro 80s green theme | |
| def apply_custom_css(): | |
| try: | |
| with open("style.css") as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| except FileNotFoundError: | |
| st.warning("style.css not found. Using default styles.") | |
| def load_model(): | |
| model_path = "HuggingFaceH4/zephyr-7b-beta" | |
| peft_model_path = "yitzashapiro/FDA-guidance-zephyr-7b-beta-PEFT" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| torch_dtype=torch.float16 # Adjust if necessary | |
| ).eval() | |
| model.load_adapter(peft_model_path) | |
| st.success("Model loaded successfully.") | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| st.stop() | |
| return tokenizer, model | |
| def generate_response(tokenizer, model, user_input): | |
| messages = [ | |
| {"role": "user", "content": user_input} | |
| ] | |
| try: | |
| if hasattr(tokenizer, 'apply_chat_template'): | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation=messages, | |
| max_length=45, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors='pt' | |
| ) | |
| else: | |
| input_ids = tokenizer( | |
| user_input, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=45 | |
| )['input_ids'] | |
| pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| attention_mask = (input_ids != pad_token_id).long() | |
| output_ids = model.generate( | |
| input_ids.to(model.device), | |
| max_length=2048, | |
| max_new_tokens=500, | |
| attention_mask=attention_mask.to(model.device) | |
| ) | |
| response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| st.error(f"Error generating response: {e}") | |
| return "An error occurred while generating the response." | |
| def main(): | |
| apply_custom_css() | |
| st.title("FDA NDA Submission Assistant") | |
| st.write("Ask the model about submitting an NDA to the FDA.") | |
| tokenizer, model = load_model() | |
| user_input = st.text_input("Enter your question:", "What's the best way to submit an NDA to the FDA?") | |
| if st.button("Generate Response"): | |
| if user_input.strip() == "": | |
| st.error("Please enter a valid question.") | |
| else: | |
| try: | |
| with st.spinner("Generating response..."): | |
| response = generate_response(tokenizer, model, user_input) | |
| st.success("Response:") | |
| st.write(response) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| main() |