Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import json | |
from openai import OpenAI | |
import sys # Added for flushing output in case of direct printing | |
# Load sensitive information from environment variables | |
RUNPOD_API_KEY = os.getenv('RUNPOD_API_KEY') | |
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID') | |
# --- Basic Input Validation --- | |
if not RUNPOD_API_KEY: | |
raise ValueError("RunPod API key not found. Please set the RUNPOD_API_KEY environment variable.") | |
if not RUNPOD_ENDPOINT_ID: | |
raise ValueError("RunPod Endpoint ID not found. Please set the RUNPOD_ENDPOINT_ID environment variable.") | |
BASE_URL = f"https://api.runpod.ai/v2/{RUNPOD_ENDPOINT_ID}/openai/v1" | |
MODEL_NAME = "karths/coder_commit_32B" # The specific model hosted on RunPod | |
MAX_TOKENS = 4096 # Max tokens for the model response | |
# --- OpenAI Client Initialization --- | |
client = OpenAI( | |
api_key=RUNPOD_API_KEY, | |
base_url=BASE_URL, | |
) | |
# --- Gradio App Configuration --- | |
title = "Python Maintainability Refactoring demo" | |
description = """ | |
## Instructions for Using the Model | |
### Model Loading Time: | |
- Please allow time for the model on GPU server to initialize if it's starting fresh ("Cold Start"). The response will appear token by token. | |
### Code Submission: | |
- You can enter or paste your Python code you wish to have refactored, or use the provided example. | |
### Python Code Constraints: | |
- Keep the code reasonably sized. Large code blocks might face limitations depending on the GPU instance and model constraints. Max response length is set to {} tokens. | |
### Understanding Changes: | |
- It's important to read the "Changes made" section (if provided by the model) in the refactored code response. This will help in understanding what modifications have been made. | |
### Usage Recommendation: | |
- Intended for research and evaluation purposes. | |
""".format(MAX_TOKENS) | |
system_prompt = """### Instruction: | |
Refactor the provided Python code to improve its maintainability and efficiency and reduce complexity. Include the refactored code along with comments on the changes made for improving the metrics. | |
### Input: | |
""" | |
css = """.toast-wrap { display: none !important } """ | |
examples = [ | |
["""def analyze_sales_data(sales_records): | |
active_sales = filter(lambda record: record['status'] == 'active', sales_records) | |
sales_by_category = {} | |
for record in active_sales: | |
category = record['category'] | |
total_sales = record['units_sold'] * record['price_per_unit'] | |
if category not in sales_by_category: | |
sales_by_category[category] = {'total_sales': 0, 'total_units': 0} | |
sales_by_category[category]['total_sales'] += total_sales | |
sales_by_category[category]['total_units'] += record['units_sold'] | |
average_sales_data = [] | |
for category, data in sales_by_category.items(): | |
average_sales = data['total_sales'] / data['total_units'] if data['total_units'] > 0 else 0 # Avoid division by zero | |
sales_by_category[category]['average_sales'] = average_sales | |
average_sales_data.append((category, average_sales)) | |
average_sales_data.sort(key=lambda x: x[1], reverse=True) | |
for rank, (category, _) in enumerate(average_sales_data, start=1): | |
sales_by_category[category]['rank'] = rank | |
return sales_by_category"""], | |
["""import pandas as pd | |
import re | |
import ast | |
from code_bert_score import score # Assuming this library is available in the environment | |
import numpy as np | |
def preprocess_code(source_text): | |
def remove_comments_and_docstrings(source_code): | |
# Remove single-line comments | |
source_code = re.sub(r'#.*', '', source_code) | |
# Remove multi-line strings (docstrings) | |
source_code = re.sub(r'(\'\'\'(.*?)\'\'\'|\"\"\"(.*?)\"\"\")', '', source_code, flags=re.DOTALL) | |
return source_code.strip() # Added strip | |
# Pattern to extract code specifically from markdown blocks if present | |
pattern = r"```python\s+(.+?)\s+```" | |
matches = re.findall(pattern, source_text, re.DOTALL) | |
code_to_process = '\n'.join(matches) if matches else source_text | |
cleaned_code = remove_comments_and_docstrings(code_to_process) | |
return cleaned_code | |
def evaluate_dataframe(df): | |
results = {'P': [], 'R': [], 'F1': [], 'F3': []} | |
for index, row in df.iterrows(): | |
try: | |
# Ensure inputs are lists of strings | |
cands = [preprocess_code(str(row['generated_text']))] # Added str() conversion | |
refs = [preprocess_code(str(row['output']))] # Added str() conversion | |
# Ensure code_bert_score.score returns four values | |
score_results = score(cands, refs, lang='python') | |
if len(score_results) == 4: | |
P, R, F1, F3 = score_results | |
results['P'].append(P.item() if hasattr(P, 'item') else P) # Handle potential tensor output | |
results['R'].append(R.item() if hasattr(R, 'item') else R) | |
results['F1'].append(F1.item() if hasattr(F1, 'item') else F1) | |
results['F3'].append(F3.item() if hasattr(F3, 'item') else F3) # Assuming F3 is returned | |
else: | |
print(f"Warning: Unexpected number of return values from score function for row {index}. Got {len(score_results)} values.") | |
for key in results.keys(): | |
results[key].append(np.nan) # Append NaN for unexpected format | |
except Exception as e: | |
print(f"Error processing row {index}: {e}") | |
for key in results.keys(): | |
results[key].append(np.nan) # Use NaN for errors | |
df_metrics = pd.DataFrame(results) | |
return df_metrics | |
def evaluate_dataframe_multiple_runs(df, runs=3): | |
all_results = [] | |
print(f"Starting evaluation for {runs} runs...") | |
for run in range(runs): | |
print(f"Run {run + 1}/{runs}") | |
df_metrics = evaluate_dataframe(df.copy()) # Use a copy to avoid side effects if df is modified | |
all_results.append(df_metrics) | |
print(f"Run {run + 1} completed.") | |
if not all_results: | |
print("No results collected.") | |
return pd.DataFrame(), pd.DataFrame() | |
# Concatenate results and calculate statistics | |
try: | |
concatenated_results = pd.concat(all_results) | |
df_metrics_mean = concatenated_results.groupby(level=0).mean() | |
df_metrics_std = concatenated_results.groupby(level=0).std() | |
print("Mean and standard deviation calculated.") | |
except Exception as e: | |
print(f"Error calculating statistics: {e}") | |
# Return empty DataFrames or handle as appropriate | |
return pd.DataFrame(), pd.DataFrame() | |
return df_metrics_mean, df_metrics_std"""] | |
] | |
# --- Core Logic (Modified for Streaming) --- | |
def gen_solution_stream(prompt): | |
""" | |
Generates a solution for a given problem prompt by calling the LLM via RunPod | |
and yielding the response chunks as they arrive (streaming). | |
Parameters: | |
- prompt (str): The problem prompt including the system message and user input. | |
Yields: | |
- str: Chunks of the generated solution text. | |
- str: An error message if an exception occurs. | |
""" | |
try: | |
# Call the OpenAI compatible endpoint on RunPod with streaming enabled | |
stream = client.chat.completions.create( | |
model=MODEL_NAME, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.1, # Keep temperature low for deterministic refactoring | |
top_p=1.0, | |
max_tokens=MAX_TOKENS, | |
stream=True # Enable streaming | |
) | |
# Yield content chunks from the stream | |
for chunk in stream: | |
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
content = chunk.choices[0].delta.content | |
yield content | |
# Optional: Handle finish reason if needed | |
# if chunk.choices and chunk.choices[0].finish_reason: | |
# print(f"\nStream finished with reason: {chunk.choices[0].finish_reason}") | |
except Exception as e: | |
error_message = f"Error: Could not get streaming response from the model. Details: {str(e)}" | |
print(error_message, file=sys.stderr) # Log error to stderr | |
yield error_message # Yield the error message to be displayed in the UI | |
# --- Gradio Interface Function (Modified for Streaming) --- | |
def predict(message, history): | |
""" | |
Handles the user input, calls the backend model stream, and yields the response chunks. | |
'history' parameter is required by gr.ChatInterface but might not be used here. | |
""" | |
# Construct the full prompt | |
input_prompt = system_prompt + str(message) | |
# Get the refactored code stream from the backend | |
response_stream = gen_solution_stream(input_prompt) | |
# Yield each chunk received from the stream generator | |
# Gradio's ChatInterface handles accumulating these yields into the chatbot output | |
buffer = "" | |
for chunk in response_stream: | |
buffer += chunk | |
yield buffer # Yield the accumulated buffer to update the UI incrementally | |
# --- Launch Gradio Interface --- | |
# Use gr.ChatInterface for a chat-like experience | |
gr.ChatInterface( | |
predict, # Pass the generator function | |
chatbot=gr.Chatbot(height=500, label="Refactored Code and Explanation", show_copy_button=True), # Added copy button | |
textbox=gr.Textbox(lines=10, label="Python Code", placeholder="Enter or Paste your Python code here..."), | |
title=title, | |
description=description, | |
theme="abidlabs/Lime", # Or choose another theme e.g., gr.themes.Default() | |
examples=examples, | |
cache_examples=False, # Consider enabling caching if examples don't change often | |
submit_btn="Submit Code", | |
retry_btn="Retry", | |
undo_btn="Undo", | |
clear_btn="Clear", | |
css=css # Apply custom CSS if needed | |
).queue().launch(share=True) # share=True creates a public link (use with caution) |