Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
# Get the Hugging Face token from the environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/GPT-NeoX-20B', use_auth_token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained('skylersterling/SentimentGPT', use_auth_token=HF_TOKEN) | |
model.eval() | |
model.to('cpu') | |
# Define the function that generates text from a prompt | |
def generate_text(prompt, temperature=0.1, top_p=0.2, max_tokens=3): | |
prompt_with_eos = prompt + " >" # Add the "EOS" to the end of the prompt | |
input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt') | |
print(prompt_with_eos) | |
input_tokens = input_tokens.to('cpu') | |
generated_text = prompt_with_eos # Start with the initial prompt plus "EOS" | |
prompt_length = len(generated_text) | |
for _ in range(max_tokens): # Adjust the range to control the number of tokens generated | |
with torch.no_grad(): | |
outputs = model(input_tokens) | |
predictions = outputs.logits[:, -1, :] / temperature | |
sorted_logits, sorted_indices = torch.sort(predictions, descending=True) | |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
predictions[:, indices_to_remove] = -float('Inf') | |
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1) | |
input_tokens = torch.cat((input_tokens, next_token), dim=1) | |
decoded_token = tokenizer.decode(next_token.item()) | |
generated_text += decoded_token # Append the new token to the generated text | |
if decoded_token == "<": # Stop if the end of sequence token is generated | |
break | |
# Check if the generated text contains "0" or "1" and return the appropriate sentiment message | |
if "1" in generated_text: | |
return "The sentiment is positive." | |
elif "0" in generated_text: | |
return "The sentiment is negative." | |
else: | |
return "Invalid" | |
# Create a Gradio interface with a text input, sliders for temperature and top_p, and a text output | |
interface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter your prompt here...") | |
], | |
outputs=gr.Textbox(), | |
live=False, | |
description="SentimentGPT processes the sequence and returns a reasonably accurate guess of whether the sentiment behind the input is positive or negative." | |
) | |
interface.launch() | |