stibiumghost commited on
Commit
1659ccb
·
1 Parent(s): 7edf8d6

Update text_gen.py

Browse files
Files changed (1) hide show
  1. text_gen.py +4 -4
text_gen.py CHANGED
@@ -3,15 +3,15 @@ import string
3
 
4
  model_names = ['microsoft/GODEL-v1_1-large-seq2seq',
5
  'facebook/blenderbot-1B-distill',
6
- 'satvikag/chatbot']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
- transformers.GPT2Tokenizer.from_pretrained(model_names[2])]
11
 
12
  model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
  transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
- transformers.GPT2LMHeadModel.from_pretrained(model_names[2])]
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
@@ -24,7 +24,7 @@ def generate_text(text, context, model_name, model, tokenizer, minimum=15, maxim
24
  input_ids = tokenizer(text, return_tensors="pt").input_ids
25
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
26
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
- return model_name + capitalization(output)
28
 
29
 
30
  def capitalization(line):
 
3
 
4
  model_names = ['microsoft/GODEL-v1_1-large-seq2seq',
5
  'facebook/blenderbot-1B-distill',
6
+ 'facebook/blenderbot_small-90M']
7
 
8
  tokenizers = [transformers.AutoTokenizer.from_pretrained(model_names[0]),
9
  transformers.BlenderbotTokenizer.from_pretrained(model_names[1]),
10
+ transformers.BlenderbotSmallTokenizer.from_pretrained(model_names[2])]
11
 
12
  model = [transformers.AutoModelForSeq2SeqLM.from_pretrained(model_names[0]),
13
  transformers.BlenderbotForConditionalGeneration.from_pretrained(model_names[1]),
14
+ transformers.BlenderbotSmallForConditionalGeneration.from_pretrained(model_names[2])]
15
 
16
 
17
  def generate_text(text, context, model_name, model, tokenizer, minimum=15, maximum=300):
 
24
  input_ids = tokenizer(text, return_tensors="pt").input_ids
25
  outputs = model.generate(input_ids, max_new_tokens=maximum, min_new_tokens=minimum, top_p=0.9, do_sample=True)
26
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+ return capitalization(output)
28
 
29
 
30
  def capitalization(line):