Anupam202224 commited on
Commit
af1164d
·
verified ·
1 Parent(s): b3657e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import shutil
3
  import gradio as gr
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import pandas as pd
6
  import torch
7
  import matplotlib.pyplot as plt
@@ -9,7 +9,7 @@ import seaborn as sns
9
  import base64
10
 
11
  # Define constants
12
- MODEL_NAME = "gpt2" # Publicly accessible model suitable for CPU
13
  FIGURES_DIR = "./figures"
14
  EXAMPLE_DIR = "./example"
15
  EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
@@ -36,7 +36,7 @@ if not os.path.isfile(EXAMPLE_FILE):
36
  print("Loading model and tokenizer...")
37
  try:
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
39
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
40
  model.to('cpu') # Ensure the model runs on CPU
41
  print("Model and tokenizer loaded successfully.")
42
  except Exception as e:
@@ -86,18 +86,15 @@ def generate_summary(prompt):
86
 
87
  # Generate response
88
  with torch.no_grad():
89
- outputs = model.generate(
90
  inputs,
91
  max_length=500,
92
- do_sample=True,
93
- top_p=0.95,
94
- temperature=0.7,
95
- eos_token_id=tokenizer.eos_token_id,
96
- pad_token_id=tokenizer.eos_token_id
97
  )
98
 
99
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
- return response
101
 
102
  def analyze_data(data_file_path):
103
  """Perform data analysis on the uploaded CSV file."""
@@ -249,3 +246,5 @@ if __name__ == "__main__":
249
 
250
 
251
 
 
 
 
1
  import os
2
  import shutil
3
  import gradio as gr
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
  import pandas as pd
6
  import torch
7
  import matplotlib.pyplot as plt
 
9
  import base64
10
 
11
  # Define constants
12
+ MODEL_NAME = "facebook/bart-large-cnn" # Fine-tuned for summarization
13
  FIGURES_DIR = "./figures"
14
  EXAMPLE_DIR = "./example"
15
  EXAMPLE_FILE = os.path.join(EXAMPLE_DIR, "titanic.csv")
 
36
  print("Loading model and tokenizer...")
37
  try:
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
40
  model.to('cpu') # Ensure the model runs on CPU
41
  print("Model and tokenizer loaded successfully.")
42
  except Exception as e:
 
86
 
87
  # Generate response
88
  with torch.no_grad():
89
+ summary_ids = model.generate(
90
  inputs,
91
  max_length=500,
92
+ num_beams=4,
93
+ early_stopping=True
 
 
 
94
  )
95
 
96
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
97
+ return summary
98
 
99
  def analyze_data(data_file_path):
100
  """Perform data analysis on the uploaded CSV file."""
 
246
 
247
 
248
 
249
+
250
+