Paridhim commited on
Commit
21c18e3
1 Parent(s): 6ff8a5a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ model = AutoModelForCausalLM.from_pretrained(
6
+ "mistralai/Mistral-7B-Instruct-v0.1",
7
+ torch_dtype=torch.bfloat16,
8
+ trust_remote_code=True,
9
+ device_map="auto",
10
+ context_length = 6000)
11
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
12
+
13
+ def generate_text(input_text):
14
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
15
+ attention_mask = torch.ones(input_ids.shape)
16
+
17
+ output = model.generate(
18
+ input_ids,
19
+ attention_mask=attention_mask,
20
+ max_length=200,
21
+ do_sample=True,
22
+ top_k=10,
23
+ num_return_sequences=1,
24
+ eos_token_id=tokenizer.eos_token_id,
25
+
26
+ )
27
+
28
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
29
+ print(output_text)
30
+
31
+ # Remove Prompt Echo from Generated Text
32
+ cleaned_output_text = output_text.replace(input_text, "")
33
+ return cleaned_output_text
34
+
35
+
36
+ text_generation_interface = gr.Interface(
37
+ fn=generate_text,
38
+ inputs=[
39
+ gr.inputs.Textbox(label="Input Text"),
40
+ ],
41
+ outputs=gr.inputs.Textbox(label="Generated Text")).launch()