|
import gradio as gr |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
|
|
|
|
use_pipeline = True |
|
|
|
if use_pipeline: |
|
pipe = pipeline("text-generation", model="kakaocorp/kanana-nano-2.1b-base", device="cpu") |
|
else: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("kakaocorp/kanana-nano-2.1b-base") |
|
model = AutoModelForCausalLM.from_pretrained("kakaocorp/kanana-nano-2.1b-base") |
|
|
|
print("Model loaded on CPU") |
|
|
|
|
|
|
|
def generate_text(prompt, max_length=50, temperature=1.0, top_k=50, top_p=1.0, no_repeat_ngram_size=0, num_return_sequences=1): |
|
"""Generates text based on the given prompt and parameters.""" |
|
|
|
if use_pipeline: |
|
messages = [{"role": "user", "content": prompt}] |
|
try: |
|
result = pipe( |
|
messages, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=num_return_sequences, |
|
return_full_text=False, |
|
pad_token_id=pipe.tokenizer.eos_token_id |
|
) |
|
|
|
return "\n\n".join([res['generated_text'] for res in result]) |
|
|
|
except Exception as e: |
|
return f"Error during generation: {e}" |
|
|
|
else: |
|
try: |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=num_return_sequences, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True |
|
) |
|
|
|
generated_texts = [] |
|
for i in range(outputs.shape[0]): |
|
generated_text = tokenizer.decode(outputs[i], skip_special_tokens=True) |
|
generated_texts.append(generated_text) |
|
|
|
return "\n\n".join(generated_texts) |
|
except Exception as e: |
|
return f"Error during generation: {e}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Text Generation with kakaocorp/kanana-nano-2.1b-base") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_length_slider = gr.Slider(label="Max Length", minimum=10, maximum=512, value=50, step=1) |
|
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1) |
|
top_k_slider = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=1) |
|
top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=1.0, step=0.05) |
|
no_repeat_ngram_size_slider = gr.Slider(label="No Repeat N-gram Size", minimum=0, maximum=10, value=0, step=1) |
|
num_return_sequences_slider = gr.Slider(label="Number of Return Sequences", minimum=1, maximum=5, value=1, step=1) |
|
|
|
generate_button = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Generated Text", interactive=False) |
|
|
|
generate_button.click( |
|
generate_text, |
|
inputs=[ |
|
prompt_input, |
|
max_length_slider, |
|
temperature_slider, |
|
top_k_slider, |
|
top_p_slider, |
|
no_repeat_ngram_size_slider, |
|
num_return_sequences_slider |
|
], |
|
outputs=output_text, |
|
) |
|
|
|
demo.launch() |