Spaces:
Running
Running
import gc | |
import torch | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import os | |
# Load Hugging Face token | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Define models | |
MODELS = { | |
"atlas-flash-1215": { | |
"name": "π¦ Atlas-Flash 1215", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview", | |
}, | |
"emoji": "π¦", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_FLASH_1215", | |
}, | |
"atlas-pro-0403": { | |
"name": "π Atlas-Pro 0403", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview", | |
}, | |
"emoji": "π", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_PRO_0403", | |
}, | |
} | |
# Profile pictures | |
USER_PFP = "user.png" | |
AI_PFP = "ai_pfp.png" | |
st.set_page_config( | |
page_title="Atlas Model Inference", | |
page_icon="π¦ ", | |
layout="wide", | |
menu_items={ | |
'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', | |
'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new', | |
'About': 'Athena Model Inference Platform' | |
} | |
) | |
st.markdown( | |
""" | |
<style> | |
.stSlider > div > div > div > div { | |
background-color: #1f78b4 !important; | |
} | |
.stButton > button { | |
background-color: #1f78b4 !important; | |
color: white !important; | |
border: none !important; | |
} | |
.stButton > button:hover { | |
background-color: #16609a !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
class AtlasInferenceApp: | |
def __init__(self): | |
if "current_model" not in st.session_state: | |
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
def clear_memory(self): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def load_model(self, model_key, model_size): | |
try: | |
self.clear_memory() | |
if st.session_state.current_model["model"] is not None: | |
del st.session_state.current_model["model"] | |
del st.session_state.current_model["tokenizer"] | |
self.clear_memory() | |
model_path = MODELS[model_key]["sizes"][model_size] | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
device_map="auto", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
st.session_state.current_model.update({ | |
"tokenizer": tokenizer, | |
"model": model, | |
"config": { | |
"name": f"{MODELS[model_key]['name']} {model_size}", | |
"path": model_path, | |
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"), | |
} | |
}) | |
return f"β {MODELS[model_key]['name']} {model_size} loaded successfully!" | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
def respond(self, message, max_tokens, temperature, top_p, top_k, image=None): | |
if not st.session_state.current_model["model"] or not st.session_state.current_model["tokenizer"]: | |
return "β οΈ Please select and load a model first" | |
try: | |
system_prompt = st.session_state.current_model["config"]["system_prompt"] | |
if not system_prompt: | |
return "β οΈ System prompt not found for the selected model." | |
prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:" | |
inputs = st.session_state.current_model["tokenizer"]( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
with torch.no_grad(): | |
output = st.session_state.current_model["model"].generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
do_sample=True, | |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, | |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, | |
) | |
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) | |
if prompt in response: | |
response = response.replace(prompt, "").strip() | |
return response | |
except Exception as e: | |
return f"β οΈ Generation Error: {str(e)}" | |
finally: | |
self.clear_memory() | |
def main(self): | |
st.title("π¦ AtlasUI - Experimental π§ͺ") | |
with st.sidebar: | |
st.header("π Model Selection") | |
model_key = st.selectbox( | |
"Choose Atlas Variant", | |
list(MODELS.keys()), | |
format_func=lambda x: f"{MODELS[x]['name']} {'π§ͺ' if MODELS[x]['experimental'] else ''}" | |
) | |
model_size = st.selectbox( | |
"Choose Model Size", | |
list(MODELS[model_key]["sizes"].keys()) | |
) | |
if st.button("Load Model"): | |
with st.spinner("Loading model... This may take a few minutes."): | |
status = self.load_model(model_key, model_size) | |
st.success(status) | |
st.header("π§ Generation Parameters") | |
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) | |
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) | |
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) | |
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.rerun() | |
st.markdown("*β οΈ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*") | |
for message in st.session_state.chat_history: | |
with st.chat_message( | |
message["role"], | |
avatar=USER_PFP if message["role"] == "user" else AI_PFP | |
): | |
st.markdown(message["content"]) | |
if "image" in message and message["image"]: | |
st.image(message["image"], caption="Uploaded Image", use_column_width=True) | |
if prompt := st.chat_input("Message Atlas..."): | |
uploaded_image = None | |
if MODELS[model_key]["is_vision"]: | |
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image}) | |
with st.chat_message("user", avatar=USER_PFP): | |
st.markdown(prompt) | |
if uploaded_image: | |
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) | |
with st.chat_message("assistant", avatar=AI_PFP): | |
with st.spinner("Generating response..."): | |
response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image) | |
st.markdown(response) | |
st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
def run(): | |
try: | |
app = AtlasInferenceApp() | |
app.main() | |
except Exception as e: | |
st.error(f"β οΈ Application Error: {str(e)}") | |
if __name__ == "__main__": | |
run() | |