Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import concurrent.futures | |
# Set Device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load Models | |
# Model 1: Bloom 560M | |
tokenizer1 = AutoTokenizer.from_pretrained('bigscience/bloom-560m') | |
model1 = AutoModelForCausalLM.from_pretrained('bigscience/bloom-560m', torch_dtype=torch.float16) | |
model1.to(device) | |
# Model 2: GPT-Neo 1.3B | |
tokenizer2 = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B') | |
model2 = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3B', torch_dtype=torch.float16) | |
model2.to(device) | |
# Define Functions with Improved Parameters | |
def generate_text_model1(prompt): | |
inputs = tokenizer1.encode(prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
outputs = model1.generate( | |
inputs, | |
max_length=50, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.8 | |
) | |
text = tokenizer1.decode(outputs[0], skip_special_tokens=True) | |
return text | |
def generate_text_model2(prompt): | |
inputs = tokenizer2.encode(prompt, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
outputs = model2.generate( | |
inputs, | |
max_length=50, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.8 | |
) | |
text = tokenizer2.decode(outputs[0], skip_special_tokens=True) | |
return text | |
# Use ThreadPoolExecutor to Process in Parallel | |
def compare_models(prompt): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future1 = executor.submit(generate_text_model1, prompt) | |
future2 = executor.submit(generate_text_model2, prompt) | |
output1 = future1.result() | |
output2 = future2.result() | |
return output1, output2 | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=compare_models, | |
inputs=gr.Textbox(lines=2, placeholder='Enter a prompt here...'), | |
outputs=[gr.Textbox(label='Bloom 560M Output'), gr.Textbox(label='GPT-Neo 1.3B Output')], | |
title='Compare Text Generation Models', | |
description='Enter a prompt and see how two different models generate text.' | |
) | |
# Launch Interface | |
iface.launch() | |