Spaces:
Build error
Build error
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from typing import List, Tuple | |
from dataclasses import dataclass | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ModelConfig: | |
hidden_size: int = 768 | |
num_heads: int = 8 | |
segment_size: int = 512 | |
memory_size: int = 1024 | |
max_length: int = 2048 | |
model_name: str = "gpt2" | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
class CompressiveMemory(nn.Module): | |
"""Long-term memory component that compresses and stores information""" | |
def __init__(self, config: ModelConfig): | |
super().__init__() | |
self.config = config | |
self.hidden_size = config.hidden_size | |
self.memory_size = config.memory_size | |
# Initialize memory components | |
self.memory = nn.Parameter(torch.randn(config.memory_size, config.hidden_size)) | |
self.memory_key = nn.Linear(config.hidden_size, config.hidden_size) | |
self.memory_value = nn.Linear(config.hidden_size, config.hidden_size) | |
# Memory statistics | |
self.updates = 0 | |
self.memory_usage = torch.zeros(config.memory_size) | |
# Initialize on specified device | |
self.to(config.device) | |
def forward(self, query: torch.Tensor) -> torch.Tensor: | |
"""Retrieve information from memory using attention""" | |
# Scale query for stable attention | |
query = query / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float32)) | |
# Compute attention scores | |
attention = torch.matmul(query, self.memory.T) | |
attention_weights = F.softmax(attention, dim=-1) | |
# Update memory usage statistics | |
with torch.no_grad(): | |
self.memory_usage += attention_weights.sum(dim=0) | |
# Retrieve from memory | |
retrieved = torch.matmul(attention_weights, self.memory) | |
return retrieved | |
def update_memory(self, keys: torch.Tensor, values: torch.Tensor): | |
"""Update memory with new information""" | |
# Compress inputs | |
compressed_keys = self.memory_key(keys) | |
compressed_values = self.memory_value(values) | |
# Compute update | |
with torch.no_grad(): | |
update = torch.matmul(compressed_keys.T, compressed_values) | |
# Progressive update with decay | |
decay = 0.9 | |
update_rate = 0.1 | |
self.memory.data = decay * self.memory.data + update_rate * update[:self.memory_size] | |
# Track updates | |
self.updates += 1 | |
# Optional: Reset rarely used memory locations | |
if self.updates % 1000 == 0: | |
rarely_used = self.memory_usage < (self.memory_usage.mean() / 10) | |
self.memory.data[rarely_used] = torch.randn_like( | |
self.memory.data[rarely_used] | |
) * 0.1 | |
self.memory_usage[rarely_used] = 0 | |
def reset_memory(self): | |
"""Reset memory to initial state""" | |
self.memory.data = torch.randn_like(self.memory.data) * 0.1 | |
self.memory_usage.zero_() | |
self.updates = 0 | |
class InfiniteAttention(nn.Module): | |
"""Main attention module combining local and long-term memory attention""" | |
def __init__(self, config: ModelConfig): | |
super().__init__() | |
self.config = config | |
# Core attention components | |
self.query = nn.Linear(config.hidden_size, config.hidden_size) | |
self.key = nn.Linear(config.hidden_size, config.hidden_size) | |
self.value = nn.Linear(config.hidden_size, config.hidden_size) | |
# Multi-head attention setup | |
self.num_heads = config.num_heads | |
self.head_dim = config.hidden_size // config.num_heads | |
assert self.head_dim * config.num_heads == config.hidden_size, "hidden_size must be divisible by num_heads" | |
# Memory component | |
self.memory = CompressiveMemory(config) | |
# Output and gating | |
self.output = nn.Linear(config.hidden_size * 2, config.hidden_size) | |
self.gate = nn.Parameter(torch.zeros(1)) | |
# Load base language model and tokenizer | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name) | |
self.base_model = AutoModelForCausalLM.from_pretrained(config.model_name) | |
self.base_model.to(config.device) | |
except Exception as e: | |
logger.error(f"Error loading base model: {str(e)}") | |
raise | |
# Move model to specified device | |
self.to(config.device) | |
def split_heads(self, x: torch.Tensor) -> torch.Tensor: | |
"""Split tensor into attention heads""" | |
batch_size, seq_length, _ = x.size() | |
return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) | |
def merge_heads(self, x: torch.Tensor) -> torch.Tensor: | |
"""Merge attention heads back together""" | |
batch_size, _, seq_length, _ = x.size() | |
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.config.hidden_size) | |
def get_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | |
"""Get embeddings from base model""" | |
return self.base_model.transformer.wte(input_ids) | |
def process_segment(self, segment: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
"""Process a single segment with attention""" | |
# Compute Q, K, V | |
q = self.split_heads(self.query(segment)) | |
k = self.split_heads(self.key(segment)) | |
v = self.split_heads(self.value(segment)) | |
# Scale query | |
q = q / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)) | |
# Compute local attention scores | |
local_attn = torch.matmul(q, k.transpose(-2, -1)) | |
if mask is not None: | |
local_attn = local_attn.masked_fill(mask == 0, float('-inf')) | |
# Apply softmax | |
local_attn = F.softmax(local_attn, dim=-1) | |
# Compute local attention output | |
local_output = self.merge_heads(torch.matmul(local_attn, v)) | |
# Get memory output | |
memory_output = self.memory(q.view(-1, self.config.hidden_size)) | |
memory_output = memory_output.view(segment.size()) | |
# Update memory | |
self.memory.update_memory(k.view(-1, self.config.hidden_size), | |
v.view(-1, self.config.hidden_size)) | |
# Combine outputs using learned gate | |
gate = torch.sigmoid(self.gate) | |
combined = torch.cat([ | |
gate * local_output, | |
(1 - gate) * memory_output | |
], dim=-1) | |
return self.output(combined) | |
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
"""Process input sequence by segments""" | |
batch_size = x.size(0) | |
# Split into segments | |
segments = x.unfold(1, self.config.segment_size, | |
step=self.config.segment_size) | |
output_segments = [] | |
# Process each segment | |
for segment in segments.unbind(1): | |
segment_output = self.process_segment(segment, mask) | |
output_segments.append(segment_output) | |
# Handle any remaining tokens | |
remainder_start = segments.size(1) * self.config.segment_size | |
if remainder_start < x.size(1): | |
remainder = x[:, remainder_start:] | |
if remainder.size(1) > 0: | |
remainder_output = self.process_segment(remainder, mask) | |
output_segments.append(remainder_output) | |
# Combine all segments | |
return torch.cat(output_segments, dim=1) | |
def generate_response(self, input_text: str, max_new_tokens: int = 100) -> str: | |
"""Generate response from input text""" | |
try: | |
# Prepare input | |
inputs = self.tokenizer(input_text, | |
return_tensors="pt", | |
truncation=False) | |
input_ids = inputs["input_ids"].to(self.config.device) | |
# Get embeddings | |
embeddings = self.get_embeddings(input_ids) | |
# Process through infinite attention | |
attended = self.forward(embeddings) | |
# Generate response using base model with attended context | |
outputs = self.base_model.generate( | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.eos_token_id, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
return f"Error generating response: {str(e)}" | |
class ChatBot: | |
"""Manages chat history and message processing""" | |
def __init__(self, config: ModelConfig): | |
self.config = config | |
self.model = InfiniteAttention(config) | |
self.history: List[Tuple[str, str]] = [] | |
self.max_history_tokens = 4096 # Adjust based on your needs | |
def count_tokens(self, text: str) -> int: | |
"""Count tokens in text using model's tokenizer""" | |
return len(self.model.tokenizer.encode(text)) | |
def get_truncated_history(self) -> str: | |
"""Get history truncated to max tokens""" | |
history_text = "" | |
token_count = 0 | |
for msg, response in reversed(self.history): | |
new_text = f"User: {msg}\nAssistant: {response}\n" | |
new_tokens = self.count_tokens(new_text) | |
if token_count + new_tokens > self.max_history_tokens: | |
break | |
history_text = new_text + history_text | |
token_count += new_tokens | |
return history_text.strip() | |
def process_message(self, message: str) -> Tuple[str, List[Tuple[str, str]]]: | |
"""Process a message and return response with updated history""" | |
try: | |
# Skip empty messages | |
if not message.strip(): | |
return "", self.history | |
# Prepare context with history | |
history_text = self.get_truncated_history() | |
context = f"{history_text}\nUser: {message}\nAssistant:" | |
# Generate response | |
full_response = self.model.generate_response(context) | |
# Extract just the new response (after "Assistant:") | |
response = full_response.split("Assistant:")[-1].strip() | |
# Update history | |
self.history.append((message, response)) | |
return response, self.history | |
except Exception as e: | |
error_msg = f"Error processing message: {str(e)}" | |
logger.error(error_msg) | |
return error_msg, self.history | |
def save_conversation(self, filename: str): | |
"""Save conversation history to file""" | |
try: | |
with open(filename, 'w', encoding='utf-8') as f: | |
for msg, response in self.history: | |
f.write(f"User: {msg}\n") | |
f.write(f"Assistant: {response}\n\n") | |
except Exception as e: | |
logger.error(f"Error saving conversation: {str(e)}") | |
def load_conversation(self, filename: str): | |
"""Load conversation history from file""" | |
try: | |
with open(filename, 'r', encoding='utf-8') as f: | |
content = f.read() | |
# Reset history | |
self.history = [] | |
# Parse content | |
conversations = content.strip().split('\n\n') | |
for conv in conversations: | |
if 'User:' in conv and 'Assistant:' in conv: | |
parts = conv.split('Assistant:') | |
msg = parts[0].replace('User:', '').strip() | |
response = parts[1].strip() | |
self.history.append((msg, response)) | |
except Exception as e: | |
logger.error(f"Error loading conversation: {str(e)}") | |
def create_gradio_interface(): | |
"""Create and configure Gradio interface""" | |
# Initialize config and chatbot | |
config = ModelConfig() | |
chatbot = ChatBot(config) | |
def user_message(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]: | |
"""Handle incoming user messages""" | |
response, updated_history = chatbot.process_message(message) | |
return response, updated_history | |
def save_chat(filename: str): | |
"""Save chat history to file""" | |
if not filename.endswith('.txt'): | |
filename += '.txt' | |
chatbot.save_conversation(filename) | |
return f"Conversation saved to {filename}" | |
def load_chat(filename: str): | |
"""Load chat history from file""" | |
if not filename.endswith('.txt'): | |
filename += '.txt' | |
chatbot.load_conversation(filename) | |
return f"Conversation loaded from {filename}" | |
# Create main chat interface | |
chat_interface = gr.ChatInterface( | |
fn=user_message, | |
title="Long Context AI Chat", | |
description="Chat with an AI that can handle very long conversations", | |
examples=[ | |
["Tell me a story about space exploration"], | |
["What were the key points from our earlier discussion?"], | |
["Can you summarize everything we've talked about so far?"] | |
], | |
retry_btn=None, | |
undo_btn="Delete Last", | |
clear_btn="Clear" | |
) | |
# Add save/load functionality | |
with gr.Blocks() as interface: | |
chat_interface.render() | |
with gr.Row(): | |
save_file = gr.Textbox( | |
label="Save conversation to file", | |
placeholder="conversation.txt" | |
) | |
save_btn = gr.Button("Save") | |
save_output = gr.Textbox(label="Save Status") | |
load_file = gr.Textbox( | |
label="Load conversation from file", | |
placeholder="conversation.txt" | |
) | |
load_btn = gr.Button("Load") | |
load_output = gr.Textbox(label="Load Status") | |
save_btn.click( | |
fn=save_chat, | |
inputs=[save_file], | |
outputs=[save_output] | |
) | |
load_btn.click( | |
fn=load_chat, | |
inputs=[load_file], | |
outputs=[load_output] | |
) | |
return interface | |
def main(): | |
"""Main application entry point""" | |
try: | |
# Create interface | |
interface = create_gradio_interface() | |
# Launch with configuration | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True, | |
auth=None, # Add authentication if needed | |
ssl_keyfile=None, # Add SSL if needed | |
ssl_certfile=None | |
) | |
except Exception as e: | |
logger.error(f"Error launching application: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
main() |