AtlasUI / app.py
Spestly's picture
Update app.py
40c7d6d verified
import gc
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import os
# Load Hugging Face token
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"atlas-flash-1215": {
"name": "🦁 Atlas-Flash 1215",
"sizes": {
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
},
"emoji": "🦁",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_FLASH_1215",
},
"atlas-pro-0403": {
"name": "πŸ† Atlas-Pro 0403",
"sizes": {
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
},
"emoji": "πŸ†",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_PRO_0403",
},
}
# Profile pictures
USER_PFP = "user.png"
AI_PFP = "ai_pfp.png"
st.set_page_config(
page_title="Atlas Model Inference",
page_icon="🦁 ",
layout="wide",
menu_items={
'Get Help': 'https://huggingface.co/collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86',
'Report a bug': 'https://huggingface.co/Spestly/Athena-1-1.5B/discussions/new',
'About': 'Athena Model Inference Platform'
}
)
st.markdown(
"""
<style>
.stSlider > div > div > div > div {
background-color: #1f78b4 !important;
}
.stButton > button {
background-color: #1f78b4 !important;
color: white !important;
border: none !important;
}
.stButton > button:hover {
background-color: #16609a !important;
}
</style>
""",
unsafe_allow_html=True,
)
class AtlasInferenceApp:
def __init__(self):
if "current_model" not in st.session_state:
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None}
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
def clear_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load_model(self, model_key, model_size):
try:
self.clear_memory()
if st.session_state.current_model["model"] is not None:
del st.session_state.current_model["model"]
del st.session_state.current_model["tokenizer"]
self.clear_memory()
model_path = MODELS[model_key]["sizes"][model_size]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
st.session_state.current_model.update({
"tokenizer": tokenizer,
"model": model,
"config": {
"name": f"{MODELS[model_key]['name']} {model_size}",
"path": model_path,
"system_prompt": os.getenv(MODELS[model_key]["system_prompt_env"], "Default system prompt"),
}
})
return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
except Exception as e:
return f"❌ Error: {str(e)}"
def respond(self, message, max_tokens, temperature, top_p, top_k, image=None):
if not st.session_state.current_model["model"] or not st.session_state.current_model["tokenizer"]:
return "⚠️ Please select and load a model first"
try:
system_prompt = st.session_state.current_model["config"]["system_prompt"]
if not system_prompt:
return "⚠️ System prompt not found for the selected model."
prompt = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
inputs = st.session_state.current_model["tokenizer"](
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
output = st.session_state.current_model["model"].generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
)
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
if prompt in response:
response = response.replace(prompt, "").strip()
return response
except Exception as e:
return f"⚠️ Generation Error: {str(e)}"
finally:
self.clear_memory()
def main(self):
st.title("🦁 AtlasUI - Experimental πŸ§ͺ")
with st.sidebar:
st.header("πŸ›  Model Selection")
model_key = st.selectbox(
"Choose Atlas Variant",
list(MODELS.keys()),
format_func=lambda x: f"{MODELS[x]['name']} {'πŸ§ͺ' if MODELS[x]['experimental'] else ''}"
)
model_size = st.selectbox(
"Choose Model Size",
list(MODELS[model_key]["sizes"].keys())
)
if st.button("Load Model"):
with st.spinner("Loading model... This may take a few minutes."):
status = self.load_model(model_key, model_size)
st.success(status)
st.header("πŸ”§ Generation Parameters")
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1)
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1)
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
for message in st.session_state.chat_history:
with st.chat_message(
message["role"],
avatar=USER_PFP if message["role"] == "user" else AI_PFP
):
st.markdown(message["content"])
if "image" in message and message["image"]:
st.image(message["image"], caption="Uploaded Image", use_column_width=True)
if prompt := st.chat_input("Message Atlas..."):
uploaded_image = None
if MODELS[model_key]["is_vision"]:
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
st.session_state.chat_history.append({"role": "user", "content": prompt, "image": uploaded_image})
with st.chat_message("user", avatar=USER_PFP):
st.markdown(prompt)
if uploaded_image:
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
with st.chat_message("assistant", avatar=AI_PFP):
with st.spinner("Generating response..."):
response = self.respond(prompt, max_tokens, temperature, top_p, top_k, image=uploaded_image)
st.markdown(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
def run():
try:
app = AtlasInferenceApp()
app.main()
except Exception as e:
st.error(f"⚠️ Application Error: {str(e)}")
if __name__ == "__main__":
run()