Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,29 +6,26 @@ from dotenv import load_dotenv
|
|
| 6 |
import os
|
| 7 |
from groq import Groq
|
| 8 |
|
| 9 |
-
# Load
|
| 10 |
load_dotenv()
|
| 11 |
HF_API_KEY = st.secrets.get("HUGGINGFACE_API_KEY") or os.getenv("HUGGINGFACE_API_KEY")
|
| 12 |
GROQ_API_KEY = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY")
|
| 13 |
|
| 14 |
-
# Initialize Groq client
|
| 15 |
groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
"
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
"
|
| 28 |
-
|
| 29 |
-
"groq": "llama3-70b-8192"
|
| 30 |
-
}
|
| 31 |
-
}
|
| 32 |
|
| 33 |
# Hugging Face API
|
| 34 |
def query_hf(model_id, prompt, max_new_tokens=300):
|
|
@@ -71,44 +68,50 @@ def query_groq(model_id, prompt, max_tokens=300):
|
|
| 71 |
|
| 72 |
# Prompt builder
|
| 73 |
def build_prompt(mode, user_input, language="Python"):
|
| 74 |
-
if mode == "
|
| 75 |
-
return f"
|
| 76 |
-
elif mode == "
|
| 77 |
return f"Debug this code and explain briefly:\n{user_input}"
|
| 78 |
-
elif mode == "
|
| 79 |
return f"Explain this code step by step:\n{user_input}"
|
| 80 |
return user_input
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
def run_query(backend, mode, user_input, lang="Python"):
|
| 84 |
-
model_id = MODELS[mode][backend]
|
| 85 |
prompt = build_prompt(mode, user_input, lang)
|
| 86 |
|
| 87 |
if backend == "hf":
|
| 88 |
output, latency, success = query_hf(model_id, prompt)
|
| 89 |
if success:
|
| 90 |
-
return output, latency, "Hugging Face β
"
|
| 91 |
else:
|
| 92 |
# fallback to Groq
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
elif backend == "groq":
|
| 96 |
output, latency, success = query_groq(model_id, prompt)
|
| 97 |
if success:
|
| 98 |
-
return output, latency, "Groq β
"
|
| 99 |
else:
|
| 100 |
-
# fallback to
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
return "β οΈ Invalid backend", 0, "None"
|
| 104 |
|
| 105 |
# Streamlit UI
|
| 106 |
st.set_page_config(page_title="CodeCraft AI", layout="wide")
|
| 107 |
-
st.title("π§βπ» CodeCraft AI (
|
| 108 |
-
st.write("
|
| 109 |
|
| 110 |
backend = st.radio("Choose Backend", ["hf", "groq"], format_func=lambda x: "Hugging Face" if x == "hf" else "Groq")
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
tab1, tab2, tab3, tab4 = st.tabs(["Generate", "Debug", "Explain", "Analytics"])
|
| 113 |
|
| 114 |
# Track logs
|
|
@@ -119,10 +122,10 @@ if "logs" not in st.session_state:
|
|
| 119 |
with tab1:
|
| 120 |
st.subheader("Code Generation")
|
| 121 |
lang = st.selectbox("Choose language", ["Python", "JavaScript"])
|
| 122 |
-
problem = st.text_area("Enter
|
| 123 |
if st.button("Generate Code", key="gen_btn"):
|
| 124 |
if problem.strip():
|
| 125 |
-
output, latency, status = run_query(backend,
|
| 126 |
st.code(output, language=lang.lower())
|
| 127 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 128 |
st.session_state.logs.append(("Generate", latency, status))
|
|
@@ -135,7 +138,7 @@ with tab2:
|
|
| 135 |
buggy_code = st.text_area("Paste buggy code here")
|
| 136 |
if st.button("Debug Code", key="debug_btn"):
|
| 137 |
if buggy_code.strip():
|
| 138 |
-
output, latency, status = run_query(backend, "
|
| 139 |
st.text_area("AI Fix & Explanation", output, height=300)
|
| 140 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 141 |
st.session_state.logs.append(("Debug", latency, status))
|
|
@@ -148,7 +151,7 @@ with tab3:
|
|
| 148 |
code_input = st.text_area("Paste code to explain")
|
| 149 |
if st.button("Explain Code", key="explain_btn"):
|
| 150 |
if code_input.strip():
|
| 151 |
-
output, latency, status = run_query(backend, "
|
| 152 |
st.text_area("AI Explanation", output, height=300)
|
| 153 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 154 |
st.session_state.logs.append(("Explain", latency, status))
|
|
@@ -162,5 +165,6 @@ with tab4:
|
|
| 162 |
df = pd.DataFrame(st.session_state.logs, columns=["Mode", "Latency", "Status"])
|
| 163 |
st.write(df)
|
| 164 |
st.bar_chart(df.groupby("Mode")["Latency"].mean())
|
|
|
|
| 165 |
else:
|
| 166 |
st.info("No usage yet. Try generating, debugging, or explaining first!")
|
|
|
|
| 6 |
import os
|
| 7 |
from groq import Groq
|
| 8 |
|
| 9 |
+
# Load API keys
|
| 10 |
load_dotenv()
|
| 11 |
HF_API_KEY = st.secrets.get("HUGGINGFACE_API_KEY") or os.getenv("HUGGINGFACE_API_KEY")
|
| 12 |
GROQ_API_KEY = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY")
|
| 13 |
|
| 14 |
+
# Initialize Groq client
|
| 15 |
groq_client = Groq(api_key=GROQ_API_KEY) if GROQ_API_KEY else None
|
| 16 |
|
| 17 |
+
# Model lists
|
| 18 |
+
HF_MODELS = [
|
| 19 |
+
"bigcode/starcoder2-3b",
|
| 20 |
+
"mistralai/Mistral-7B-Instruct-v0.2",
|
| 21 |
+
"tiiuae/falcon-7b-instruct"
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
GROQ_MODELS = [
|
| 25 |
+
"llama3-8b-8192",
|
| 26 |
+
"llama3-70b-8192",
|
| 27 |
+
"mixtral-8x7b-32768"
|
| 28 |
+
]
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Hugging Face API
|
| 31 |
def query_hf(model_id, prompt, max_new_tokens=300):
|
|
|
|
| 68 |
|
| 69 |
# Prompt builder
|
| 70 |
def build_prompt(mode, user_input, language="Python"):
|
| 71 |
+
if mode == "Generate":
|
| 72 |
+
return f"Write {language} code for this task:\n{user_input}\nOnly return code."
|
| 73 |
+
elif mode == "Debug":
|
| 74 |
return f"Debug this code and explain briefly:\n{user_input}"
|
| 75 |
+
elif mode == "Explain":
|
| 76 |
return f"Explain this code step by step:\n{user_input}"
|
| 77 |
return user_input
|
| 78 |
|
| 79 |
+
# Run query with fallback
|
| 80 |
+
def run_query(backend, model_id, mode, user_input, lang="Python"):
|
|
|
|
| 81 |
prompt = build_prompt(mode, user_input, lang)
|
| 82 |
|
| 83 |
if backend == "hf":
|
| 84 |
output, latency, success = query_hf(model_id, prompt)
|
| 85 |
if success:
|
| 86 |
+
return output, latency, f"Hugging Face ({model_id}) β
"
|
| 87 |
else:
|
| 88 |
# fallback to Groq
|
| 89 |
+
fallback_model = GROQ_MODELS[0] # default to Llama3-8B
|
| 90 |
+
output2, latency2, success2 = query_groq(fallback_model, prompt)
|
| 91 |
+
return output2, latency + latency2, f"Hugging Face β β Groq ({fallback_model}) β
"
|
| 92 |
elif backend == "groq":
|
| 93 |
output, latency, success = query_groq(model_id, prompt)
|
| 94 |
if success:
|
| 95 |
+
return output, latency, f"Groq ({model_id}) β
"
|
| 96 |
else:
|
| 97 |
+
# fallback to Hugging Face
|
| 98 |
+
fallback_model = HF_MODELS[0] # default to Starcoder2
|
| 99 |
+
output2, latency2, success2 = query_hf(fallback_model, prompt)
|
| 100 |
+
return output2, latency + latency2, f"Groq β β Hugging Face ({fallback_model}) β
"
|
| 101 |
return "β οΈ Invalid backend", 0, "None"
|
| 102 |
|
| 103 |
# Streamlit UI
|
| 104 |
st.set_page_config(page_title="CodeCraft AI", layout="wide")
|
| 105 |
+
st.title("π§βπ» CodeCraft AI (Model Selection + Fallback)")
|
| 106 |
+
st.write("Choose models yourself. If one fails, fallback ensures smooth output.")
|
| 107 |
|
| 108 |
backend = st.radio("Choose Backend", ["hf", "groq"], format_func=lambda x: "Hugging Face" if x == "hf" else "Groq")
|
| 109 |
|
| 110 |
+
if backend == "hf":
|
| 111 |
+
model_id = st.selectbox("Choose Hugging Face Model", HF_MODELS)
|
| 112 |
+
else:
|
| 113 |
+
model_id = st.selectbox("Choose Groq Model", GROQ_MODELS)
|
| 114 |
+
|
| 115 |
tab1, tab2, tab3, tab4 = st.tabs(["Generate", "Debug", "Explain", "Analytics"])
|
| 116 |
|
| 117 |
# Track logs
|
|
|
|
| 122 |
with tab1:
|
| 123 |
st.subheader("Code Generation")
|
| 124 |
lang = st.selectbox("Choose language", ["Python", "JavaScript"])
|
| 125 |
+
problem = st.text_area("Enter problem statement")
|
| 126 |
if st.button("Generate Code", key="gen_btn"):
|
| 127 |
if problem.strip():
|
| 128 |
+
output, latency, status = run_query(backend, model_id, "Generate", problem, lang)
|
| 129 |
st.code(output, language=lang.lower())
|
| 130 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 131 |
st.session_state.logs.append(("Generate", latency, status))
|
|
|
|
| 138 |
buggy_code = st.text_area("Paste buggy code here")
|
| 139 |
if st.button("Debug Code", key="debug_btn"):
|
| 140 |
if buggy_code.strip():
|
| 141 |
+
output, latency, status = run_query(backend, model_id, "Debug", buggy_code)
|
| 142 |
st.text_area("AI Fix & Explanation", output, height=300)
|
| 143 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 144 |
st.session_state.logs.append(("Debug", latency, status))
|
|
|
|
| 151 |
code_input = st.text_area("Paste code to explain")
|
| 152 |
if st.button("Explain Code", key="explain_btn"):
|
| 153 |
if code_input.strip():
|
| 154 |
+
output, latency, status = run_query(backend, model_id, "Explain", code_input)
|
| 155 |
st.text_area("AI Explanation", output, height=300)
|
| 156 |
st.success(f"{status} | Time: {latency:.2f}s")
|
| 157 |
st.session_state.logs.append(("Explain", latency, status))
|
|
|
|
| 165 |
df = pd.DataFrame(st.session_state.logs, columns=["Mode", "Latency", "Status"])
|
| 166 |
st.write(df)
|
| 167 |
st.bar_chart(df.groupby("Mode")["Latency"].mean())
|
| 168 |
+
st.download_button("Download Logs as CSV", df.to_csv(index=False), "logs.csv", "text/csv")
|
| 169 |
else:
|
| 170 |
st.info("No usage yet. Try generating, debugging, or explaining first!")
|