jzou19950715 commited on
Commit
dfc517e
·
verified ·
1 Parent(s): 91c7f7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -371
app.py CHANGED
@@ -1,398 +1,266 @@
1
  import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import gradio as gr
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from typing import List, Tuple
8
- from dataclasses import dataclass
9
  import logging
 
 
 
 
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
- @dataclass
16
- class ModelConfig:
17
- hidden_size: int = 768
18
- num_heads: int = 8
19
- segment_size: int = 512
20
- memory_size: int = 1024
21
- max_length: int = 2048
22
- model_name: str = "gpt2"
23
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
24
 
25
- class CompressiveMemory(nn.Module):
26
- """Long-term memory component that compresses and stores information"""
27
- def __init__(self, config: ModelConfig):
28
- super().__init__()
29
- self.config = config
30
- self.hidden_size = config.hidden_size
31
- self.memory_size = config.memory_size
32
-
33
- # Initialize memory components
34
- self.memory = nn.Parameter(torch.randn(config.memory_size, config.hidden_size))
35
- self.memory_key = nn.Linear(config.hidden_size, config.hidden_size)
36
- self.memory_value = nn.Linear(config.hidden_size, config.hidden_size)
37
-
38
- # Memory statistics
39
- self.updates = 0
40
- self.memory_usage = torch.zeros(config.memory_size)
41
-
42
- # Initialize on specified device
43
- self.to(config.device)
44
 
45
- def forward(self, query: torch.Tensor) -> torch.Tensor:
46
- """Retrieve information from memory using attention"""
47
- # Scale query for stable attention
48
- query = query / torch.sqrt(torch.tensor(self.hidden_size, dtype=torch.float32))
49
-
50
- # Compute attention scores
51
- attention = torch.matmul(query, self.memory.T)
52
- attention_weights = F.softmax(attention, dim=-1)
53
-
54
- # Update memory usage statistics
55
- with torch.no_grad():
56
- self.memory_usage += attention_weights.sum(dim=0)
57
-
58
- # Retrieve from memory
59
- retrieved = torch.matmul(attention_weights, self.memory)
60
- return retrieved
61
-
62
- def update_memory(self, keys: torch.Tensor, values: torch.Tensor):
63
- """Update memory with new information"""
64
- # Compress inputs
65
- compressed_keys = self.memory_key(keys)
66
- compressed_values = self.memory_value(values)
67
-
68
- # Compute update
69
- with torch.no_grad():
70
- update = torch.matmul(compressed_keys.T, compressed_values)
71
-
72
- # Progressive update with decay
73
- decay = 0.9
74
- update_rate = 0.1
75
- self.memory.data = decay * self.memory.data + update_rate * update[:self.memory_size]
76
-
77
- # Track updates
78
- self.updates += 1
79
-
80
- # Optional: Reset rarely used memory locations
81
- if self.updates % 1000 == 0:
82
- rarely_used = self.memory_usage < (self.memory_usage.mean() / 10)
83
- self.memory.data[rarely_used] = torch.randn_like(
84
- self.memory.data[rarely_used]
85
- ) * 0.1
86
- self.memory_usage[rarely_used] = 0
87
-
88
- def reset_memory(self):
89
- """Reset memory to initial state"""
90
- self.memory.data = torch.randn_like(self.memory.data) * 0.1
91
- self.memory_usage.zero_()
92
- self.updates = 0
93
 
94
- class InfiniteAttention(nn.Module):
95
- """Main attention module combining local and long-term memory attention"""
96
- def __init__(self, config: ModelConfig):
97
- super().__init__()
98
- self.config = config
99
-
100
- # Core attention components
101
- self.query = nn.Linear(config.hidden_size, config.hidden_size)
102
- self.key = nn.Linear(config.hidden_size, config.hidden_size)
103
- self.value = nn.Linear(config.hidden_size, config.hidden_size)
104
-
105
- # Multi-head attention setup
106
- self.num_heads = config.num_heads
107
- self.head_dim = config.hidden_size // config.num_heads
108
- assert self.head_dim * config.num_heads == config.hidden_size, "hidden_size must be divisible by num_heads"
109
-
110
- # Memory component
111
- self.memory = CompressiveMemory(config)
112
-
113
- # Output and gating
114
- self.output = nn.Linear(config.hidden_size * 2, config.hidden_size)
115
- self.gate = nn.Parameter(torch.zeros(1))
116
-
117
- # Load base language model and tokenizer
118
- try:
119
- self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
120
- self.base_model = AutoModelForCausalLM.from_pretrained(config.model_name)
121
- self.base_model.to(config.device)
122
- except Exception as e:
123
- logger.error(f"Error loading base model: {str(e)}")
124
- raise
125
-
126
- # Move model to specified device
127
- self.to(config.device)
128
-
129
- def split_heads(self, x: torch.Tensor) -> torch.Tensor:
130
- """Split tensor into attention heads"""
131
- batch_size, seq_length, _ = x.size()
132
- return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
133
-
134
- def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
135
- """Merge attention heads back together"""
136
- batch_size, _, seq_length, _ = x.size()
137
- return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.config.hidden_size)
138
-
139
- def get_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
140
- """Get embeddings from base model"""
141
- return self.base_model.transformer.wte(input_ids)
142
 
143
- def process_segment(self, segment: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
144
- """Process a single segment with attention"""
145
- # Compute Q, K, V
146
- q = self.split_heads(self.query(segment))
147
- k = self.split_heads(self.key(segment))
148
- v = self.split_heads(self.value(segment))
149
-
150
- # Scale query
151
- q = q / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
152
-
153
- # Compute local attention scores
154
- local_attn = torch.matmul(q, k.transpose(-2, -1))
155
-
156
- if mask is not None:
157
- local_attn = local_attn.masked_fill(mask == 0, float('-inf'))
158
-
159
- # Apply softmax
160
- local_attn = F.softmax(local_attn, dim=-1)
161
-
162
- # Compute local attention output
163
- local_output = self.merge_heads(torch.matmul(local_attn, v))
164
-
165
- # Get memory output
166
- memory_output = self.memory(q.view(-1, self.config.hidden_size))
167
- memory_output = memory_output.view(segment.size())
168
-
169
- # Update memory
170
- self.memory.update_memory(k.view(-1, self.config.hidden_size),
171
- v.view(-1, self.config.hidden_size))
172
 
173
- # Combine outputs using learned gate
174
- gate = torch.sigmoid(self.gate)
175
- combined = torch.cat([
176
- gate * local_output,
177
- (1 - gate) * memory_output
178
- ], dim=-1)
 
179
 
180
- return self.output(combined)
181
-
182
- def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
183
- """Process input sequence by segments"""
184
- batch_size = x.size(0)
185
 
186
- # Split into segments
187
- segments = x.unfold(1, self.config.segment_size,
188
- step=self.config.segment_size)
189
- output_segments = []
 
 
190
 
191
- # Process each segment
192
- for segment in segments.unbind(1):
193
- segment_output = self.process_segment(segment, mask)
194
- output_segments.append(segment_output)
195
 
196
- # Handle any remaining tokens
197
- remainder_start = segments.size(1) * self.config.segment_size
198
- if remainder_start < x.size(1):
199
- remainder = x[:, remainder_start:]
200
- if remainder.size(1) > 0:
201
- remainder_output = self.process_segment(remainder, mask)
202
- output_segments.append(remainder_output)
203
 
204
- # Combine all segments
205
- return torch.cat(output_segments, dim=1)
 
 
 
206
 
207
- def generate_response(self, input_text: str, max_new_tokens: int = 100) -> str:
208
- """Generate response from input text"""
209
- try:
210
- # Prepare input
211
- inputs = self.tokenizer(input_text,
212
- return_tensors="pt",
213
- truncation=False)
214
- input_ids = inputs["input_ids"].to(self.config.device)
215
-
216
- # Get embeddings
217
- embeddings = self.get_embeddings(input_ids)
218
-
219
- # Process through infinite attention
220
- attended = self.forward(embeddings)
221
-
222
- # Generate response using base model with attended context
223
- outputs = self.base_model.generate(
224
- input_ids,
225
- max_new_tokens=max_new_tokens,
226
- num_return_sequences=1,
227
- pad_token_id=self.tokenizer.eos_token_id,
228
- do_sample=True,
229
- temperature=0.7,
230
- top_p=0.9,
231
- )
232
-
233
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
234
-
235
- except Exception as e:
236
- logger.error(f"Error in generate_response: {str(e)}")
237
- return f"Error generating response: {str(e)}"
238
 
239
- class ChatBot:
240
- """Manages chat history and message processing"""
241
- def __init__(self, config: ModelConfig):
242
- self.config = config
243
- self.model = InfiniteAttention(config)
244
- self.history: List[Tuple[str, str]] = []
245
- self.max_history_tokens = 4096 # Adjust based on your needs
246
-
247
- def count_tokens(self, text: str) -> int:
248
- """Count tokens in text using model's tokenizer"""
249
- return len(self.model.tokenizer.encode(text))
250
-
251
- def get_truncated_history(self) -> str:
252
- """Get history truncated to max tokens"""
253
- history_text = ""
254
- token_count = 0
255
-
256
- for msg, response in reversed(self.history):
257
- new_text = f"User: {msg}\nAssistant: {response}\n"
258
- new_tokens = self.count_tokens(new_text)
259
-
260
- if token_count + new_tokens > self.max_history_tokens:
261
- break
262
-
263
- history_text = new_text + history_text
264
- token_count += new_tokens
265
-
266
- return history_text.strip()
267
-
268
- def process_message(self, message: str) -> Tuple[str, List[Tuple[str, str]]]:
269
- """Process a message and return response with updated history"""
270
  try:
271
- # Skip empty messages
272
- if not message.strip():
273
- return "", self.history
274
-
275
- # Prepare context with history
276
- history_text = self.get_truncated_history()
277
- context = f"{history_text}\nUser: {message}\nAssistant:"
278
-
279
- # Generate response
280
- full_response = self.model.generate_response(context)
281
-
282
- # Extract just the new response (after "Assistant:")
283
- response = full_response.split("Assistant:")[-1].strip()
284
-
285
- # Update history
286
- self.history.append((message, response))
287
-
288
- return response, self.history
289
-
290
  except Exception as e:
291
- error_msg = f"Error processing message: {str(e)}"
292
- logger.error(error_msg)
293
- return error_msg, self.history
294
-
295
- def save_conversation(self, filename: str):
296
- """Save conversation history to file"""
297
  try:
298
- with open(filename, 'w', encoding='utf-8') as f:
299
- for msg, response in self.history:
300
- f.write(f"User: {msg}\n")
301
- f.write(f"Assistant: {response}\n\n")
302
  except Exception as e:
303
- logger.error(f"Error saving conversation: {str(e)}")
304
-
305
- def load_conversation(self, filename: str):
306
- """Load conversation history from file"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  try:
308
- with open(filename, 'r', encoding='utf-8') as f:
309
- content = f.read()
310
 
311
- # Reset history
312
- self.history = []
 
 
 
 
 
 
 
 
 
 
313
 
314
- # Parse content
315
- conversations = content.strip().split('\n\n')
316
- for conv in conversations:
317
- if 'User:' in conv and 'Assistant:' in conv:
318
- parts = conv.split('Assistant:')
319
- msg = parts[0].replace('User:', '').strip()
320
- response = parts[1].strip()
321
- self.history.append((msg, response))
322
-
323
  except Exception as e:
324
- logger.error(f"Error loading conversation: {str(e)}")
 
 
 
 
 
 
 
 
 
325
 
326
- def create_gradio_interface():
327
- """Create and configure Gradio interface"""
328
-
329
- # Initialize config and chatbot
330
- config = ModelConfig()
331
- chatbot = ChatBot(config)
332
-
333
- def user_message(message: str, history: List[Tuple[str, str]]) -> Tuple[str, List[Tuple[str, str]]]:
334
- """Handle incoming user messages"""
335
- response, updated_history = chatbot.process_message(message)
336
- return response, updated_history
337
 
338
- def save_chat(filename: str):
339
- """Save chat history to file"""
340
- if not filename.endswith('.txt'):
341
- filename += '.txt'
342
- chatbot.save_conversation(filename)
343
- return f"Conversation saved to {filename}"
344
-
345
- def load_chat(filename: str):
346
- """Load chat history from file"""
347
- if not filename.endswith('.txt'):
348
- filename += '.txt'
349
- chatbot.load_conversation(filename)
350
- return f"Conversation loaded from {filename}"
351
-
352
- # Create main chat interface
353
- chat_interface = gr.ChatInterface(
354
- fn=user_message,
355
- title="Long Context AI Chat",
356
- description="Chat with an AI that can handle very long conversations",
357
- examples=[
358
- ["Tell me a story about space exploration"],
359
- ["What were the key points from our earlier discussion?"],
360
- ["Can you summarize everything we've talked about so far?"]
361
- ],
362
- retry_btn=None,
363
- undo_btn="Delete Last",
364
- clear_btn="Clear"
365
- )
366
-
367
- # Add save/load functionality
368
- with gr.Blocks() as interface:
369
- chat_interface.render()
370
 
371
  with gr.Row():
372
- save_file = gr.Textbox(
373
- label="Save conversation to file",
374
- placeholder="conversation.txt"
 
375
  )
376
- save_btn = gr.Button("Save")
377
- save_output = gr.Textbox(label="Save Status")
378
-
379
- load_file = gr.Textbox(
380
- label="Load conversation from file",
381
- placeholder="conversation.txt"
 
382
  )
383
- load_btn = gr.Button("Load")
384
- load_output = gr.Textbox(label="Load Status")
385
 
386
- save_btn.click(
387
- fn=save_chat,
388
- inputs=[save_file],
389
- outputs=[save_output]
390
- )
 
 
 
 
 
 
 
 
 
 
391
 
392
- load_btn.click(
393
- fn=load_chat,
394
- inputs=[load_file],
395
- outputs=[load_output]
 
 
 
396
  )
397
 
398
  return interface
@@ -400,21 +268,14 @@ def create_gradio_interface():
400
  def main():
401
  """Main application entry point"""
402
  try:
403
- # Create interface
404
- interface = create_gradio_interface()
405
-
406
- # Launch with configuration
407
  interface.launch(
 
408
  server_name="0.0.0.0",
409
- server_port=7860,
410
- share=False,
411
- debug=True,
412
- auth=None, # Add authentication if needed
413
- ssl_keyfile=None, # Add SSL if needed
414
- ssl_certfile=None
415
  )
416
  except Exception as e:
417
- logger.error(f"Error launching application: {str(e)}")
418
  raise
419
 
420
  if __name__ == "__main__":
 
1
  import os
 
 
 
 
 
 
 
2
  import logging
3
+ import pandas as pd
4
+ import google.generativeai as genai
5
+ import gradio as gr
6
+ from typing import Dict, List, Any, Tuple
7
+ import json
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ class DataAnalyzer:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.api_key = None
17
+ self.system_prompt = None
18
+ self.df = None
19
+
20
+ def configure_api(self, api_key: str):
21
+ try:
22
+ response = self.model.generate_content(prompt)
23
+ return response.text
24
+ except Exception as e:
25
+ logger.error(f"Analysis failed: {str(e)}")
26
+ return f"Analysis failed: {str(e)}"
27
 
28
+ def create_interface():
29
+ analyzer = DataAnalyzer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def process_inputs(api_key: str, system_prompt: str, file, query: str):
32
+ """Process user inputs and return analysis results"""
33
+ # Configure API
34
+ if api_key != analyzer.api_key:
35
+ if not analyzer.configure_api(api_key):
36
+ return "Failed to configure API. Please check your API key."
37
+
38
+ # Update system prompt
39
+ analyzer.system_prompt = system_prompt
40
+
41
+ # Load data if new file provided
42
+ if file is not None:
43
+ success, message = analyzer.load_data(file)
44
+ if not success:
45
+ return message
46
+
47
+ # Run analysis
48
+ return analyzer.analyze(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Create Gradio interface
51
+ with gr.Blocks(title="Data Analysis Assistant") as interface:
52
+ gr.Markdown("# Data Analysis Assistant")
53
+ gr.Markdown("Upload your CSV file and get AI-powered analysis")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ with gr.Row():
56
+ api_key_input = gr.Textbox(
57
+ label="Gemini API Key",
58
+ placeholder="Enter your Gemini API key",
59
+ type="password"
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ with gr.Row():
63
+ system_prompt_input = gr.Textbox(
64
+ label="System Prompt",
65
+ placeholder="Enter system prompt for the AI",
66
+ value="You are a data analysis expert. Analyze the provided data and answer the user's query.",
67
+ lines=3
68
+ )
69
 
70
+ with gr.Row():
71
+ file_input = gr.File(
72
+ label="Upload CSV",
73
+ file_types=[".csv"]
74
+ )
75
 
76
+ with gr.Row():
77
+ query_input = gr.Textbox(
78
+ label="Analysis Query",
79
+ placeholder="What would you like to know about the data?",
80
+ lines=2
81
+ )
82
 
83
+ with gr.Row():
84
+ submit_btn = gr.Button("Analyze")
 
 
85
 
86
+ with gr.Row():
87
+ output = gr.Markdown(label="Analysis Results")
 
 
 
 
 
88
 
89
+ submit_btn.click(
90
+ fn=process_inputs,
91
+ inputs=[api_key_input, system_prompt_input, file_input, query_input],
92
+ outputs=output
93
+ )
94
 
95
+ return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ def main():
98
+ interface = create_interface()
99
+ interface.launch()
100
+
101
+ if __name__ == "__main__":
102
+ main()Configure the Gemini API with the provided key"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  try:
104
+ genai.configure(api_key=api_key)
105
+ self.model = genai.GenerativeModel('gemini-1.5-pro')
106
+ self.api_key = api_key
107
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
+ logger.error(f"API configuration failed: {str(e)}")
110
+ return False
111
+
112
+ def load_data(self, file) -> Tuple[bool, str]:
113
+ """Load data from uploaded CSV file"""
 
114
  try:
115
+ self.df = pd.read_csv(file.name)
116
+ return True, f"Loaded CSV with {len(self.df)} rows and {len(self.df.columns)} columns"
 
 
117
  except Exception as e:
118
+ logger.error(f"Data loading failed: {str(e)}")
119
+ return False, f"Error loading data: {str(e)}"
120
+
121
+ def get_data_info(self) -> Dict[str, Any]:
122
+ """Get information about the loaded data"""
123
+ if self.df is None:
124
+ return {"error": "No data loaded"}
125
+
126
+ info = {
127
+ "columns": list(self.df.columns),
128
+ "rows": len(self.df),
129
+ "sample": self.df.head(5).to_dict('records'),
130
+ "dtypes": self.df.dtypes.astype(str).to_dict()
131
+ }
132
+ return info
133
+
134
+ def analyze(self, query: str) -> str:
135
+ """Analyze data based on user query"""
136
+ if self.model is None:
137
+ return "Please configure API key first"
138
+ if self.df is None:
139
+ return "Please upload a CSV file first"
140
+
141
+ data_info = self.get_data_info()
142
+
143
+ # Combine system prompt with data context
144
+ prompt = f"""{self.system_prompt}
145
+
146
+ Data Information:
147
+ - Columns: {data_info['columns']}
148
+ - Number of rows: {data_info['rows']}
149
+ - Sample data: {json.dumps(data_info['sample'], indent=2)}
150
+
151
+ User Query: {query}
152
+
153
+ Please analyze this data and provide:
154
+ 1. A clear explanation of your findings
155
+ 2. Key statistics relevant to the query
156
+ 3. If appropriate, suggest visualizations that would help understand the data better
157
+
158
+ Response Format:
159
+ 1. First give a direct answer to the query
160
+ 2. Then provide supporting statistics
161
+ 3. Finally, suggest any relevant additional insights
162
+
163
+ Remember to handle:
164
+ - Missing or null values
165
+ - Outliers
166
+ - Data type conversions if needed
167
+ - Basic error checking
168
+ """
169
  try:
170
+ # Call Gemini API
171
+ response = self.model.generate_content(prompt)
172
 
173
+ # Extract and format the response
174
+ if response.text:
175
+ formatted_response = (
176
+ "## Analysis Results\n\n"
177
+ f"{response.text}\n\n"
178
+ "---\n"
179
+ "Note: This analysis was generated using the provided data. "
180
+ "Please verify any critical insights independently."
181
+ )
182
+ return formatted_response
183
+ else:
184
+ return "No analysis could be generated. Please try a different query."
185
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
+ logger.error(f"Analysis failed: {str(e)}")
188
+ error_message = (
189
+ "## Error During Analysis\n\n"
190
+ f"The analysis failed with error: {str(e)}\n\n"
191
+ "Please try:\n"
192
+ "1. Checking your API key\n"
193
+ "2. Simplifying your query\n"
194
+ "3. Ensuring your data is properly formatted"
195
+ )
196
+ return error_message
197
 
198
+ def create_interface():
199
+ """Create the Gradio interface"""
200
+ analyzer = DataAnalyzer()
 
 
 
 
 
 
 
 
201
 
202
+ def process_inputs(api_key: str, system_prompt: str, file, query: str):
203
+ """Process user inputs and return analysis results"""
204
+ # Configure API
205
+ if api_key != analyzer.api_key:
206
+ if not analyzer.configure_api(api_key):
207
+ return "Failed to configure API. Please check your API key."
208
+
209
+ # Update system prompt
210
+ analyzer.system_prompt = system_prompt
211
+
212
+ # Load data if new file provided
213
+ if file is not None:
214
+ success, message = analyzer.load_data(file)
215
+ if not success:
216
+ return message
217
+
218
+ # Run analysis
219
+ return analyzer.analyze(query)
220
+
221
+ # Create Gradio interface
222
+ with gr.Blocks(title="Data Analysis Assistant") as interface:
223
+ gr.Markdown("# Data Analysis Assistant")
224
+ gr.Markdown("Upload your CSV file and get AI-powered analysis")
 
 
 
 
 
 
 
 
 
225
 
226
  with gr.Row():
227
+ api_key_input = gr.Textbox(
228
+ label="Gemini API Key",
229
+ placeholder="Enter your Gemini API key",
230
+ type="password"
231
  )
232
+
233
+ with gr.Row():
234
+ system_prompt_input = gr.Textbox(
235
+ label="System Prompt",
236
+ placeholder="Enter system prompt for the AI",
237
+ value="You are a data analysis expert. Analyze the provided data and answer the user's query.",
238
+ lines=3
239
  )
 
 
240
 
241
+ with gr.Row():
242
+ file_input = gr.File(
243
+ label="Upload CSV",
244
+ file_types=[".csv"]
245
+ )
246
+
247
+ with gr.Row():
248
+ query_input = gr.Textbox(
249
+ label="Analysis Query",
250
+ placeholder="What would you like to know about the data?",
251
+ lines=2
252
+ )
253
+
254
+ with gr.Row():
255
+ submit_btn = gr.Button("Analyze")
256
 
257
+ with gr.Row():
258
+ output = gr.Markdown(label="Analysis Results")
259
+
260
+ submit_btn.click(
261
+ fn=process_inputs,
262
+ inputs=[api_key_input, system_prompt_input, file_input, query_input],
263
+ outputs=output
264
  )
265
 
266
  return interface
 
268
  def main():
269
  """Main application entry point"""
270
  try:
271
+ interface = create_interface()
 
 
 
272
  interface.launch(
273
+ share=True,
274
  server_name="0.0.0.0",
275
+ server_port=7860
 
 
 
 
 
276
  )
277
  except Exception as e:
278
+ logger.error(f"Application startup failed: {str(e)}")
279
  raise
280
 
281
  if __name__ == "__main__":