sbicy's picture
Update app.py
1c8e8b3 verified
raw
history blame
2.41 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import concurrent.futures
# Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load Models
# Model 1: Bloom 560M
tokenizer1 = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
model1 = AutoModelForCausalLM.from_pretrained('bigscience/bloom-560m', torch_dtype=torch.float16)
model1.to(device)
# Model 2: GPT-Neo 1.3B
tokenizer2 = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B')
model2 = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3B', torch_dtype=torch.float16)
model2.to(device)
# Define Functions with Improved Parameters
def generate_text_model1(prompt):
inputs = tokenizer1.encode(prompt, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model1.generate(
inputs,
max_length=50,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
text = tokenizer1.decode(outputs[0], skip_special_tokens=True)
return text
def generate_text_model2(prompt):
inputs = tokenizer2.encode(prompt, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model2.generate(
inputs,
max_length=50,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
text = tokenizer2.decode(outputs[0], skip_special_tokens=True)
return text
# Use ThreadPoolExecutor to Process in Parallel
def compare_models(prompt):
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(generate_text_model1, prompt)
future2 = executor.submit(generate_text_model2, prompt)
output1 = future1.result()
output2 = future2.result()
return output1, output2
# Create Gradio Interface
iface = gr.Interface(
fn=compare_models,
inputs=gr.Textbox(lines=2, placeholder='Enter a prompt here...'),
outputs=[gr.Textbox(label='Bloom 560M Output'), gr.Textbox(label='GPT-Neo 1.3B Output')],
title='Compare Text Generation Models',
description='Enter a prompt and see how two different models generate text.'
)
# Launch Interface
iface.launch()