Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import shutil
|
3 |
import gradio as gr
|
4 |
-
from transformers import
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
import matplotlib.pyplot as plt
|
@@ -9,7 +9,7 @@ import seaborn as sns
|
|
9 |
import base64
|
10 |
|
11 |
# Define constants
|
12 |
-
MODEL_NAME = "
|
13 |
FIGURES_DIR = "./figures"
|
14 |
EXAMPLE_DIR = "./example"
|
15 |
EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
|
@@ -36,7 +36,7 @@ if not os.path.isfile(EXAMPLE_FILE):
|
|
36 |
print("Loading model and tokenizer...")
|
37 |
try:
|
38 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
39 |
-
model =
|
40 |
model.to('cpu') # Ensure the model runs on CPU
|
41 |
print("Model and tokenizer loaded successfully.")
|
42 |
except Exception as e:
|
@@ -86,18 +86,15 @@ def generate_summary(prompt):
|
|
86 |
|
87 |
# Generate response
|
88 |
with torch.no_grad():
|
89 |
-
|
90 |
inputs,
|
91 |
max_length=500,
|
92 |
-
|
93 |
-
|
94 |
-
temperature=0.7,
|
95 |
-
eos_token_id=tokenizer.eos_token_id,
|
96 |
-
pad_token_id=tokenizer.eos_token_id
|
97 |
)
|
98 |
|
99 |
-
|
100 |
-
return
|
101 |
|
102 |
def analyze_data(data_file_path):
|
103 |
"""Perform data analysis on the uploaded CSV file."""
|
@@ -249,3 +246,5 @@ if __name__ == "__main__":
|
|
249 |
|
250 |
|
251 |
|
|
|
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
import gradio as gr
|
4 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
import pandas as pd
|
6 |
import torch
|
7 |
import matplotlib.pyplot as plt
|
|
|
9 |
import base64
|
10 |
|
11 |
# Define constants
|
12 |
+
MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization
|
13 |
FIGURES_DIR = "./figures"
|
14 |
EXAMPLE_DIR = "./example"
|
15 |
EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
|
|
|
36 |
print("Loading model and tokenizer...")
|
37 |
try:
|
38 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
39 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
40 |
model.to('cpu') # Ensure the model runs on CPU
|
41 |
print("Model and tokenizer loaded successfully.")
|
42 |
except Exception as e:
|
|
|
86 |
|
87 |
# Generate response
|
88 |
with torch.no_grad():
|
89 |
+
summary_ids = model.generate(
|
90 |
inputs,
|
91 |
max_length=500,
|
92 |
+
num_beams=4,
|
93 |
+
early_stopping=True
|
|
|
|
|
|
|
94 |
)
|
95 |
|
96 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
97 |
+
return summary
|
98 |
|
99 |
def analyze_data(data_file_path):
|
100 |
"""Perform data analysis on the uploaded CSV file."""
|
|
|
246 |
|
247 |
|
248 |
|
249 |
+
|
250 |
+
|