Upload with huggingface_hub
Browse files- app.py +33 -0
- image2caption.py +70 -0
- predict.py +47 -0
- requirements.txt +5 -0
- summary_reverse_pred_eng_native.py +227 -0
app.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from summary_reverse_pred_eng_native import *
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
|
5 |
+
#text0 = "飓风格特是1993年9月在墨西哥和整个中美洲引发严重洪灾的大规模热带气旋,源于9月14日西南加勒比海上空一股东风波。次日从尼加拉瓜登岸,经过洪都拉斯后于9月17日在洪都拉斯湾再次达到热带风暴标准,但次日进入伯利兹上空后就减弱成热带低气压。穿过尤卡坦半岛后,在9月20日强化成二级飓风,从韦拉克鲁斯州的图斯潘附近登陆墨西哥。9月21日从纳亚里特州进入太平洋时已降级成热带低气压,最终于5天后在开放水域上空消散。"
|
6 |
+
#text1 = "珊瑚坝是长江中的一处河漫滩,位于长江重庆市渝中区区段主航道左侧[1],靠近渝中半岛,原分属重庆市市中区菜园坝街道和石板坡街道[2],现属渝中区菜园坝街道石板坡社区[3],是长江上游缓冲地段自然冲积沙洲,略呈纺锤形[4]或椭圆形,长约1800米,宽约600米,坝上遍布鹅卵石和水草。每年夏季洪水时均被淹没,其余时间常露水面,枯水期则与长江左岸相连[5]。"
|
7 |
+
|
8 |
+
text0 = "The Wisconsin Territorial Centennial half dollar was designed by David Parsons and Benjamin Hawkins and minted by the United States Bureau of the Mint in 1936. The obverse (pictured) depicts a pick axe and lead ore, referring to the lead mining in early Wisconsin"
|
9 |
+
#text1 = ""
|
10 |
+
|
11 |
+
example_sample = [
|
12 |
+
[text0, False],
|
13 |
+
#[text1, False],
|
14 |
+
]
|
15 |
+
|
16 |
+
def demo_func(prefix, do_sample):
|
17 |
+
l = simple_pred(prefix, do_sample = do_sample)
|
18 |
+
return {
|
19 |
+
"Dialogue Context": l
|
20 |
+
}
|
21 |
+
|
22 |
+
demo = gr.Interface(
|
23 |
+
fn=demo_func,
|
24 |
+
inputs=[gr.Text(label = "Context"),
|
25 |
+
gr.Checkbox(label="do sample"),
|
26 |
+
],
|
27 |
+
outputs="json",
|
28 |
+
title=f"English Context Dialogue Generator 🦅 demonstration",
|
29 |
+
examples=example_sample if example_sample else None,
|
30 |
+
cache_examples = False
|
31 |
+
)
|
32 |
+
|
33 |
+
demo.launch(server_name=None, server_port=None)
|
image2caption.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##### image pred
|
2 |
+
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import pathlib
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
from IPython.core.display import HTML
|
9 |
+
import os
|
10 |
+
import requests
|
11 |
+
|
12 |
+
class Image2Caption(object):
|
13 |
+
def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning",
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
15 |
+
overwrite_encoder_checkpoint_path = None,
|
16 |
+
overwrite_token_model_path = None
|
17 |
+
):
|
18 |
+
assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None
|
19 |
+
assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None
|
20 |
+
if overwrite_token_model_path is None:
|
21 |
+
overwrite_token_model_path = model_path
|
22 |
+
if overwrite_encoder_checkpoint_path is None:
|
23 |
+
overwrite_encoder_checkpoint_path = model_path
|
24 |
+
self.device = device
|
25 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
26 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path)
|
27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path)
|
28 |
+
self.model = self.model.to(self.device)
|
29 |
+
|
30 |
+
def predict_to_df(self, image_paths):
|
31 |
+
img_caption_pred = self.predict_step(image_paths)
|
32 |
+
img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
|
33 |
+
img_cation_df.columns = ["img", "caption"]
|
34 |
+
return img_cation_df
|
35 |
+
#img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
|
36 |
+
|
37 |
+
def predict_step(self ,image_paths, max_length = 128, num_beams = 4):
|
38 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
39 |
+
images = []
|
40 |
+
for image_path in image_paths:
|
41 |
+
#i_image = Image.open(image_path)
|
42 |
+
if image_path.startswith("http"):
|
43 |
+
i_image = Image.open(
|
44 |
+
requests.get(image_path, stream=True).raw
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
i_image = Image.open(image_path)
|
48 |
+
|
49 |
+
if i_image.mode != "RGB":
|
50 |
+
i_image = i_image.convert(mode="RGB")
|
51 |
+
images.append(i_image)
|
52 |
+
|
53 |
+
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
|
54 |
+
pixel_values = pixel_values.to(self.device)
|
55 |
+
|
56 |
+
output_ids = self.model.generate(pixel_values, **gen_kwargs)
|
57 |
+
|
58 |
+
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
59 |
+
preds = [pred.strip() for pred in preds]
|
60 |
+
return preds
|
61 |
+
|
62 |
+
def path_to_image_html(path):
|
63 |
+
return '<img src="'+ path + '" width="60" >'
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
i2c_obj = Image2Caption()
|
67 |
+
i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
|
68 |
+
overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
|
69 |
+
overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M"
|
70 |
+
)
|
predict.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Obj:
|
2 |
+
def __init__(self, model, tokenizer, device = "cpu"):
|
3 |
+
self.model = model
|
4 |
+
self.tokenizer = tokenizer
|
5 |
+
self.device = device
|
6 |
+
self.model = self.model.to(self.device)
|
7 |
+
|
8 |
+
def predict(
|
9 |
+
self,
|
10 |
+
source_text: str,
|
11 |
+
max_length: int = 512,
|
12 |
+
num_return_sequences: int = 1,
|
13 |
+
num_beams: int = 2,
|
14 |
+
top_k: int = 50,
|
15 |
+
top_p: float = 0.95,
|
16 |
+
do_sample: bool = True,
|
17 |
+
repetition_penalty: float = 2.5,
|
18 |
+
length_penalty: float = 1.0,
|
19 |
+
early_stopping: bool = True,
|
20 |
+
skip_special_tokens: bool = True,
|
21 |
+
clean_up_tokenization_spaces: bool = True,
|
22 |
+
):
|
23 |
+
input_ids = self.tokenizer.encode(
|
24 |
+
source_text, return_tensors="pt", add_special_tokens=True
|
25 |
+
)
|
26 |
+
input_ids = input_ids.to(self.device)
|
27 |
+
generated_ids = self.model.generate(
|
28 |
+
input_ids=input_ids,
|
29 |
+
num_beams=num_beams,
|
30 |
+
max_length=max_length,
|
31 |
+
repetition_penalty=repetition_penalty,
|
32 |
+
length_penalty=length_penalty,
|
33 |
+
early_stopping=early_stopping,
|
34 |
+
top_p=top_p,
|
35 |
+
top_k=top_k,
|
36 |
+
num_return_sequences=num_return_sequences,
|
37 |
+
do_sample = do_sample
|
38 |
+
)
|
39 |
+
preds = [
|
40 |
+
self.tokenizer.decode(
|
41 |
+
g,
|
42 |
+
skip_special_tokens=skip_special_tokens,
|
43 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
44 |
+
)
|
45 |
+
for g in generated_ids
|
46 |
+
]
|
47 |
+
return preds
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
jieba
|
4 |
+
rapidfuzz
|
5 |
+
ipykernel
|
summary_reverse_pred_eng_native.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### English scope
|
2 |
+
#device = "cuda:0"
|
3 |
+
device = "cpu"
|
4 |
+
assert device.startswith("cpu") or device.startswith("cuda")
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from predict import *
|
8 |
+
|
9 |
+
from transformers import (
|
10 |
+
T5ForConditionalGeneration,
|
11 |
+
MT5ForConditionalGeneration,
|
12 |
+
ByT5Tokenizer,
|
13 |
+
PreTrainedTokenizer,
|
14 |
+
T5TokenizerFast as T5Tokenizer,
|
15 |
+
MT5TokenizerFast as MT5Tokenizer,
|
16 |
+
AutoModelForSeq2SeqLM,
|
17 |
+
AutoTokenizer,
|
18 |
+
BertTokenizer,
|
19 |
+
GPT2LMHeadModel,
|
20 |
+
)
|
21 |
+
|
22 |
+
import pandas as pd
|
23 |
+
import numpy as np
|
24 |
+
import re
|
25 |
+
from rapidfuzz import fuzz
|
26 |
+
from tqdm import tqdm
|
27 |
+
import numpy as np
|
28 |
+
from transformers import pipeline
|
29 |
+
import os
|
30 |
+
|
31 |
+
def shorten_exists(l, sim_threshold = 80, slice_size = 5):
|
32 |
+
req = []
|
33 |
+
for ele in l:
|
34 |
+
if not req:
|
35 |
+
req.append(ele)
|
36 |
+
else:
|
37 |
+
if max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), req)) < sim_threshold:
|
38 |
+
req.append(ele)
|
39 |
+
return req
|
40 |
+
|
41 |
+
model_path = "svjack/summary-dialogue-eng"
|
42 |
+
tokenizer0 = T5Tokenizer.from_pretrained(model_path)
|
43 |
+
model0 = T5ForConditionalGeneration.from_pretrained(model_path)
|
44 |
+
|
45 |
+
if device.startswith("cuda"):
|
46 |
+
model = Obj(model0, tokenizer0, device = "cuda:0")
|
47 |
+
else:
|
48 |
+
model = Obj(model0, tokenizer0, device = "cpu")
|
49 |
+
|
50 |
+
if device.startswith("cuda"):
|
51 |
+
prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
|
52 |
+
device = 0
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
|
56 |
+
)
|
57 |
+
|
58 |
+
def loop_add(l, names = ["Tom", "Jack"]):
|
59 |
+
req = []
|
60 |
+
for i in range(len(l)):
|
61 |
+
ii = int(i % len(names))
|
62 |
+
req.append(
|
63 |
+
"{}:{}".format(names[ii], l[i])
|
64 |
+
)
|
65 |
+
return req
|
66 |
+
|
67 |
+
#### need some names drop in context(may not have ":")
|
68 |
+
#### '艾米-亚当斯在《沉睡的空洞》中,全身,双色大眼睛,咬牙切齿,恐怖,复杂的细节,电影,史诗,现实,解剖,汤姆-哈努卡,上光,艺术站,逼真,可怕'
|
69 |
+
def guess_name_candidates(context, cnt_threshold = 1):
|
70 |
+
from copy import deepcopy
|
71 |
+
assert type(context) == type("")
|
72 |
+
import re
|
73 |
+
l = re.findall(r"[\u4e00-\u9fa5a-zA-Z]+:", context)
|
74 |
+
l = list(filter(lambda x: x.strip(), l))
|
75 |
+
ori_l = deepcopy(l)
|
76 |
+
if not l:
|
77 |
+
return []
|
78 |
+
s = pd.Series(l).value_counts()
|
79 |
+
l = pd.Series(s[s > cnt_threshold].index.values.tolist()).map(lambda x: x[:-1]).values.tolist()
|
80 |
+
for ele in ori_l:
|
81 |
+
if len(ele[:-1]) not in l and (len(ele[:-1]) <= 3 or (
|
82 |
+
sum(map(len ,re.findall(r"[a-zA-Z]+:", ele))) == len(ele)
|
83 |
+
)):
|
84 |
+
l.append(ele[:-1])
|
85 |
+
l = list(set(l))
|
86 |
+
return l
|
87 |
+
|
88 |
+
def stdf_prompt_expander(x):
|
89 |
+
assert type(x) == type("")
|
90 |
+
return prompt_expand_model(x, num_return_sequences=1)[0]["generated_text"]
|
91 |
+
|
92 |
+
def simple_pred(summary, candidates = ["Tom", "Jack"], shorten_it = False,
|
93 |
+
summary_expander = lambda _:_, do_sample = True):
|
94 |
+
assert callable(summary_expander)
|
95 |
+
summary = summary_expander(summary)
|
96 |
+
pred_text = model.predict(
|
97 |
+
"{}\nCandidates:{}".format(summary, " ".join(candidates)),
|
98 |
+
do_sample = do_sample
|
99 |
+
)[0]
|
100 |
+
candidates_ = guess_name_candidates(pred_text)
|
101 |
+
l = re.split("{}".format("|".join(map(lambda x: "{}:".format(x), candidates_))) ,pred_text)
|
102 |
+
l = list(filter(lambda x: x.strip(), l))
|
103 |
+
if shorten_it:
|
104 |
+
l = shorten_exists(l)
|
105 |
+
#l = loop_add(l, candidates)
|
106 |
+
l = list(map(lambda x: x.strip(), l))
|
107 |
+
return l
|
108 |
+
|
109 |
+
def percentile_sort(df, perc_num = 101):
|
110 |
+
score_tuple_s = df["score_tuple"]
|
111 |
+
score_array = np.asarray(score_tuple_s.values.tolist())
|
112 |
+
perc_list = np.linspace(0, 100, perc_num).tolist()
|
113 |
+
low_to_high_perc_array = np.stack(list(map(lambda p: np.percentile(score_array, p, axis = 0), perc_list)))
|
114 |
+
|
115 |
+
def get_rank(array_):
|
116 |
+
lookup_list = pd.DataFrame(array_ - low_to_high_perc_array[::-1]).apply(lambda s: min(s) >= 0, axis = 1).tolist()
|
117 |
+
if True not in lookup_list:
|
118 |
+
return len(lookup_list)
|
119 |
+
return lookup_list.index(True)
|
120 |
+
|
121 |
+
rank_list = []
|
122 |
+
for i in range(score_array.shape[0]):
|
123 |
+
rank_list.append(get_rank(score_array[i, :]))
|
124 |
+
|
125 |
+
rank_s = pd.Series(rank_list)
|
126 |
+
return df.iloc[np.argsort(rank_s.values)]
|
127 |
+
|
128 |
+
def repeat_score(l, slice_size = 200 ,sim_threshold = 70):
|
129 |
+
from copy import deepcopy
|
130 |
+
assert type(l) == type([])
|
131 |
+
l = deepcopy(l)
|
132 |
+
l = sorted(l)
|
133 |
+
cnt_num = 0
|
134 |
+
set0 = set([])
|
135 |
+
for ele in l:
|
136 |
+
if ":" in ele:
|
137 |
+
ele = "".join(ele.split(":")[1:])
|
138 |
+
if set0 and max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), set0)) > sim_threshold:
|
139 |
+
#if ele in set0:
|
140 |
+
cnt_num += 1
|
141 |
+
set0.add(ele)
|
142 |
+
return cnt_num
|
143 |
+
|
144 |
+
def sample_pred(context, times = 5, stdf_prompt_expander = lambda _: _):
|
145 |
+
df_req = []
|
146 |
+
for i in tqdm(range(times)):
|
147 |
+
ele = stdf_prompt_expander(context)
|
148 |
+
#ele = context
|
149 |
+
l = simple_pred(ele, do_sample = True)
|
150 |
+
df_req.append(
|
151 |
+
[ele, l]
|
152 |
+
)
|
153 |
+
df = pd.DataFrame(df_req)
|
154 |
+
df.columns = ["context", "dialogue"]
|
155 |
+
df["fuzz"] = df["dialogue"].map(
|
156 |
+
lambda x: fuzz.ratio(context, " ".join(x))
|
157 |
+
)
|
158 |
+
df["max_fuzz"] = df["dialogue"].map(
|
159 |
+
lambda x: max(map(lambda y: fuzz.ratio(y, context), x))
|
160 |
+
)
|
161 |
+
df["length"] = df["dialogue"].map(len)
|
162 |
+
df["rpt_score"] = df["dialogue"].map(repeat_score)
|
163 |
+
df["score_tuple"] = df.apply(
|
164 |
+
lambda x: (x["fuzz"], -1 * x["max_fuzz"], x["length"], -1 * x["rpt_score"]), axis = 1
|
165 |
+
)
|
166 |
+
df = percentile_sort(df)
|
167 |
+
return df
|
168 |
+
|
169 |
+
def sample_pred_wrapper(context, i2c_obj, times = 5, extend_by_diffusion = False):
|
170 |
+
assert type(context) == type("")
|
171 |
+
if any(map(lambda x: context.endswith(x), [".jpg", ".png", ".jpeg"])):
|
172 |
+
img_path = context
|
173 |
+
i2c_df = i2c_obj.predict_to_df([img_path])
|
174 |
+
assert i2c_df.size > 0
|
175 |
+
context = i2c_df["caption"].iloc[0]
|
176 |
+
else:
|
177 |
+
pass
|
178 |
+
assert type(context) == type("")
|
179 |
+
if extend_by_diffusion:
|
180 |
+
req_df = sample_pred(context, times = times, stdf_prompt_expander = stdf_prompt_expander)
|
181 |
+
else:
|
182 |
+
req_df = sample_pred(context, times = times, stdf_prompt_expander = lambda _: _)
|
183 |
+
return req_df
|
184 |
+
|
185 |
+
from image2caption import *
|
186 |
+
i2c_obj = Image2Caption(device = device)
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
from image2caption import *
|
190 |
+
i2c_obj = Image2Caption(device = device)
|
191 |
+
|
192 |
+
img_path = "../pic/bug.jpg"
|
193 |
+
img_path = "../pic/baobao.jpeg"
|
194 |
+
img_path = "../pic/cat0.jpg"
|
195 |
+
img_path = "../pic/cat.jpg"
|
196 |
+
os.path.exists(img_path)
|
197 |
+
|
198 |
+
df = sample_pred_wrapper(img_path, i2c_obj = i2c_obj)
|
199 |
+
df["dialogue"].values.tolist()
|
200 |
+
|
201 |
+
img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/2/image/image.jpg"
|
202 |
+
img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/6/image/image.jpg"
|
203 |
+
|
204 |
+
df = sample_pred_wrapper(img_url, i2c_obj = i2c_obj)
|
205 |
+
df["dialogue"].values.tolist()
|
206 |
+
|
207 |
+
|
208 |
+
text = "Goldfinger is the seventh novel in Ian Fleming's James Bond series. First published in 1959, it centres on Bond's investigation into the gold-smuggling activities of Auric Goldfinger, who is suspected of being connected to Soviet counter-intelligence. "
|
209 |
+
text
|
210 |
+
|
211 |
+
df = sample_pred_wrapper(text, i2c_obj = i2c_obj, times = 6)
|
212 |
+
df["dialogue"].values.tolist()
|
213 |
+
|
214 |
+
en_l = ['a statue of a bird on top of a rock',
|
215 |
+
'a woman standing in front of a flower arrangement',
|
216 |
+
'people walking down a dirt road',
|
217 |
+
'two pictures of a man with a beard',
|
218 |
+
'a sign that is on top of a sign',
|
219 |
+
'a woman dressed in a costume holding an umbrella',
|
220 |
+
'a woman in a red dress holding a flower in her hand',
|
221 |
+
'a little girl in a pink dress with a pink flower in her hair']
|
222 |
+
|
223 |
+
df = sample_pred(en_l[0], 5)
|
224 |
+
df["dialogue"].values.tolist()
|
225 |
+
|
226 |
+
df = sample_pred(en_l[0], 5, stdf_prompt_expander = stdf_prompt_expander)
|
227 |
+
df["dialogue"].values.tolist()
|