Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import time | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| import os | |
| from groq import Groq | |
| # Load API keys | |
| load_dotenv() | |
| HF_API_KEY = st.secrets.get("HUGGINGFACE_API_KEY") or os.getenv("HUGGINGFACE_API_KEY") | |
| GROQ_API_KEY = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY") | |
| # Initialize Groq client | |
| groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None | |
| # Model lists | |
| HF_MODELS = [ | |
| "bigcode/starcoder2-3b", | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "tiiuae/falcon-7b-instruct" | |
| ] | |
| GROQ_MODELS = [ | |
| "llama3-8b-8192", | |
| "llama3-70b-8192", | |
| "mixtral-8x7b-32768" | |
| ] | |
| # Hugging Face API | |
| def query_hf(model_id, prompt, max_new_tokens=300): | |
| url = f"https://api-inference.huggingface.co/models/{model_id}" | |
| headers = {"Authorization": f"Bearer {HF_API_KEY}"} | |
| payload = {"inputs": prompt, "parameters": {"max_new_tokens": max_new_tokens}} | |
| t0 = time.time() | |
| try: | |
| response = requests.post(url, headers=headers, json=payload, timeout=60) | |
| latency = time.time() - t0 | |
| output = response.json() | |
| if isinstance(output, list) and "generated_text" in output[0]: | |
| return output[0]["generated_text"], latency, True | |
| elif isinstance(output, dict) and "error" in output: | |
| return f"β οΈ HF Error: {output['error']}", latency, False | |
| else: | |
| return str(output), latency, False | |
| except Exception as e: | |
| latency = time.time() - t0 | |
| return f"β οΈ HF Exception: {e}", latency, False | |
| # Groq API | |
| def query_groq(model_id, prompt, max_tokens=300): | |
| if not groq_client: | |
| return "β οΈ Groq API key not set.", 0, False | |
| t0 = time.time() | |
| try: | |
| response = groq_client.chat.completions.create( | |
| model=model_id, | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=max_tokens | |
| ) | |
| latency = time.time() - t0 | |
| return response.choices[0].message.content, latency, True | |
| except Exception as e: | |
| latency = time.time() - t0 | |
| return f"β οΈ Groq Exception: {e}", latency, False | |
| # Prompt builder | |
| def build_prompt(mode, user_input, language="Python"): | |
| if mode == "Generate": | |
| return f"Write {language} code for this task:\n{user_input}\nOnly return code." | |
| elif mode == "Debug": | |
| return f"Debug this code and explain briefly:\n{user_input}" | |
| elif mode == "Explain": | |
| return f"Explain this code step by step:\n{user_input}" | |
| return user_input | |
| # Run query with fallback | |
| def run_query(backend, model_id, mode, user_input, lang="Python"): | |
| prompt = build_prompt(mode, user_input, lang) | |
| if backend == "hf": | |
| output, latency, success = query_hf(model_id, prompt) | |
| if success: | |
| return output, latency, f"Hugging Face ({model_id}) β " | |
| else: | |
| # fallback to Groq | |
| fallback_model = GROQ_MODELS[0] # default to Llama3-8B | |
| output2, latency2, success2 = query_groq(fallback_model, prompt) | |
| return output2, latency + latency2, f"Hugging Face β β Groq ({fallback_model}) β " | |
| elif backend == "groq": | |
| output, latency, success = query_groq(model_id, prompt) | |
| if success: | |
| return output, latency, f"Groq ({model_id}) β " | |
| else: | |
| # fallback to Hugging Face | |
| fallback_model = HF_MODELS[0] # default to Starcoder2 | |
| output2, latency2, success2 = query_hf(fallback_model, prompt) | |
| return output2, latency + latency2, f"Groq β β Hugging Face ({fallback_model}) β " | |
| return "β οΈ Invalid backend", 0, "None" | |
| # Streamlit UI | |
| st.set_page_config(page_title="CodeCraft AI", layout="wide") | |
| st.title("π§βπ» CodeCraft AI (Model Selection + Fallback)") | |
| st.write("Choose models yourself. If one fails, fallback ensures smooth output.") | |
| backend = st.radio("Choose Backend", ["hf", "groq"], format_func=lambda x: "Hugging Face" if x == "hf" else "Groq") | |
| if backend == "hf": | |
| model_id = st.selectbox("Choose Hugging Face Model", HF_MODELS) | |
| else: | |
| model_id = st.selectbox("Choose Groq Model", GROQ_MODELS) | |
| tab1, tab2, tab3, tab4 = st.tabs(["Generate", "Debug", "Explain", "Analytics"]) | |
| # Track logs | |
| if "logs" not in st.session_state: | |
| st.session_state.logs = [] | |
| # Tab 1: Generate | |
| with tab1: | |
| st.subheader("Code Generation") | |
| lang = st.selectbox("Choose language", ["Python", "JavaScript"]) | |
| problem = st.text_area("Enter problem statement") | |
| if st.button("Generate Code", key="gen_btn"): | |
| if problem.strip(): | |
| output, latency, status = run_query(backend, model_id, "Generate", problem, lang) | |
| st.code(output, language=lang.lower()) | |
| st.success(f"{status} | Time: {latency:.2f}s") | |
| st.session_state.logs.append(("Generate", latency, status)) | |
| else: | |
| st.warning("Please enter a problem.") | |
| # Tab 2: Debug | |
| with tab2: | |
| st.subheader("Debug Code") | |
| buggy_code = st.text_area("Paste buggy code here") | |
| if st.button("Debug Code", key="debug_btn"): | |
| if buggy_code.strip(): | |
| output, latency, status = run_query(backend, model_id, "Debug", buggy_code) | |
| st.text_area("AI Fix & Explanation", output, height=300) | |
| st.success(f"{status} | Time: {latency:.2f}s") | |
| st.session_state.logs.append(("Debug", latency, status)) | |
| else: | |
| st.warning("Please paste code.") | |
| # Tab 3: Explain | |
| with tab3: | |
| st.subheader("Explain Code") | |
| code_input = st.text_area("Paste code to explain") | |
| if st.button("Explain Code", key="explain_btn"): | |
| if code_input.strip(): | |
| output, latency, status = run_query(backend, model_id, "Explain", code_input) | |
| st.text_area("AI Explanation", output, height=300) | |
| st.success(f"{status} | Time: {latency:.2f}s") | |
| st.session_state.logs.append(("Explain", latency, status)) | |
| else: | |
| st.warning("Please paste code.") | |
| # Tab 4: Analytics | |
| with tab4: | |
| st.subheader("Usage Analytics") | |
| if st.session_state.logs: | |
| df = pd.DataFrame(st.session_state.logs, columns=["Mode", "Latency", "Status"]) | |
| st.write(df) | |
| st.bar_chart(df.groupby("Mode")["Latency"].mean()) | |
| st.download_button("Download Logs as CSV", df.to_csv(index=False), "logs.csv", "text/csv") | |
| else: | |
| st.info("No usage yet. Try generating, debugging, or explaining first!") | |