karths's picture
Update app.py
ddc3bda verified
import gradio as gr
import os
import json
from openai import OpenAI
import sys # Added for flushing output in case of direct printing
# Load sensitive information from environment variables
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY')
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
# --- Basic Input Validation ---
if not RUNPOD_API_KEY:
raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.")
if not RUNPOD_ENDPOINT_ID:
raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.")
BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1"
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod
MAX_TOKENS = 4096 # Max tokens for the model response
# --- OpenAI Client Initialization ---
client = OpenAI(
api_key=RUNPOD_API_KEY,
base_url=BASE_URL,
)
# --- Gradio App Configuration ---
title = "Python Maintainability Refactoring demo"
description = """
## Instructions for Using the Model
### Model Loading Time:
- Please allow time for the model on GPU server to initialize if it's starting fresh ("Cold Start"). The response will appear token by token.
### Code Submission:
- You can enter or paste your Python code you wish to have refactored, or use the provided example.
### Python Code Constraints:
- Keep the code reasonably sized. Large code blocks might face limitations depending on the GPU instance and model constraints. Max response length is set to {} tokens.
### Understanding Changes:
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made.
### Usage Recommendation:
- Intended for research and evaluation purposes.
""".format(MAX_TOKENS)
system_prompt = """### Instruction:
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics.
### Input:
"""
css = """.toast-wrap { display: none !important } """
examples = [
["""def analyze_sales_data(sales_records):
active_sales = filter(lambda record: record['status'] == 'active', sales_records)
sales_by_category = {}
for record in active_sales:
category = record['category']
total_sales = record['units_sold'] * record['price_per_unit']
if category not in sales_by_category:
sales_by_category[category] = {'total_sales': 0, 'total_units': 0}
sales_by_category[category]['total_sales'] += total_sales
sales_by_category[category]['total_units'] += record['units_sold']
average_sales_data = []
for category, data in sales_by_category.items():
average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero
sales_by_category[category]['average_sales'] = average_sales
average_sales_data.append((category, average_sales))
average_sales_data.sort(key=lambda x: x[1], reverse=True)
for rank, (category, _) in enumerate(average_sales_data, start=1):
sales_by_category[category]['rank'] = rank
return sales_by_category"""],
["""import pandas as pd
import re
import ast
from code_bert_score import score # Assuming this library is available in the environment
import numpy as np
def preprocess_code(source_text):
def remove_comments_and_docstrings(source_code):
# Remove single-line comments
source_code = re.sub(r'#.*', '', source_code)
# Remove multi-line strings (docstrings)
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL)
return source_code.strip() # Added strip
# Pattern to extract code specifically from markdown blocks if present
pattern = r"```python\s+(.+?)\s+```"
matches = re.findall(pattern, source_text, re.DOTALL)
code_to_process = '\n'.join(matches) if matches else source_text
cleaned_code = remove_comments_and_docstrings(code_to_process)
return cleaned_code
def evaluate_dataframe(df):
results = {'P': [], 'R': [], 'F1': [], 'F3': []}
for index, row in df.iterrows():
try:
# Ensure inputs are lists of strings
cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion
refs = [preprocess_code(str(row['output']))] # Added str() conversion
# Ensure code_bert_score.score returns four values
score_results = score(cands, refs, lang='python')
if len(score_results) == 4:
P, R, F1, F3 = score_results
results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output
results['R'].append(R.item() if hasattr(R, 'item') else R)
results['F1'].append(F1.item() if hasattr(F1, 'item') else F1)
results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned
else:
print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.")
for key in results.keys():
results[key].append(np.nan) # Append NaN for unexpected format
except Exception as e:
print(f"Error processing row {index}: {e}")
for key in results.keys():
results[key].append(np.nan) # Use NaN for errors
df_metrics = pd.DataFrame(results)
return df_metrics
def evaluate_dataframe_multiple_runs(df, runs=3):
all_results = []
print(f"Starting evaluation for {runs} runs...")
for run in range(runs):
print(f"Run {run + 1}/{runs}")
df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified
all_results.append(df_metrics)
print(f"Run {run + 1} completed.")
if not all_results:
print("No results collected.")
return pd.DataFrame(), pd.DataFrame()
# Concatenate results and calculate statistics
try:
concatenated_results = pd.concat(all_results)
df_metrics_mean = concatenated_results.groupby(level=0).mean()
df_metrics_std = concatenated_results.groupby(level=0).std()
print("Mean and standard deviation calculated.")
except Exception as e:
print(f"Error calculating statistics: {e}")
# Return empty DataFrames or handle as appropriate
return pd.DataFrame(), pd.DataFrame()
return df_metrics_mean, df_metrics_std"""]
]
# --- Core Logic (Modified for Streaming) ---
def gen_solution_stream(prompt):
"""
Generates a solution for a given problem prompt by calling the LLM via RunPod
and yielding the response chunks as they arrive (streaming).
Parameters:
- prompt (str): The problem prompt including the system message and user input.
Yields:
- str: Chunks of the generated solution text.
- str: An error message if an exception occurs.
"""
try:
# Call the OpenAI compatible endpoint on RunPod with streaming enabled
stream = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.1, # Keep temperature low for deterministic refactoring
top_p=1.0,
max_tokens=MAX_TOKENS,
stream=True # Enable streaming
)
# Yield content chunks from the stream
for chunk in stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content
# Optional: Handle finish reason if needed
# if chunk.choices and chunk.choices[0].finish_reason:
# print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}")
except Exception as e:
error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}"
print(error_message, file=sys.stderr) # Log error to stderr
yield error_message # Yield the error message to be displayed in the UI
# --- Gradio Interface Function (Modified for Streaming) ---
def predict(message, history):
"""
Handles the user input, calls the backend model stream, and yields the response chunks.
'history' parameter is required by gr.ChatInterface but might not be used here.
"""
# Construct the full prompt
input_prompt = system_prompt + str(message)
# Get the refactored code stream from the backend
response_stream = gen_solution_stream(input_prompt)
# Yield each chunk received from the stream generator
# Gradio's ChatInterface handles accumulating these yields into the chatbot output
buffer = ""
for chunk in response_stream:
buffer += chunk
yield buffer # Yield the accumulated buffer to update the UI incrementally
# --- Launch Gradio Interface ---
# Use gr.ChatInterface for a chat-like experience
gr.ChatInterface(
predict, # Pass the generator function
chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button
textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."),
title=title,
description=description,
theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default()
examples=examples,
cache_examples=False, # Consider enabling caching if examples don't change often
submit_btn="Submit Code",
retry_btn="Retry",
undo_btn="Undo",
clear_btn="Clear",
css=css # Apply custom CSS if needed
).queue().launch(share=True) # share=True creates a public link (use with caution)