svjack commited on
Commit
32fcb0c
·
1 Parent(s): 87feba0

Upload with huggingface_hub

Browse files
Files changed (5) hide show
  1. app.py +33 -0
  2. image2caption.py +70 -0
  3. predict.py +47 -0
  4. requirements.txt +5 -0
  5. summary_reverse_pred_eng_native.py +227 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from summary_reverse_pred_eng_native import *
2
+ import gradio as gr
3
+ import os
4
+
5
+ #text0 = "飓风格特是1993年9月在墨西哥和整个中美洲引发严重洪灾的大规模热带气旋,源于9月14日西南加勒比海上空一股东风波。次日从尼加拉瓜登岸,经过洪都拉斯后于9月17日在洪都拉斯湾再次达到热带风暴标准,但次日进入伯利兹上空后就减弱成热带低气压。穿过尤卡坦半岛后,在9月20日强化成二级飓风,从韦拉克鲁斯州的图斯潘附近登陆墨西哥。9月21日从纳亚里特州进入太平洋时已降级成热带低气压,最终于5天后在开放水域上空消散。"
6
+ #text1 = "珊瑚坝是长江中的一处河漫滩,位于长江重庆市渝中区区段主航道左侧[1],靠近渝中半岛,原分属重庆市市中区菜园坝街道和石板坡街道[2],现属渝中区菜园坝街道石板坡社区[3],是长江上游缓冲地段自然冲积沙洲,略呈纺锤形[4]或椭圆形,长约1800米,宽约600米,坝上遍布鹅卵石和水草。每年夏季洪水时均被淹没,其余时间常露水面,枯水期则与长江左岸相连[5]。"
7
+
8
+ text0 = "The Wisconsin Territorial Centennial half dollar was designed by David Parsons and Benjamin Hawkins and minted by the United States Bureau of the Mint in 1936. The obverse (pictured) depicts a pick axe and lead ore, referring to the lead mining in early Wisconsin"
9
+ #text1 = ""
10
+
11
+ example_sample = [
12
+ [text0, False],
13
+ #[text1, False],
14
+ ]
15
+
16
+ def demo_func(prefix, do_sample):
17
+ l = simple_pred(prefix, do_sample = do_sample)
18
+ return {
19
+ "Dialogue Context": l
20
+ }
21
+
22
+ demo = gr.Interface(
23
+ fn=demo_func,
24
+ inputs=[gr.Text(label = "Context"),
25
+ gr.Checkbox(label="do sample"),
26
+ ],
27
+ outputs="json",
28
+ title=f"English Context Dialogue Generator 🦅 demonstration",
29
+ examples=example_sample if example_sample else None,
30
+ cache_examples = False
31
+ )
32
+
33
+ demo.launch(server_name=None, server_port=None)
image2caption.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##### image pred
2
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
3
+ import torch
4
+ from PIL import Image
5
+ import pathlib
6
+ import pandas as pd
7
+ import numpy as np
8
+ from IPython.core.display import HTML
9
+ import os
10
+ import requests
11
+
12
+ class Image2Caption(object):
13
+ def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning",
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
15
+ overwrite_encoder_checkpoint_path = None,
16
+ overwrite_token_model_path = None
17
+ ):
18
+ assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None
19
+ assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None
20
+ if overwrite_token_model_path is None:
21
+ overwrite_token_model_path = model_path
22
+ if overwrite_encoder_checkpoint_path is None:
23
+ overwrite_encoder_checkpoint_path = model_path
24
+ self.device = device
25
+ self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
26
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path)
27
+ self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path)
28
+ self.model = self.model.to(self.device)
29
+
30
+ def predict_to_df(self, image_paths):
31
+ img_caption_pred = self.predict_step(image_paths)
32
+ img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
33
+ img_cation_df.columns = ["img", "caption"]
34
+ return img_cation_df
35
+ #img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
36
+
37
+ def predict_step(self ,image_paths, max_length = 128, num_beams = 4):
38
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
39
+ images = []
40
+ for image_path in image_paths:
41
+ #i_image = Image.open(image_path)
42
+ if image_path.startswith("http"):
43
+ i_image = Image.open(
44
+ requests.get(image_path, stream=True).raw
45
+ )
46
+ else:
47
+ i_image = Image.open(image_path)
48
+
49
+ if i_image.mode != "RGB":
50
+ i_image = i_image.convert(mode="RGB")
51
+ images.append(i_image)
52
+
53
+ pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
54
+ pixel_values = pixel_values.to(self.device)
55
+
56
+ output_ids = self.model.generate(pixel_values, **gen_kwargs)
57
+
58
+ preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
59
+ preds = [pred.strip() for pred in preds]
60
+ return preds
61
+
62
+ def path_to_image_html(path):
63
+ return '<img src="'+ path + '" width="60" >'
64
+
65
+ if __name__ == "__main__":
66
+ i2c_obj = Image2Caption()
67
+ i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
68
+ overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
69
+ overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M"
70
+ )
predict.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Obj:
2
+ def __init__(self, model, tokenizer, device = "cpu"):
3
+ self.model = model
4
+ self.tokenizer = tokenizer
5
+ self.device = device
6
+ self.model = self.model.to(self.device)
7
+
8
+ def predict(
9
+ self,
10
+ source_text: str,
11
+ max_length: int = 512,
12
+ num_return_sequences: int = 1,
13
+ num_beams: int = 2,
14
+ top_k: int = 50,
15
+ top_p: float = 0.95,
16
+ do_sample: bool = True,
17
+ repetition_penalty: float = 2.5,
18
+ length_penalty: float = 1.0,
19
+ early_stopping: bool = True,
20
+ skip_special_tokens: bool = True,
21
+ clean_up_tokenization_spaces: bool = True,
22
+ ):
23
+ input_ids = self.tokenizer.encode(
24
+ source_text, return_tensors="pt", add_special_tokens=True
25
+ )
26
+ input_ids = input_ids.to(self.device)
27
+ generated_ids = self.model.generate(
28
+ input_ids=input_ids,
29
+ num_beams=num_beams,
30
+ max_length=max_length,
31
+ repetition_penalty=repetition_penalty,
32
+ length_penalty=length_penalty,
33
+ early_stopping=early_stopping,
34
+ top_p=top_p,
35
+ top_k=top_k,
36
+ num_return_sequences=num_return_sequences,
37
+ do_sample = do_sample
38
+ )
39
+ preds = [
40
+ self.tokenizer.decode(
41
+ g,
42
+ skip_special_tokens=skip_special_tokens,
43
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
44
+ )
45
+ for g in generated_ids
46
+ ]
47
+ return preds
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ jieba
4
+ rapidfuzz
5
+ ipykernel
summary_reverse_pred_eng_native.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### English scope
2
+ #device = "cuda:0"
3
+ device = "cpu"
4
+ assert device.startswith("cpu") or device.startswith("cuda")
5
+
6
+ import sys
7
+ from predict import *
8
+
9
+ from transformers import (
10
+ T5ForConditionalGeneration,
11
+ MT5ForConditionalGeneration,
12
+ ByT5Tokenizer,
13
+ PreTrainedTokenizer,
14
+ T5TokenizerFast as T5Tokenizer,
15
+ MT5TokenizerFast as MT5Tokenizer,
16
+ AutoModelForSeq2SeqLM,
17
+ AutoTokenizer,
18
+ BertTokenizer,
19
+ GPT2LMHeadModel,
20
+ )
21
+
22
+ import pandas as pd
23
+ import numpy as np
24
+ import re
25
+ from rapidfuzz import fuzz
26
+ from tqdm import tqdm
27
+ import numpy as np
28
+ from transformers import pipeline
29
+ import os
30
+
31
+ def shorten_exists(l, sim_threshold = 80, slice_size = 5):
32
+ req = []
33
+ for ele in l:
34
+ if not req:
35
+ req.append(ele)
36
+ else:
37
+ if max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), req)) < sim_threshold:
38
+ req.append(ele)
39
+ return req
40
+
41
+ model_path = "svjack/summary-dialogue-eng"
42
+ tokenizer0 = T5Tokenizer.from_pretrained(model_path)
43
+ model0 = T5ForConditionalGeneration.from_pretrained(model_path)
44
+
45
+ if device.startswith("cuda"):
46
+ model = Obj(model0, tokenizer0, device = "cuda:0")
47
+ else:
48
+ model = Obj(model0, tokenizer0, device = "cpu")
49
+
50
+ if device.startswith("cuda"):
51
+ prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
52
+ device = 0
53
+ )
54
+ else:
55
+ prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
56
+ )
57
+
58
+ def loop_add(l, names = ["Tom", "Jack"]):
59
+ req = []
60
+ for i in range(len(l)):
61
+ ii = int(i % len(names))
62
+ req.append(
63
+ "{}:{}".format(names[ii], l[i])
64
+ )
65
+ return req
66
+
67
+ #### need some names drop in context(may not have ":")
68
+ #### '艾米-亚当斯在《沉睡的空洞》中,全身,双色大眼睛,咬牙切齿,恐怖,复杂的细节,电影,史诗,现实,解剖,汤姆-哈努卡,上光,艺术站,逼真,可怕'
69
+ def guess_name_candidates(context, cnt_threshold = 1):
70
+ from copy import deepcopy
71
+ assert type(context) == type("")
72
+ import re
73
+ l = re.findall(r"[\u4e00-\u9fa5a-zA-Z]+:", context)
74
+ l = list(filter(lambda x: x.strip(), l))
75
+ ori_l = deepcopy(l)
76
+ if not l:
77
+ return []
78
+ s = pd.Series(l).value_counts()
79
+ l = pd.Series(s[s > cnt_threshold].index.values.tolist()).map(lambda x: x[:-1]).values.tolist()
80
+ for ele in ori_l:
81
+ if len(ele[:-1]) not in l and (len(ele[:-1]) <= 3 or (
82
+ sum(map(len ,re.findall(r"[a-zA-Z]+:", ele))) == len(ele)
83
+ )):
84
+ l.append(ele[:-1])
85
+ l = list(set(l))
86
+ return l
87
+
88
+ def stdf_prompt_expander(x):
89
+ assert type(x) == type("")
90
+ return prompt_expand_model(x, num_return_sequences=1)[0]["generated_text"]
91
+
92
+ def simple_pred(summary, candidates = ["Tom", "Jack"], shorten_it = False,
93
+ summary_expander = lambda _:_, do_sample = True):
94
+ assert callable(summary_expander)
95
+ summary = summary_expander(summary)
96
+ pred_text = model.predict(
97
+ "{}\nCandidates:{}".format(summary, " ".join(candidates)),
98
+ do_sample = do_sample
99
+ )[0]
100
+ candidates_ = guess_name_candidates(pred_text)
101
+ l = re.split("{}".format("|".join(map(lambda x: "{}:".format(x), candidates_))) ,pred_text)
102
+ l = list(filter(lambda x: x.strip(), l))
103
+ if shorten_it:
104
+ l = shorten_exists(l)
105
+ #l = loop_add(l, candidates)
106
+ l = list(map(lambda x: x.strip(), l))
107
+ return l
108
+
109
+ def percentile_sort(df, perc_num = 101):
110
+ score_tuple_s = df["score_tuple"]
111
+ score_array = np.asarray(score_tuple_s.values.tolist())
112
+ perc_list = np.linspace(0, 100, perc_num).tolist()
113
+ low_to_high_perc_array = np.stack(list(map(lambda p: np.percentile(score_array, p, axis = 0), perc_list)))
114
+
115
+ def get_rank(array_):
116
+ lookup_list = pd.DataFrame(array_ - low_to_high_perc_array[::-1]).apply(lambda s: min(s) >= 0, axis = 1).tolist()
117
+ if True not in lookup_list:
118
+ return len(lookup_list)
119
+ return lookup_list.index(True)
120
+
121
+ rank_list = []
122
+ for i in range(score_array.shape[0]):
123
+ rank_list.append(get_rank(score_array[i, :]))
124
+
125
+ rank_s = pd.Series(rank_list)
126
+ return df.iloc[np.argsort(rank_s.values)]
127
+
128
+ def repeat_score(l, slice_size = 200 ,sim_threshold = 70):
129
+ from copy import deepcopy
130
+ assert type(l) == type([])
131
+ l = deepcopy(l)
132
+ l = sorted(l)
133
+ cnt_num = 0
134
+ set0 = set([])
135
+ for ele in l:
136
+ if ":" in ele:
137
+ ele = "".join(ele.split(":")[1:])
138
+ if set0 and max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), set0)) > sim_threshold:
139
+ #if ele in set0:
140
+ cnt_num += 1
141
+ set0.add(ele)
142
+ return cnt_num
143
+
144
+ def sample_pred(context, times = 5, stdf_prompt_expander = lambda _: _):
145
+ df_req = []
146
+ for i in tqdm(range(times)):
147
+ ele = stdf_prompt_expander(context)
148
+ #ele = context
149
+ l = simple_pred(ele, do_sample = True)
150
+ df_req.append(
151
+ [ele, l]
152
+ )
153
+ df = pd.DataFrame(df_req)
154
+ df.columns = ["context", "dialogue"]
155
+ df["fuzz"] = df["dialogue"].map(
156
+ lambda x: fuzz.ratio(context, " ".join(x))
157
+ )
158
+ df["max_fuzz"] = df["dialogue"].map(
159
+ lambda x: max(map(lambda y: fuzz.ratio(y, context), x))
160
+ )
161
+ df["length"] = df["dialogue"].map(len)
162
+ df["rpt_score"] = df["dialogue"].map(repeat_score)
163
+ df["score_tuple"] = df.apply(
164
+ lambda x: (x["fuzz"], -1 * x["max_fuzz"], x["length"], -1 * x["rpt_score"]), axis = 1
165
+ )
166
+ df = percentile_sort(df)
167
+ return df
168
+
169
+ def sample_pred_wrapper(context, i2c_obj, times = 5, extend_by_diffusion = False):
170
+ assert type(context) == type("")
171
+ if any(map(lambda x: context.endswith(x), [".jpg", ".png", ".jpeg"])):
172
+ img_path = context
173
+ i2c_df = i2c_obj.predict_to_df([img_path])
174
+ assert i2c_df.size > 0
175
+ context = i2c_df["caption"].iloc[0]
176
+ else:
177
+ pass
178
+ assert type(context) == type("")
179
+ if extend_by_diffusion:
180
+ req_df = sample_pred(context, times = times, stdf_prompt_expander = stdf_prompt_expander)
181
+ else:
182
+ req_df = sample_pred(context, times = times, stdf_prompt_expander = lambda _: _)
183
+ return req_df
184
+
185
+ from image2caption import *
186
+ i2c_obj = Image2Caption(device = device)
187
+
188
+ if __name__ == "__main__":
189
+ from image2caption import *
190
+ i2c_obj = Image2Caption(device = device)
191
+
192
+ img_path = "../pic/bug.jpg"
193
+ img_path = "../pic/baobao.jpeg"
194
+ img_path = "../pic/cat0.jpg"
195
+ img_path = "../pic/cat.jpg"
196
+ os.path.exists(img_path)
197
+
198
+ df = sample_pred_wrapper(img_path, i2c_obj = i2c_obj)
199
+ df["dialogue"].values.tolist()
200
+
201
+ img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/2/image/image.jpg"
202
+ img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/6/image/image.jpg"
203
+
204
+ df = sample_pred_wrapper(img_url, i2c_obj = i2c_obj)
205
+ df["dialogue"].values.tolist()
206
+
207
+
208
+ text = "Goldfinger is the seventh novel in Ian Fleming's James Bond series. First published in 1959, it centres on Bond's investigation into the gold-smuggling activities of Auric Goldfinger, who is suspected of being connected to Soviet counter-intelligence. "
209
+ text
210
+
211
+ df = sample_pred_wrapper(text, i2c_obj = i2c_obj, times = 6)
212
+ df["dialogue"].values.tolist()
213
+
214
+ en_l = ['a statue of a bird on top of a rock',
215
+ 'a woman standing in front of a flower arrangement',
216
+ 'people walking down a dirt road',
217
+ 'two pictures of a man with a beard',
218
+ 'a sign that is on top of a sign',
219
+ 'a woman dressed in a costume holding an umbrella',
220
+ 'a woman in a red dress holding a flower in her hand',
221
+ 'a little girl in a pink dress with a pink flower in her hair']
222
+
223
+ df = sample_pred(en_l[0], 5)
224
+ df["dialogue"].values.tolist()
225
+
226
+ df = sample_pred(en_l[0], 5, stdf_prompt_expander = stdf_prompt_expander)
227
+ df["dialogue"].values.tolist()