jzou19950715's picture
Update app.py
eb04de8 verified
raw
history blame
15.6 kB
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple
from dataclasses import dataclass
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
hidden_size: int = 768
num_heads: int = 8
segment_size: int = 512
memory_size: int = 1024
max_length: int = 2048
model_name: str = "gpt2"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
class CompressiveMemory(nn.Module):
"""Long-term memory component that compresses and stores information"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.memory_size = config.memory_size
# Initialize memory components
self.memory = nn.Parameter(torch.randn(config.memory_size, config.hidden_size))
self.memory_key = nn.Linear(config.hidden_size, config.hidden_size)
self.memory_value = nn.Linear(config.hidden_size, config.hidden_size)
# Memory statistics
self.updates = 0
self.memory_usage = torch.zeros(config.memory_size)
# Initialize on specified device
self.to(config.device)
def forward(self, query: torch.Tensor) -> torch.Tensor:
"""Retrieve information from memory using attention"""
# Scale query for stable attention
query = query / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float32))
# Compute attention scores
attention = torch.matmul(query, self.memory.T)
attention_weights = F.softmax(attention, dim=-1)
# Update memory usage statistics
with torch.no_grad():
self.memory_usage += attention_weights.sum(dim=0)
# Retrieve from memory
retrieved = torch.matmul(attention_weights, self.memory)
return retrieved
def update_memory(self, keys: torch.Tensor, values: torch.Tensor):
"""Update memory with new information"""
# Compress inputs
compressed_keys = self.memory_key(keys)
compressed_values = self.memory_value(values)
# Compute update
with torch.no_grad():
update = torch.matmul(compressed_keys.T, compressed_values)
# Progressive update with decay
decay = 0.9
update_rate = 0.1
self.memory.data = decay * self.memory.data + update_rate * update[:self.memory_size]
# Track updates
self.updates += 1
# Optional: Reset rarely used memory locations
if self.updates % 1000 == 0:
rarely_used = self.memory_usage < (self.memory_usage.mean() / 10)
self.memory.data[rarely_used] = torch.randn_like(
self.memory.data[rarely_used]
) * 0.1
self.memory_usage[rarely_used] = 0
def reset_memory(self):
"""Reset memory to initial state"""
self.memory.data = torch.randn_like(self.memory.data) * 0.1
self.memory_usage.zero_()
self.updates = 0
class InfiniteAttention(nn.Module):
"""Main attention module combining local and long-term memory attention"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Core attention components
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
# Multi-head attention setup
self.num_heads = config.num_heads
self.head_dim = config.hidden_size // config.num_heads
assert self.head_dim * config.num_heads == config.hidden_size, "hidden_size must be divisible by num_heads"
# Memory component
self.memory = CompressiveMemory(config)
# Output and gating
self.output = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.gate = nn.Parameter(torch.zeros(1))
# Load base language model and tokenizer
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
self.base_model = AutoModelForCausalLM.from_pretrained(config.model_name)
self.base_model.to(config.device)
except Exception as e:
logger.error(f"Error loading base model: {str(e)}")
raise
# Move model to specified device
self.to(config.device)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Split tensor into attention heads"""
batch_size, seq_length, _ = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Merge attention heads back together"""
batch_size, _, seq_length, _ = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.config.hidden_size)
def get_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get embeddings from base model"""
return self.base_model.transformer.wte(input_ids)
def process_segment(self, segment: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""Process a single segment with attention"""
# Compute Q, K, V
q = self.split_heads(self.query(segment))
k = self.split_heads(self.key(segment))
v = self.split_heads(self.value(segment))
# Scale query
q = q / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
# Compute local attention scores
local_attn = torch.matmul(q, k.transpose(-2, -1))
if mask is not None:
local_attn = local_attn.masked_fill(mask == 0, float('-inf'))
# Apply softmax
local_attn = F.softmax(local_attn, dim=-1)
# Compute local attention output
local_output = self.merge_heads(torch.matmul(local_attn, v))
# Get memory output
memory_output = self.memory(q.view(-1, self.config.hidden_size))
memory_output = memory_output.view(segment.size())
# Update memory
self.memory.update_memory(k.view(-1, self.config.hidden_size),
v.view(-1, self.config.hidden_size))
# Combine outputs using learned gate
gate = torch.sigmoid(self.gate)
combined = torch.cat([
gate * local_output,
(1 - gate) * memory_output
], dim=-1)
return self.output(combined)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""Process input sequence by segments"""
batch_size = x.size(0)
# Split into segments
segments = x.unfold(1, self.config.segment_size,
step=self.config.segment_size)
output_segments = []
# Process each segment
for segment in segments.unbind(1):
segment_output = self.process_segment(segment, mask)
output_segments.append(segment_output)
# Handle any remaining tokens
remainder_start = segments.size(1) * self.config.segment_size
if remainder_start < x.size(1):
remainder = x[:, remainder_start:]
if remainder.size(1) > 0:
remainder_output = self.process_segment(remainder, mask)
output_segments.append(remainder_output)
# Combine all segments
return torch.cat(output_segments, dim=1)
def generate_response(self, input_text: str, max_new_tokens: int = 100) -> str:
"""Generate response from input text"""
try:
# Prepare input
inputs = self.tokenizer(input_text,
return_tensors="pt",
truncation=False)
input_ids = inputs["input_ids"].to(self.config.device)
# Get embeddings
embeddings = self.get_embeddings(input_ids)
# Process through infinite attention
attended = self.forward(embeddings)
# Generate response using base model with attended context
outputs = self.base_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
logger.error(f"Error in generate_response: {str(e)}")
return f"Error generating response: {str(e)}"
class ChatBot:
"""Manages chat history and message processing"""
def __init__(self, config: ModelConfig):
self.config = config
self.model = InfiniteAttention(config)
self.history: List[Tuple[str, str]] = []
self.max_history_tokens = 4096 # Adjust based on your needs
def count_tokens(self, text: str) -> int:
"""Count tokens in text using model's tokenizer"""
return len(self.model.tokenizer.encode(text))
def get_truncated_history(self) -> str:
"""Get history truncated to max tokens"""
history_text = ""
token_count = 0
for msg, response in reversed(self.history):
new_text = f"User: {msg}\nAssistant: {response}\n"
new_tokens = self.count_tokens(new_text)
if token_count + new_tokens > self.max_history_tokens:
break
history_text = new_text + history_text
token_count += new_tokens
return history_text.strip()
def process_message(self, message: str) -> Tuple[str, List[Tuple[str, str]]]:
"""Process a message and return response with updated history"""
try:
# Skip empty messages
if not message.strip():
return "", self.history
# Prepare context with history
history_text = self.get_truncated_history()
context = f"{history_text}\nUser: {message}\nAssistant:"
# Generate response
full_response = self.model.generate_response(context)
# Extract just the new response (after "Assistant:")
response = full_response.split("Assistant:")[-1].strip()
# Update history
self.history.append((message, response))
return response, self.history
except Exception as e:
error_msg = f"Error processing message: {str(e)}"
logger.error(error_msg)
return error_msg, self.history
def save_conversation(self, filename: str):
"""Save conversation history to file"""
try:
with open(filename, 'w', encoding='utf-8') as f:
for msg, response in self.history:
f.write(f"User: {msg}\n")
f.write(f"Assistant: {response}\n\n")
except Exception as e:
logger.error(f"Error saving conversation: {str(e)}")
def load_conversation(self, filename: str):
"""Load conversation history from file"""
try:
with open(filename, 'r', encoding='utf-8') as f:
content = f.read()
# Reset history
self.history = []
# Parse content
conversations = content.strip().split('\n\n')
for conv in conversations:
if 'User:' in conv and 'Assistant:' in conv:
parts = conv.split('Assistant:')
msg = parts[0].replace('User:', '').strip()
response = parts[1].strip()
self.history.append((msg, response))
except Exception as e:
logger.error(f"Error loading conversation: {str(e)}")
def create_gradio_interface():
"""Create and configure Gradio interface"""
# Initialize config and chatbot
config = ModelConfig()
chatbot = ChatBot(config)
def user_message(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
"""Handle incoming user messages"""
response, updated_history = chatbot.process_message(message)
return response, updated_history
def save_chat(filename: str):
"""Save chat history to file"""
if not filename.endswith('.txt'):
filename += '.txt'
chatbot.save_conversation(filename)
return f"Conversation saved to {filename}"
def load_chat(filename: str):
"""Load chat history from file"""
if not filename.endswith('.txt'):
filename += '.txt'
chatbot.load_conversation(filename)
return f"Conversation loaded from {filename}"
# Create main chat interface
chat_interface = gr.ChatInterface(
fn=user_message,
title="Long Context AI Chat",
description="Chat with an AI that can handle very long conversations",
examples=[
["Tell me a story about space exploration"],
["What were the key points from our earlier discussion?"],
["Can you summarize everything we've talked about so far?"]
],
retry_btn=None,
undo_btn="Delete Last",
clear_btn="Clear"
)
# Add save/load functionality
with gr.Blocks() as interface:
chat_interface.render()
with gr.Row():
save_file = gr.Textbox(
label="Save conversation to file",
placeholder="conversation.txt"
)
save_btn = gr.Button("Save")
save_output = gr.Textbox(label="Save Status")
load_file = gr.Textbox(
label="Load conversation from file",
placeholder="conversation.txt"
)
load_btn = gr.Button("Load")
load_output = gr.Textbox(label="Load Status")
save_btn.click(
fn=save_chat,
inputs=[save_file],
outputs=[save_output]
)
load_btn.click(
fn=load_chat,
inputs=[load_file],
outputs=[load_output]
)
return interface
def main():
"""Main application entry point"""
try:
# Create interface
interface = create_gradio_interface()
# Launch with configuration
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
auth=None, # Add authentication if needed
ssl_keyfile=None, # Add SSL if needed
ssl_certfile=None
)
except Exception as e:
logger.error(f"Error launching application: {str(e)}")
raise
if __name__ == "__main__":
main()