jzou19950715 commited on
Commit
eb04de8
·
verified ·
1 Parent(s): 720340e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +401 -158
app.py CHANGED
@@ -1,178 +1,421 @@
1
  import os
2
- import pandas as pd
3
- import requests
4
- import json
5
- import subprocess
6
  import gradio as gr
7
- import tempfile
8
- import sys
9
- from io import StringIO
10
- import matplotlib.pyplot as plt
11
- from pathlib import Path
12
- import importlib
13
- import ast
14
 
15
- class AICodeEnvironment:
16
- """Environment for AI to execute code safely"""
17
- def __init__(self):
18
- self.globals_dict = {}
19
- self.temp_dir = "temp_outputs"
20
- os.makedirs(self.temp_dir, exist_ok=True)
21
- self.setup_base_environment()
22
-
23
- def setup_base_environment(self):
24
- """Set up the base environment with commonly used packages"""
25
- self.globals_dict.update({
26
- 'pd': pd,
27
- 'plt': plt,
28
- '__builtins__': __builtins__,
29
- 'print': print
30
- })
31
-
32
- def dynamic_import(self, package_name):
33
- """Dynamically import packages as needed by AI"""
34
- try:
35
- # Install package if not present
36
- subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", package_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Import the package
39
- module = importlib.import_module(package_name)
40
- self.globals_dict[package_name] = module
41
- return True
42
- except Exception as e:
43
- print(f"Failed to import {package_name}: {str(e)}")
44
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def handle_imports(self, code):
47
- """Extract and handle all imports in the code"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
- tree = ast.parse(code)
50
- for node in ast.walk(tree):
51
- if isinstance(node, (ast.Import, ast.ImportFrom)):
52
- for name in node.names:
53
- package = name.name.split('.')[0]
54
- if package not in self.globals_dict:
55
- self.dynamic_import(package)
56
- return True
57
  except Exception as e:
58
- return False
59
-
60
- def execute_code(self, code):
61
- """Execute code and capture all outputs"""
62
- # Create temporary stdout to capture prints
63
- output_buffer = StringIO()
64
- sys.stdout = output_buffer
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
- # Handle imports first
68
- self.handle_imports(code)
69
-
70
- # Execute the code
71
- exec(code, self.globals_dict)
72
-
73
- # Capture terminal output
74
- text_output = output_buffer.getvalue()
75
-
76
- # Handle figures
77
- figures = []
78
- if 'plt' in self.globals_dict and plt.get_figs():
79
- for i, fig in enumerate(plt.get_figs()):
80
- fig_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.png")
81
- fig.savefig(fig_path)
82
- figures.append(fig_path)
83
- plt.close('all')
84
-
85
- # Check for other visualization libraries
86
- if 'fig' in self.globals_dict:
87
- fig = self.globals_dict['fig']
88
- # Handle Plotly figures
89
- if 'plotly.graph_objs' in str(type(fig)):
90
- fig_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.html")
91
- fig.write_html(fig_path)
92
- # Also save static image
93
- img_path = os.path.join(self.temp_dir, f"figure_{len(figures)}.png")
94
- fig.write_image(img_path)
95
- figures.append(img_path)
96
-
97
- return True, text_output, figures
98
 
99
  except Exception as e:
100
- return False, str(e), []
101
- finally:
102
- sys.stdout = sys.__stdout__
103
 
104
- def create_interface():
105
- """Create the interface for AI code execution"""
106
- env = AICodeEnvironment()
107
-
108
- def process_message(message, history, csv_file, api_key):
109
- """Process message and execute any code blocks"""
110
- if not api_key:
111
- return history + [[message, "Please provide your API key."]], None
112
-
113
- # Update environment with dataframe if CSV uploaded
114
- if csv_file:
115
- env.globals_dict['df'] = pd.read_csv(csv_file.name)
116
-
117
- # Get response from AI (example structure)
118
- response = query_ai(message, api_key)
119
-
120
- # Extract and execute code blocks
121
- code_blocks = response.split("```python")
122
- outputs = []
123
- figures = []
124
-
125
- for block in code_blocks[1:]: # Skip first split as it's before any code block
126
- code = block.split("```")[0].strip()
127
- success, output, new_figures = env.execute_code(code)
128
- outputs.append(output)
129
- figures.extend(new_figures)
130
-
131
- # Format response with outputs
132
- modified_response = response
133
- for i, output in enumerate(outputs):
134
- modified_response = modified_response.replace(
135
- f"```python{code_blocks[i+1].split('```')[0]}```",
136
- f"```python{code_blocks[i+1].split('```')[0]}```\nOutput:\n{output}"
137
- )
138
 
139
- return history + [[message, modified_response]], figures
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Create Gradio interface
142
- with gr.Blocks() as demo:
143
- gr.Markdown("# AI Code Execution Environment")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  with gr.Row():
146
- with gr.Column(scale=1):
147
- api_key = gr.Textbox(
148
- label="API Key",
149
- type="password"
150
- )
151
- csv_file = gr.File(
152
- label="Upload CSV",
153
- file_types=[".csv"]
154
- )
155
-
156
- with gr.Column(scale=3):
157
- chatbot = gr.Chatbot(height=500)
158
- gallery = gr.Gallery(label="Outputs")
159
-
160
- with gr.Row():
161
- msg = gr.Textbox(
162
- label="Message",
163
- placeholder="Ask me to analyze your data..."
164
- )
165
- clear = gr.Button("Clear")
166
-
167
- msg.submit(
168
- process_message,
169
- [msg, chatbot, csv_file, api_key],
170
- [chatbot, gallery]
171
  )
172
- clear.click(lambda: ([], []), None, [chatbot, gallery])
173
 
174
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  if __name__ == "__main__":
177
- demo = create_interface()
178
- demo.launch()
 
1
  import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
 
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from typing import List, Tuple
8
+ from dataclasses import dataclass
9
+ import logging
 
 
 
10
 
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ @dataclass
16
+ class ModelConfig:
17
+ hidden_size: int = 768
18
+ num_heads: int = 8
19
+ segment_size: int = 512
20
+ memory_size: int = 1024
21
+ max_length: int = 2048
22
+ model_name: str = "gpt2"
23
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ class CompressiveMemory(nn.Module):
26
+ """Long-term memory component that compresses and stores information"""
27
+ def __init__(self, config: ModelConfig):
28
+ super().__init__()
29
+ self.config = config
30
+ self.hidden_size = config.hidden_size
31
+ self.memory_size = config.memory_size
32
+
33
+ # Initialize memory components
34
+ self.memory = nn.Parameter(torch.randn(config.memory_size, config.hidden_size))
35
+ self.memory_key = nn.Linear(config.hidden_size, config.hidden_size)
36
+ self.memory_value = nn.Linear(config.hidden_size, config.hidden_size)
37
+
38
+ # Memory statistics
39
+ self.updates = 0
40
+ self.memory_usage = torch.zeros(config.memory_size)
41
+
42
+ # Initialize on specified device
43
+ self.to(config.device)
44
+
45
+ def forward(self, query: torch.Tensor) -> torch.Tensor:
46
+ """Retrieve information from memory using attention"""
47
+ # Scale query for stable attention
48
+ query = query / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float32))
49
+
50
+ # Compute attention scores
51
+ attention = torch.matmul(query, self.memory.T)
52
+ attention_weights = F.softmax(attention, dim=-1)
53
+
54
+ # Update memory usage statistics
55
+ with torch.no_grad():
56
+ self.memory_usage += attention_weights.sum(dim=0)
57
+
58
+ # Retrieve from memory
59
+ retrieved = torch.matmul(attention_weights, self.memory)
60
+ return retrieved
61
+
62
+ def update_memory(self, keys: torch.Tensor, values: torch.Tensor):
63
+ """Update memory with new information"""
64
+ # Compress inputs
65
+ compressed_keys = self.memory_key(keys)
66
+ compressed_values = self.memory_value(values)
67
+
68
+ # Compute update
69
+ with torch.no_grad():
70
+ update = torch.matmul(compressed_keys.T, compressed_values)
71
 
72
+ # Progressive update with decay
73
+ decay = 0.9
74
+ update_rate = 0.1
75
+ self.memory.data = decay * self.memory.data + update_rate * update[:self.memory_size]
76
+
77
+ # Track updates
78
+ self.updates += 1
79
+
80
+ # Optional: Reset rarely used memory locations
81
+ if self.updates % 1000 == 0:
82
+ rarely_used = self.memory_usage < (self.memory_usage.mean() / 10)
83
+ self.memory.data[rarely_used] = torch.randn_like(
84
+ self.memory.data[rarely_used]
85
+ ) * 0.1
86
+ self.memory_usage[rarely_used] = 0
87
+
88
+ def reset_memory(self):
89
+ """Reset memory to initial state"""
90
+ self.memory.data = torch.randn_like(self.memory.data) * 0.1
91
+ self.memory_usage.zero_()
92
+ self.updates = 0
93
 
94
+ class InfiniteAttention(nn.Module):
95
+ """Main attention module combining local and long-term memory attention"""
96
+ def __init__(self, config: ModelConfig):
97
+ super().__init__()
98
+ self.config = config
99
+
100
+ # Core attention components
101
+ self.query = nn.Linear(config.hidden_size, config.hidden_size)
102
+ self.key = nn.Linear(config.hidden_size, config.hidden_size)
103
+ self.value = nn.Linear(config.hidden_size, config.hidden_size)
104
+
105
+ # Multi-head attention setup
106
+ self.num_heads = config.num_heads
107
+ self.head_dim = config.hidden_size // config.num_heads
108
+ assert self.head_dim * config.num_heads == config.hidden_size, "hidden_size must be divisible by num_heads"
109
+
110
+ # Memory component
111
+ self.memory = CompressiveMemory(config)
112
+
113
+ # Output and gating
114
+ self.output = nn.Linear(config.hidden_size * 2, config.hidden_size)
115
+ self.gate = nn.Parameter(torch.zeros(1))
116
+
117
+ # Load base language model and tokenizer
118
  try:
119
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
120
+ self.base_model = AutoModelForCausalLM.from_pretrained(config.model_name)
121
+ self.base_model.to(config.device)
 
 
 
 
 
122
  except Exception as e:
123
+ logger.error(f"Error loading base model: {str(e)}")
124
+ raise
 
 
 
 
 
125
 
126
+ # Move model to specified device
127
+ self.to(config.device)
128
+
129
+ def split_heads(self, x: torch.Tensor) -> torch.Tensor:
130
+ """Split tensor into attention heads"""
131
+ batch_size, seq_length, _ = x.size()
132
+ return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
133
+
134
+ def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
135
+ """Merge attention heads back together"""
136
+ batch_size, _, seq_length, _ = x.size()
137
+ return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.config.hidden_size)
138
+
139
+ def get_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
140
+ """Get embeddings from base model"""
141
+ return self.base_model.transformer.wte(input_ids)
142
+
143
+ def process_segment(self, segment: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
144
+ """Process a single segment with attention"""
145
+ # Compute Q, K, V
146
+ q = self.split_heads(self.query(segment))
147
+ k = self.split_heads(self.key(segment))
148
+ v = self.split_heads(self.value(segment))
149
+
150
+ # Scale query
151
+ q = q / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
152
+
153
+ # Compute local attention scores
154
+ local_attn = torch.matmul(q, k.transpose(-2, -1))
155
+
156
+ if mask is not None:
157
+ local_attn = local_attn.masked_fill(mask == 0, float('-inf'))
158
+
159
+ # Apply softmax
160
+ local_attn = F.softmax(local_attn, dim=-1)
161
+
162
+ # Compute local attention output
163
+ local_output = self.merge_heads(torch.matmul(local_attn, v))
164
+
165
+ # Get memory output
166
+ memory_output = self.memory(q.view(-1, self.config.hidden_size))
167
+ memory_output = memory_output.view(segment.size())
168
+
169
+ # Update memory
170
+ self.memory.update_memory(k.view(-1, self.config.hidden_size),
171
+ v.view(-1, self.config.hidden_size))
172
+
173
+ # Combine outputs using learned gate
174
+ gate = torch.sigmoid(self.gate)
175
+ combined = torch.cat([
176
+ gate * local_output,
177
+ (1 - gate) * memory_output
178
+ ], dim=-1)
179
+
180
+ return self.output(combined)
181
+
182
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
183
+ """Process input sequence by segments"""
184
+ batch_size = x.size(0)
185
+
186
+ # Split into segments
187
+ segments = x.unfold(1, self.config.segment_size,
188
+ step=self.config.segment_size)
189
+ output_segments = []
190
+
191
+ # Process each segment
192
+ for segment in segments.unbind(1):
193
+ segment_output = self.process_segment(segment, mask)
194
+ output_segments.append(segment_output)
195
+
196
+ # Handle any remaining tokens
197
+ remainder_start = segments.size(1) * self.config.segment_size
198
+ if remainder_start < x.size(1):
199
+ remainder = x[:, remainder_start:]
200
+ if remainder.size(1) > 0:
201
+ remainder_output = self.process_segment(remainder, mask)
202
+ output_segments.append(remainder_output)
203
+
204
+ # Combine all segments
205
+ return torch.cat(output_segments, dim=1)
206
+
207
+ def generate_response(self, input_text: str, max_new_tokens: int = 100) -> str:
208
+ """Generate response from input text"""
209
  try:
210
+ # Prepare input
211
+ inputs = self.tokenizer(input_text,
212
+ return_tensors="pt",
213
+ truncation=False)
214
+ input_ids = inputs["input_ids"].to(self.config.device)
215
+
216
+ # Get embeddings
217
+ embeddings = self.get_embeddings(input_ids)
218
+
219
+ # Process through infinite attention
220
+ attended = self.forward(embeddings)
221
+
222
+ # Generate response using base model with attended context
223
+ outputs = self.base_model.generate(
224
+ input_ids,
225
+ max_new_tokens=max_new_tokens,
226
+ num_return_sequences=1,
227
+ pad_token_id=self.tokenizer.eos_token_id,
228
+ do_sample=True,
229
+ temperature=0.7,
230
+ top_p=0.9,
231
+ )
232
+
233
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
234
 
235
  except Exception as e:
236
+ logger.error(f"Error in generate_response: {str(e)}")
237
+ return f"Error generating response: {str(e)}"
 
238
 
239
+ class ChatBot:
240
+ """Manages chat history and message processing"""
241
+ def __init__(self, config: ModelConfig):
242
+ self.config = config
243
+ self.model = InfiniteAttention(config)
244
+ self.history: List[Tuple[str, str]] = []
245
+ self.max_history_tokens = 4096 # Adjust based on your needs
246
+
247
+ def count_tokens(self, text: str) -> int:
248
+ """Count tokens in text using model's tokenizer"""
249
+ return len(self.model.tokenizer.encode(text))
250
+
251
+ def get_truncated_history(self) -> str:
252
+ """Get history truncated to max tokens"""
253
+ history_text = ""
254
+ token_count = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ for msg, response in reversed(self.history):
257
+ new_text = f"User: {msg}\nAssistant: {response}\n"
258
+ new_tokens = self.count_tokens(new_text)
259
+
260
+ if token_count + new_tokens > self.max_history_tokens:
261
+ break
262
+
263
+ history_text = new_text + history_text
264
+ token_count += new_tokens
265
+
266
+ return history_text.strip()
267
+
268
+ def process_message(self, message: str) -> Tuple[str, List[Tuple[str, str]]]:
269
+ """Process a message and return response with updated history"""
270
+ try:
271
+ # Skip empty messages
272
+ if not message.strip():
273
+ return "", self.history
274
+
275
+ # Prepare context with history
276
+ history_text = self.get_truncated_history()
277
+ context = f"{history_text}\nUser: {message}\nAssistant:"
278
+
279
+ # Generate response
280
+ full_response = self.model.generate_response(context)
281
+
282
+ # Extract just the new response (after "Assistant:")
283
+ response = full_response.split("Assistant:")[-1].strip()
284
+
285
+ # Update history
286
+ self.history.append((message, response))
287
+
288
+ return response, self.history
289
+
290
+ except Exception as e:
291
+ error_msg = f"Error processing message: {str(e)}"
292
+ logger.error(error_msg)
293
+ return error_msg, self.history
294
+
295
+ def save_conversation(self, filename: str):
296
+ """Save conversation history to file"""
297
+ try:
298
+ with open(filename, 'w', encoding='utf-8') as f:
299
+ for msg, response in self.history:
300
+ f.write(f"User: {msg}\n")
301
+ f.write(f"Assistant: {response}\n\n")
302
+ except Exception as e:
303
+ logger.error(f"Error saving conversation: {str(e)}")
304
 
305
+ def load_conversation(self, filename: str):
306
+ """Load conversation history from file"""
307
+ try:
308
+ with open(filename, 'r', encoding='utf-8') as f:
309
+ content = f.read()
310
+
311
+ # Reset history
312
+ self.history = []
313
+
314
+ # Parse content
315
+ conversations = content.strip().split('\n\n')
316
+ for conv in conversations:
317
+ if 'User:' in conv and 'Assistant:' in conv:
318
+ parts = conv.split('Assistant:')
319
+ msg = parts[0].replace('User:', '').strip()
320
+ response = parts[1].strip()
321
+ self.history.append((msg, response))
322
+
323
+ except Exception as e:
324
+ logger.error(f"Error loading conversation: {str(e)}")
325
+
326
+ def create_gradio_interface():
327
+ """Create and configure Gradio interface"""
328
+
329
+ # Initialize config and chatbot
330
+ config = ModelConfig()
331
+ chatbot = ChatBot(config)
332
+
333
+ def user_message(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
334
+ """Handle incoming user messages"""
335
+ response, updated_history = chatbot.process_message(message)
336
+ return response, updated_history
337
+
338
+ def save_chat(filename: str):
339
+ """Save chat history to file"""
340
+ if not filename.endswith('.txt'):
341
+ filename += '.txt'
342
+ chatbot.save_conversation(filename)
343
+ return f"Conversation saved to {filename}"
344
+
345
+ def load_chat(filename: str):
346
+ """Load chat history from file"""
347
+ if not filename.endswith('.txt'):
348
+ filename += '.txt'
349
+ chatbot.load_conversation(filename)
350
+ return f"Conversation loaded from {filename}"
351
+
352
+ # Create main chat interface
353
+ chat_interface = gr.ChatInterface(
354
+ fn=user_message,
355
+ title="Long Context AI Chat",
356
+ description="Chat with an AI that can handle very long conversations",
357
+ examples=[
358
+ ["Tell me a story about space exploration"],
359
+ ["What were the key points from our earlier discussion?"],
360
+ ["Can you summarize everything we've talked about so far?"]
361
+ ],
362
+ retry_btn=None,
363
+ undo_btn="Delete Last",
364
+ clear_btn="Clear"
365
+ )
366
+
367
+ # Add save/load functionality
368
+ with gr.Blocks() as interface:
369
+ chat_interface.render()
370
 
371
  with gr.Row():
372
+ save_file = gr.Textbox(
373
+ label="Save conversation to file",
374
+ placeholder="conversation.txt"
375
+ )
376
+ save_btn = gr.Button("Save")
377
+ save_output = gr.Textbox(label="Save Status")
378
+
379
+ load_file = gr.Textbox(
380
+ label="Load conversation from file",
381
+ placeholder="conversation.txt"
382
+ )
383
+ load_btn = gr.Button("Load")
384
+ load_output = gr.Textbox(label="Load Status")
385
+
386
+ save_btn.click(
387
+ fn=save_chat,
388
+ inputs=[save_file],
389
+ outputs=[save_output]
390
+ )
391
+
392
+ load_btn.click(
393
+ fn=load_chat,
394
+ inputs=[load_file],
395
+ outputs=[load_output]
 
396
  )
 
397
 
398
+ return interface
399
+
400
+ def main():
401
+ """Main application entry point"""
402
+ try:
403
+ # Create interface
404
+ interface = create_gradio_interface()
405
+
406
+ # Launch with configuration
407
+ interface.launch(
408
+ server_name="0.0.0.0",
409
+ server_port=7860,
410
+ share=False,
411
+ debug=True,
412
+ auth=None, # Add authentication if needed
413
+ ssl_keyfile=None, # Add SSL if needed
414
+ ssl_certfile=None
415
+ )
416
+ except Exception as e:
417
+ logger.error(f"Error launching application: {str(e)}")
418
+ raise
419
 
420
  if __name__ == "__main__":
421
+ main()