cnmoro commited on
Commit
3d5a205
·
verified ·
1 Parent(s): eabeea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -34
app.py CHANGED
@@ -1,61 +1,36 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
5
  from threading import Thread
6
 
7
  torch.set_num_threads(2)
8
 
9
  # Loading the tokenizer and model from Hugging Face's model hub.
10
- tokenizer = AutoTokenizer.from_pretrained("cnmoro/teenytinyllama-460m-text-simplification-ptbr")
11
- model = AutoModelForCausalLM.from_pretrained("cnmoro/teenytinyllama-460m-text-simplification-ptbr")
12
-
13
- # using CUDA for an optimal experience
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- model = model.to(device)
16
 
17
  def count_tokens(text):
18
  return len(tokenizer.tokenize(text))
19
 
20
- class EOSStoppingCriteria(StoppingCriteria):
21
- """
22
- Custom stopping criteria that stops the generation when the "</s>" token is found.
23
- """
24
- def __init__(self, eos_token_id):
25
- self.eos_token_id = eos_token_id
26
-
27
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
28
- # Check if the last generated token is the EOS token.
29
- is_eos = input_ids[0, -1] == self.eos_token_id
30
- return is_eos
31
-
32
- # Find the EOS token ID for the specific token "</s>" in your tokenizer
33
- eos_token_id = tokenizer.convert_tokens_to_ids("</s>")
34
-
35
  # Function to generate model predictions.
36
  def predict(message, history):
37
 
38
- formatted_prompt = f"<s><system>O objetivo é comprimir e estruturar o texto a seguir<texto>{message}</texto>"
39
- model_inputs = tokenizer([
40
- formatted_prompt
41
- ], return_tensors="pt").to(device)
42
-
43
- # Instantiate your custom stopping criteria
44
- stopping_criteria = EOSStoppingCriteria(eos_token_id=eos_token_id)
45
 
46
  streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
47
 
48
  generate_kwargs = dict(
49
  model_inputs,
50
  streamer=streamer,
51
- max_new_tokens=3072 - count_tokens(formatted_prompt),
52
  top_p=0.2,
53
  top_k=20,
54
  temperature=0.1,
55
  repetition_penalty=2.0,
56
  length_penalty=-0.5,
57
- num_beams=1,
58
- stopping_criteria=StoppingCriteriaList([stopping_criteria])
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start() # Starting the generation in a separate thread.
@@ -66,6 +41,6 @@ def predict(message, history):
66
 
67
  # Setting up the Gradio chat interface.
68
  gr.ChatInterface(predict,
69
- title="TextStructurization_TeenyTinyLlama460m_CPU",
70
- description="Pass a text to be structurized"
71
  ).launch() # Launching the web interface.
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import StoppingCriteria, TextIteratorStreamer
5
  from threading import Thread
6
 
7
  torch.set_num_threads(2)
8
 
9
  # Loading the tokenizer and model from Hugging Face's model hub.
10
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
11
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
 
 
 
 
12
 
13
  def count_tokens(text):
14
  return len(tokenizer.tokenize(text))
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Function to generate model predictions.
17
  def predict(message, history):
18
 
19
+ formatted_prompt = f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
20
+ model_inputs = tokenizer(input_text, return_tensors="pt")
 
 
 
 
 
21
 
22
  streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
23
 
24
  generate_kwargs = dict(
25
  model_inputs,
26
  streamer=streamer,
27
+ max_new_tokens=2048 - count_tokens(formatted_prompt),
28
  top_p=0.2,
29
  top_k=20,
30
  temperature=0.1,
31
  repetition_penalty=2.0,
32
  length_penalty=-0.5,
33
+ num_beams=1
 
34
  )
35
  t = Thread(target=model.generate, kwargs=generate_kwargs)
36
  t.start() # Starting the generation in a separate thread.
 
41
 
42
  # Setting up the Gradio chat interface.
43
  gr.ChatInterface(predict,
44
+ title="Gemma 2b Instruct Chat",
45
+ description=None
46
  ).launch() # Launching the web interface.