Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from huggingface_hub import login | |
import os | |
import logging | |
from datetime import datetime | |
import json | |
from typing import List, Dict | |
import warnings | |
import spaces | |
# Filter out warnings | |
warnings.filterwarnings('ignore') | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Environment variables | |
HF_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") | |
# Cache directory for model | |
CACHE_DIR = "/home/user/.cache/huggingface" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# History file | |
HISTORY_FILE = "/home/user/review_history.json" | |
class Review: | |
def __init__(self, code: str, language: str, suggestions: str): | |
self.code = code | |
self.language = language | |
self.suggestions = suggestions | |
self.timestamp = datetime.now().isoformat() | |
self.response_time = 0.0 | |
def to_dict(self): | |
return { | |
'timestamp': self.timestamp, | |
'language': self.language, | |
'code': self.code, | |
'suggestions': self.suggestions, | |
'response_time': self.response_time | |
} | |
def from_dict(cls, data): | |
review = cls(data['code'], data['language'], data['suggestions']) | |
review.timestamp = data['timestamp'] | |
review.response_time = data['response_time'] | |
return review | |
class CodeReviewer: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.device = None | |
self.review_history: List[Review] = [] | |
self.metrics = { | |
'total_reviews': 0, | |
'avg_response_time': 0.0, | |
'reviews_today': 0 | |
} | |
self._initialized = False | |
self.load_history() | |
def load_history(self): | |
"""Load review history from file.""" | |
try: | |
if os.path.exists(HISTORY_FILE): | |
with open(HISTORY_FILE, 'r') as f: | |
data = json.load(f) | |
self.review_history = [Review.from_dict(r) for r in data['history']] | |
self.metrics = data['metrics'] | |
logger.info(f"Loaded {len(self.review_history)} reviews from history") | |
except Exception as e: | |
logger.error(f"Error loading history: {e}") | |
def save_history(self): | |
"""Save review history to file.""" | |
try: | |
data = { | |
'history': [r.to_dict() for r in self.review_history], | |
'metrics': self.metrics | |
} | |
with open(HISTORY_FILE, 'w') as f: | |
json.dump(data, f) | |
logger.info("Saved review history") | |
except Exception as e: | |
logger.error(f"Error saving history: {e}") | |
def ensure_initialized(self): | |
"""Ensure model is initialized.""" | |
if not self._initialized: | |
self.initialize_model() | |
self._initialized = True | |
def initialize_model(self): | |
"""Initialize the model and tokenizer.""" | |
try: | |
if HF_TOKEN: | |
login(token=HF_TOKEN, add_to_git_credential=False) | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
cache_dir=CACHE_DIR | |
) | |
special_tokens = { | |
'pad_token': '[PAD]', | |
'eos_token': '</s>', | |
'bos_token': '<s>' | |
} | |
num_added = self.tokenizer.add_special_tokens(special_tokens) | |
logger.info(f"Added {num_added} special tokens") | |
logger.info("Tokenizer loaded successfully") | |
logger.info("Loading model...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
cache_dir=CACHE_DIR, | |
token=HF_TOKEN | |
) | |
if num_added > 0: | |
logger.info("Resizing model embeddings for special tokens") | |
self.model.resize_token_embeddings(len(self.tokenizer)) | |
self.device = next(self.model.parameters()).device | |
logger.info(f"Model loaded successfully on {self.device}") | |
self._initialized = True | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing model: {e}") | |
self._initialized = False | |
return False | |
def create_review_prompt(self, code: str, language: str) -> str: | |
"""Create a structured prompt for code review.""" | |
return f"""Review this {language} code. List specific points in these sections: | |
Issues: | |
Improvements: | |
Best Practices: | |
Security: | |
Code: | |
```{language} | |
{code} | |
```""" | |
def review_code(self, code: str, language: str) -> str: | |
"""Perform code review using the model.""" | |
try: | |
if not self._initialized and not self.initialize_model(): | |
return "Error: Model initialization failed. Please try again later." | |
start_time = datetime.now() | |
prompt = self.create_review_prompt(code, language) | |
try: | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
) | |
if inputs is None: | |
raise ValueError("Failed to tokenize input") | |
inputs = inputs.to(self.device) | |
except Exception as token_error: | |
logger.error(f"Tokenization error: {token_error}") | |
return "Error: Failed to process input code. Please try again." | |
try: | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
num_beams=1, | |
early_stopping=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
except Exception as gen_error: | |
logger.error(f"Generation error: {gen_error}") | |
return "Error: Failed to generate review. Please try again." | |
try: | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
suggestions = response[len(prompt):].strip() | |
except Exception as decode_error: | |
logger.error(f"Decoding error: {decode_error}") | |
return "Error: Failed to decode model output. Please try again." | |
# Create and save review | |
end_time = datetime.now() | |
review = Review(code, language, suggestions) | |
review.response_time = (end_time - start_time).total_seconds() | |
# Update metrics first | |
self.metrics['total_reviews'] += 1 | |
total_time = self.metrics['avg_response_time'] * (self.metrics['total_reviews'] - 1) | |
total_time += review.response_time | |
self.metrics['avg_response_time'] = total_time / self.metrics['total_reviews'] | |
today = datetime.now().date() | |
# Add review to history | |
self.review_history.append(review) | |
# Update today's reviews count | |
self.metrics['reviews_today'] = sum( | |
1 for r in self.review_history | |
if datetime.fromisoformat(r.timestamp).date() == today | |
) | |
# Save to file | |
self.save_history() | |
if self.device and self.device.type == "cuda": | |
del inputs, outputs | |
torch.cuda.empty_cache() | |
return suggestions | |
except Exception as e: | |
logger.error(f"Error during code review: {e}") | |
return f"Error performing code review: {str(e)}" | |
def update_metrics(self, review: Review): | |
"""Update metrics with new review.""" | |
self.metrics['total_reviews'] += 1 | |
total_time = self.metrics['avg_response_time'] * (self.metrics['total_reviews'] - 1) | |
total_time += review.response_time | |
self.metrics['avg_response_time'] = total_time / self.metrics['total_reviews'] | |
today = datetime.now().date() | |
self.metrics['reviews_today'] = sum( | |
1 for r in self.review_history | |
if datetime.fromisoformat(r.timestamp).date() == today | |
) | |
def get_history(self) -> List[Dict]: | |
"""Get formatted review history.""" | |
return [ | |
{ | |
'timestamp': r.timestamp, | |
'language': r.language, | |
'code': r.code, | |
'suggestions': r.suggestions, | |
'response_time': f"{r.response_time:.2f}s" | |
} | |
for r in reversed(self.review_history[-10:]) | |
] | |
def get_metrics(self) -> Dict: | |
"""Get current metrics.""" | |
return { | |
'Total Reviews': self.metrics['total_reviews'], | |
'Average Response Time': f"{self.metrics['avg_response_time']:.2f}s", | |
'Reviews Today': self.metrics['reviews_today'], | |
'Device': str(self.device) if self.device else "Not initialized" | |
} | |
# Initialize reviewer | |
reviewer = CodeReviewer() | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Markdown("# Code Review Assistant") | |
gr.Markdown("An automated code review system powered by Gemma-2b") | |
with gr.Tabs(): | |
with gr.Tab("Review Code"): | |
with gr.Row(): | |
with gr.Column(): | |
code_input = gr.Textbox( | |
lines=10, | |
placeholder="Enter your code here...", | |
label="Code" | |
) | |
language_input = gr.Dropdown( | |
choices=["python", "javascript", "java", "cpp", "typescript", "go", "rust"], | |
value="python", | |
label="Language" | |
) | |
submit_btn = gr.Button("Submit for Review", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Review Results", | |
lines=10 | |
) | |
with gr.Tab("History"): | |
with gr.Row(): | |
refresh_history = gr.Button("Refresh History", variant="secondary") | |
history_output = gr.Textbox( | |
label="Review History", | |
lines=20, | |
value="Click 'Refresh History' to view review history" | |
) | |
with gr.Tab("Metrics"): | |
with gr.Row(): | |
refresh_metrics = gr.Button("Refresh Metrics", variant="secondary") | |
metrics_output = gr.JSON( | |
label="Performance Metrics" | |
) | |
def review_code_interface(code: str, language: str) -> str: | |
if not code.strip(): | |
return "Please enter some code to review." | |
try: | |
reviewer.ensure_initialized() | |
result = reviewer.review_code(code, language) | |
return result | |
except Exception as e: | |
logger.error(f"Interface error: {e}") | |
return f"Error: {str(e)}" | |
def get_history_interface() -> str: | |
try: | |
history = reviewer.get_history() | |
if not history: | |
return "No reviews yet." | |
result = "" | |
for review in history: | |
result += f"Time: {review['timestamp']}\n" | |
result += f"Language: {review['language']}\n" | |
result += f"Response Time: {review['response_time']}\n" | |
result += "Code:\n```\n" + review['code'] + "\n```\n" | |
result += "Suggestions:\n" + review['suggestions'] + "\n" | |
result += "-" * 80 + "\n\n" | |
return result | |
except Exception as e: | |
logger.error(f"History error: {e}") | |
return "Error retrieving history" | |
def get_metrics_interface() -> Dict: | |
try: | |
metrics = reviewer.get_metrics() | |
if not metrics: | |
return { | |
'Total Reviews': 0, | |
'Average Response Time': '0.00s', | |
'Reviews Today': 0, | |
'Device': str(reviewer.device) if reviewer.device else "Not initialized" | |
} | |
return metrics | |
except Exception as e: | |
logger.error(f"Metrics error: {e}") | |
return {"error": str(e)} | |
def update_all_outputs(code: str, language: str) -> tuple: | |
"""Update all outputs after code review.""" | |
result = review_code_interface(code, language) | |
history = get_history_interface() | |
metrics = get_metrics_interface() | |
return result, history, metrics | |
# Connect the interface | |
submit_btn.click( | |
update_all_outputs, | |
inputs=[code_input, language_input], | |
outputs=[output, history_output, metrics_output] | |
) | |
refresh_history.click( | |
get_history_interface, | |
outputs=history_output | |
) | |
refresh_metrics.click( | |
get_metrics_interface, | |
outputs=metrics_output | |
) | |
# Add example inputs | |
gr.Examples( | |
examples=[ | |
["""def add_numbers(a, b): | |
return a + b""", "python"], | |
["""function calculateSum(numbers) { | |
let sum = 0; | |
for(let i = 0; i < numbers.length; i++) { | |
sum += numbers[i]; | |
} | |
return sum; | |
}""", "javascript"] | |
], | |
inputs=[code_input, language_input] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
quiet=False | |
) | |