BarBar288 commited on
Commit
7a37d39
·
verified ·
1 Parent(s): 325e7ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -52,15 +52,14 @@ text_to_image_pipelines = {}
52
  text_to_speech_pipelines = {}
53
 
54
  # Initialize pipelines for other tasks
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
- logger.info(f"Device set to use {device}")
57
 
58
- visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", device=device)
59
- document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device)
60
- image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224", device=device)
61
- object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50", device=device)
62
- video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400", device=device)
63
- summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
 
64
 
65
  # Load speaker embeddings for text-to-audio
66
  def load_speaker_embeddings(model_name):
@@ -74,14 +73,14 @@ def load_speaker_embeddings(model_name):
74
 
75
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
76
  try:
77
- text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0", device=device)
78
  except ValueError as e:
79
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
80
  logger.info("Falling back to a different text-to-audio model.")
81
- text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts", device=device)
82
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
83
 
84
- audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base", device=device)
85
 
86
  def load_conversational_model(model_name):
87
  if model_name not in conversational_models_loaded:
@@ -115,7 +114,7 @@ def chat(model_name, user_input, history=[]):
115
  tokenizer, model = load_conversational_model(model_name)
116
 
117
  # Encode the input
118
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
119
 
120
  # Generate a response
121
  with torch.no_grad():
 
52
  text_to_speech_pipelines = {}
53
 
54
  # Initialize pipelines for other tasks
 
 
55
 
56
+
57
+ visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
58
+ document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
59
+ image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224")
60
+ object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50")
61
+ video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400")
62
+ summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
63
 
64
  # Load speaker embeddings for text-to-audio
65
  def load_speaker_embeddings(model_name):
 
73
 
74
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
75
  try:
76
+ text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0")
77
  except ValueError as e:
78
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
79
  logger.info("Falling back to a different text-to-audio model.")
80
+ text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts")
81
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
82
 
83
+ audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base")
84
 
85
  def load_conversational_model(model_name):
86
  if model_name not in conversational_models_loaded:
 
114
  tokenizer, model = load_conversational_model(model_name)
115
 
116
  # Encode the input
117
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
118
 
119
  # Generate a response
120
  with torch.no_grad():