Spaces:
Runtime error
Runtime error
Upload with huggingface_hub
Browse files- app.py +33 -0
- image2caption.py +70 -0
- predict.py +47 -0
- requirements.txt +5 -0
- summary_reverse_pred_eng_native.py +227 -0
app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from summary_reverse_pred_eng_native import *
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
#text0 = "飓风格特是1993年9月在墨西哥和整个中美洲引发严重洪灾的大规模热带气旋,源于9月14日西南加勒比海上空一股东风波。次日从尼加拉瓜登岸,经过洪都拉斯后于9月17日在洪都拉斯湾再次达到热带风暴标准,但次日进入伯利兹上空后就减弱成热带低气压。穿过尤卡坦半岛后,在9月20日强化成二级飓风,从韦拉克鲁斯州的图斯潘附近登陆墨西哥。9月21日从纳亚里特州进入太平洋时已降级成热带低气压,最终于5天后在开放水域上空消散。"
|
| 6 |
+
#text1 = "珊瑚坝是长江中的一处河漫滩,位于长江重庆市渝中区区段主航道左侧[1],靠近渝中半岛,原分属重庆市市中区菜园坝街道和石板坡街道[2],现属渝中区菜园坝街道石板坡社区[3],是长江上游缓冲地段自然冲积沙洲,略呈纺锤形[4]或椭圆形,长约1800米,宽约600米,坝上遍布鹅卵石和水草。每年夏季洪水时均被淹没,其余时间常露水面,枯水期则与长江左岸相连[5]。"
|
| 7 |
+
|
| 8 |
+
text0 = "The Wisconsin Territorial Centennial half dollar was designed by David Parsons and Benjamin Hawkins and minted by the United States Bureau of the Mint in 1936. The obverse (pictured) depicts a pick axe and lead ore, referring to the lead mining in early Wisconsin"
|
| 9 |
+
#text1 = ""
|
| 10 |
+
|
| 11 |
+
example_sample = [
|
| 12 |
+
[text0, False],
|
| 13 |
+
#[text1, False],
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
def demo_func(prefix, do_sample):
|
| 17 |
+
l = simple_pred(prefix, do_sample = do_sample)
|
| 18 |
+
return {
|
| 19 |
+
"Dialogue Context": l
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
demo = gr.Interface(
|
| 23 |
+
fn=demo_func,
|
| 24 |
+
inputs=[gr.Text(label = "Context"),
|
| 25 |
+
gr.Checkbox(label="do sample"),
|
| 26 |
+
],
|
| 27 |
+
outputs="json",
|
| 28 |
+
title=f"English Context Dialogue Generator 🦅 demonstration",
|
| 29 |
+
examples=example_sample if example_sample else None,
|
| 30 |
+
cache_examples = False
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
demo.launch(server_name=None, server_port=None)
|
image2caption.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##### image pred
|
| 2 |
+
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import pathlib
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
from IPython.core.display import HTML
|
| 9 |
+
import os
|
| 10 |
+
import requests
|
| 11 |
+
|
| 12 |
+
class Image2Caption(object):
|
| 13 |
+
def __init__(self ,model_path = "nlpconnect/vit-gpt2-image-captioning",
|
| 14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
| 15 |
+
overwrite_encoder_checkpoint_path = None,
|
| 16 |
+
overwrite_token_model_path = None
|
| 17 |
+
):
|
| 18 |
+
assert type(overwrite_token_model_path) == type("") or overwrite_token_model_path is None
|
| 19 |
+
assert type(overwrite_encoder_checkpoint_path) == type("") or overwrite_encoder_checkpoint_path is None
|
| 20 |
+
if overwrite_token_model_path is None:
|
| 21 |
+
overwrite_token_model_path = model_path
|
| 22 |
+
if overwrite_encoder_checkpoint_path is None:
|
| 23 |
+
overwrite_encoder_checkpoint_path = model_path
|
| 24 |
+
self.device = device
|
| 25 |
+
self.model = VisionEncoderDecoderModel.from_pretrained(model_path)
|
| 26 |
+
self.feature_extractor = ViTFeatureExtractor.from_pretrained(overwrite_encoder_checkpoint_path)
|
| 27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(overwrite_token_model_path)
|
| 28 |
+
self.model = self.model.to(self.device)
|
| 29 |
+
|
| 30 |
+
def predict_to_df(self, image_paths):
|
| 31 |
+
img_caption_pred = self.predict_step(image_paths)
|
| 32 |
+
img_cation_df = pd.DataFrame(list(zip(image_paths, img_caption_pred)))
|
| 33 |
+
img_cation_df.columns = ["img", "caption"]
|
| 34 |
+
return img_cation_df
|
| 35 |
+
#img_cation_df.to_html(escape=False, formatters=dict(Country=path_to_image_html))
|
| 36 |
+
|
| 37 |
+
def predict_step(self ,image_paths, max_length = 128, num_beams = 4):
|
| 38 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
| 39 |
+
images = []
|
| 40 |
+
for image_path in image_paths:
|
| 41 |
+
#i_image = Image.open(image_path)
|
| 42 |
+
if image_path.startswith("http"):
|
| 43 |
+
i_image = Image.open(
|
| 44 |
+
requests.get(image_path, stream=True).raw
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
i_image = Image.open(image_path)
|
| 48 |
+
|
| 49 |
+
if i_image.mode != "RGB":
|
| 50 |
+
i_image = i_image.convert(mode="RGB")
|
| 51 |
+
images.append(i_image)
|
| 52 |
+
|
| 53 |
+
pixel_values = self.feature_extractor(images=images, return_tensors="pt").pixel_values
|
| 54 |
+
pixel_values = pixel_values.to(self.device)
|
| 55 |
+
|
| 56 |
+
output_ids = self.model.generate(pixel_values, **gen_kwargs)
|
| 57 |
+
|
| 58 |
+
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 59 |
+
preds = [pred.strip() for pred in preds]
|
| 60 |
+
return preds
|
| 61 |
+
|
| 62 |
+
def path_to_image_html(path):
|
| 63 |
+
return '<img src="'+ path + '" width="60" >'
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
i2c_obj = Image2Caption()
|
| 67 |
+
i2c_tiny_zh_obj = Image2Caption("svjack/vit-gpt-diffusion-zh",
|
| 68 |
+
overwrite_encoder_checkpoint_path = "google/vit-base-patch16-224",
|
| 69 |
+
overwrite_token_model_path = "IDEA-CCNL/Wenzhong-GPT2-110M"
|
| 70 |
+
)
|
predict.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Obj:
|
| 2 |
+
def __init__(self, model, tokenizer, device = "cpu"):
|
| 3 |
+
self.model = model
|
| 4 |
+
self.tokenizer = tokenizer
|
| 5 |
+
self.device = device
|
| 6 |
+
self.model = self.model.to(self.device)
|
| 7 |
+
|
| 8 |
+
def predict(
|
| 9 |
+
self,
|
| 10 |
+
source_text: str,
|
| 11 |
+
max_length: int = 512,
|
| 12 |
+
num_return_sequences: int = 1,
|
| 13 |
+
num_beams: int = 2,
|
| 14 |
+
top_k: int = 50,
|
| 15 |
+
top_p: float = 0.95,
|
| 16 |
+
do_sample: bool = True,
|
| 17 |
+
repetition_penalty: float = 2.5,
|
| 18 |
+
length_penalty: float = 1.0,
|
| 19 |
+
early_stopping: bool = True,
|
| 20 |
+
skip_special_tokens: bool = True,
|
| 21 |
+
clean_up_tokenization_spaces: bool = True,
|
| 22 |
+
):
|
| 23 |
+
input_ids = self.tokenizer.encode(
|
| 24 |
+
source_text, return_tensors="pt", add_special_tokens=True
|
| 25 |
+
)
|
| 26 |
+
input_ids = input_ids.to(self.device)
|
| 27 |
+
generated_ids = self.model.generate(
|
| 28 |
+
input_ids=input_ids,
|
| 29 |
+
num_beams=num_beams,
|
| 30 |
+
max_length=max_length,
|
| 31 |
+
repetition_penalty=repetition_penalty,
|
| 32 |
+
length_penalty=length_penalty,
|
| 33 |
+
early_stopping=early_stopping,
|
| 34 |
+
top_p=top_p,
|
| 35 |
+
top_k=top_k,
|
| 36 |
+
num_return_sequences=num_return_sequences,
|
| 37 |
+
do_sample = do_sample
|
| 38 |
+
)
|
| 39 |
+
preds = [
|
| 40 |
+
self.tokenizer.decode(
|
| 41 |
+
g,
|
| 42 |
+
skip_special_tokens=skip_special_tokens,
|
| 43 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 44 |
+
)
|
| 45 |
+
for g in generated_ids
|
| 46 |
+
]
|
| 47 |
+
return preds
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
jieba
|
| 4 |
+
rapidfuzz
|
| 5 |
+
ipykernel
|
summary_reverse_pred_eng_native.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#### English scope
|
| 2 |
+
#device = "cuda:0"
|
| 3 |
+
device = "cpu"
|
| 4 |
+
assert device.startswith("cpu") or device.startswith("cuda")
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from predict import *
|
| 8 |
+
|
| 9 |
+
from transformers import (
|
| 10 |
+
T5ForConditionalGeneration,
|
| 11 |
+
MT5ForConditionalGeneration,
|
| 12 |
+
ByT5Tokenizer,
|
| 13 |
+
PreTrainedTokenizer,
|
| 14 |
+
T5TokenizerFast as T5Tokenizer,
|
| 15 |
+
MT5TokenizerFast as MT5Tokenizer,
|
| 16 |
+
AutoModelForSeq2SeqLM,
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
BertTokenizer,
|
| 19 |
+
GPT2LMHeadModel,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import numpy as np
|
| 24 |
+
import re
|
| 25 |
+
from rapidfuzz import fuzz
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
import numpy as np
|
| 28 |
+
from transformers import pipeline
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
def shorten_exists(l, sim_threshold = 80, slice_size = 5):
|
| 32 |
+
req = []
|
| 33 |
+
for ele in l:
|
| 34 |
+
if not req:
|
| 35 |
+
req.append(ele)
|
| 36 |
+
else:
|
| 37 |
+
if max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), req)) < sim_threshold:
|
| 38 |
+
req.append(ele)
|
| 39 |
+
return req
|
| 40 |
+
|
| 41 |
+
model_path = "svjack/summary-dialogue-eng"
|
| 42 |
+
tokenizer0 = T5Tokenizer.from_pretrained(model_path)
|
| 43 |
+
model0 = T5ForConditionalGeneration.from_pretrained(model_path)
|
| 44 |
+
|
| 45 |
+
if device.startswith("cuda"):
|
| 46 |
+
model = Obj(model0, tokenizer0, device = "cuda:0")
|
| 47 |
+
else:
|
| 48 |
+
model = Obj(model0, tokenizer0, device = "cpu")
|
| 49 |
+
|
| 50 |
+
if device.startswith("cuda"):
|
| 51 |
+
prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
|
| 52 |
+
device = 0
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
prompt_expand_model = pipeline('text-generation', model='daspartho/prompt-extend',
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def loop_add(l, names = ["Tom", "Jack"]):
|
| 59 |
+
req = []
|
| 60 |
+
for i in range(len(l)):
|
| 61 |
+
ii = int(i % len(names))
|
| 62 |
+
req.append(
|
| 63 |
+
"{}:{}".format(names[ii], l[i])
|
| 64 |
+
)
|
| 65 |
+
return req
|
| 66 |
+
|
| 67 |
+
#### need some names drop in context(may not have ":")
|
| 68 |
+
#### '艾米-亚当斯在《沉睡的空洞》中,全身,双色大眼睛,咬牙切齿,恐怖,复杂的细节,电影,史诗,现实,解剖,汤姆-哈努卡,上光,艺术站,逼真,可怕'
|
| 69 |
+
def guess_name_candidates(context, cnt_threshold = 1):
|
| 70 |
+
from copy import deepcopy
|
| 71 |
+
assert type(context) == type("")
|
| 72 |
+
import re
|
| 73 |
+
l = re.findall(r"[\u4e00-\u9fa5a-zA-Z]+:", context)
|
| 74 |
+
l = list(filter(lambda x: x.strip(), l))
|
| 75 |
+
ori_l = deepcopy(l)
|
| 76 |
+
if not l:
|
| 77 |
+
return []
|
| 78 |
+
s = pd.Series(l).value_counts()
|
| 79 |
+
l = pd.Series(s[s > cnt_threshold].index.values.tolist()).map(lambda x: x[:-1]).values.tolist()
|
| 80 |
+
for ele in ori_l:
|
| 81 |
+
if len(ele[:-1]) not in l and (len(ele[:-1]) <= 3 or (
|
| 82 |
+
sum(map(len ,re.findall(r"[a-zA-Z]+:", ele))) == len(ele)
|
| 83 |
+
)):
|
| 84 |
+
l.append(ele[:-1])
|
| 85 |
+
l = list(set(l))
|
| 86 |
+
return l
|
| 87 |
+
|
| 88 |
+
def stdf_prompt_expander(x):
|
| 89 |
+
assert type(x) == type("")
|
| 90 |
+
return prompt_expand_model(x, num_return_sequences=1)[0]["generated_text"]
|
| 91 |
+
|
| 92 |
+
def simple_pred(summary, candidates = ["Tom", "Jack"], shorten_it = False,
|
| 93 |
+
summary_expander = lambda _:_, do_sample = True):
|
| 94 |
+
assert callable(summary_expander)
|
| 95 |
+
summary = summary_expander(summary)
|
| 96 |
+
pred_text = model.predict(
|
| 97 |
+
"{}\nCandidates:{}".format(summary, " ".join(candidates)),
|
| 98 |
+
do_sample = do_sample
|
| 99 |
+
)[0]
|
| 100 |
+
candidates_ = guess_name_candidates(pred_text)
|
| 101 |
+
l = re.split("{}".format("|".join(map(lambda x: "{}:".format(x), candidates_))) ,pred_text)
|
| 102 |
+
l = list(filter(lambda x: x.strip(), l))
|
| 103 |
+
if shorten_it:
|
| 104 |
+
l = shorten_exists(l)
|
| 105 |
+
#l = loop_add(l, candidates)
|
| 106 |
+
l = list(map(lambda x: x.strip(), l))
|
| 107 |
+
return l
|
| 108 |
+
|
| 109 |
+
def percentile_sort(df, perc_num = 101):
|
| 110 |
+
score_tuple_s = df["score_tuple"]
|
| 111 |
+
score_array = np.asarray(score_tuple_s.values.tolist())
|
| 112 |
+
perc_list = np.linspace(0, 100, perc_num).tolist()
|
| 113 |
+
low_to_high_perc_array = np.stack(list(map(lambda p: np.percentile(score_array, p, axis = 0), perc_list)))
|
| 114 |
+
|
| 115 |
+
def get_rank(array_):
|
| 116 |
+
lookup_list = pd.DataFrame(array_ - low_to_high_perc_array[::-1]).apply(lambda s: min(s) >= 0, axis = 1).tolist()
|
| 117 |
+
if True not in lookup_list:
|
| 118 |
+
return len(lookup_list)
|
| 119 |
+
return lookup_list.index(True)
|
| 120 |
+
|
| 121 |
+
rank_list = []
|
| 122 |
+
for i in range(score_array.shape[0]):
|
| 123 |
+
rank_list.append(get_rank(score_array[i, :]))
|
| 124 |
+
|
| 125 |
+
rank_s = pd.Series(rank_list)
|
| 126 |
+
return df.iloc[np.argsort(rank_s.values)]
|
| 127 |
+
|
| 128 |
+
def repeat_score(l, slice_size = 200 ,sim_threshold = 70):
|
| 129 |
+
from copy import deepcopy
|
| 130 |
+
assert type(l) == type([])
|
| 131 |
+
l = deepcopy(l)
|
| 132 |
+
l = sorted(l)
|
| 133 |
+
cnt_num = 0
|
| 134 |
+
set0 = set([])
|
| 135 |
+
for ele in l:
|
| 136 |
+
if ":" in ele:
|
| 137 |
+
ele = "".join(ele.split(":")[1:])
|
| 138 |
+
if set0 and max(map(lambda x: fuzz.ratio(x[:slice_size], ele[:slice_size]), set0)) > sim_threshold:
|
| 139 |
+
#if ele in set0:
|
| 140 |
+
cnt_num += 1
|
| 141 |
+
set0.add(ele)
|
| 142 |
+
return cnt_num
|
| 143 |
+
|
| 144 |
+
def sample_pred(context, times = 5, stdf_prompt_expander = lambda _: _):
|
| 145 |
+
df_req = []
|
| 146 |
+
for i in tqdm(range(times)):
|
| 147 |
+
ele = stdf_prompt_expander(context)
|
| 148 |
+
#ele = context
|
| 149 |
+
l = simple_pred(ele, do_sample = True)
|
| 150 |
+
df_req.append(
|
| 151 |
+
[ele, l]
|
| 152 |
+
)
|
| 153 |
+
df = pd.DataFrame(df_req)
|
| 154 |
+
df.columns = ["context", "dialogue"]
|
| 155 |
+
df["fuzz"] = df["dialogue"].map(
|
| 156 |
+
lambda x: fuzz.ratio(context, " ".join(x))
|
| 157 |
+
)
|
| 158 |
+
df["max_fuzz"] = df["dialogue"].map(
|
| 159 |
+
lambda x: max(map(lambda y: fuzz.ratio(y, context), x))
|
| 160 |
+
)
|
| 161 |
+
df["length"] = df["dialogue"].map(len)
|
| 162 |
+
df["rpt_score"] = df["dialogue"].map(repeat_score)
|
| 163 |
+
df["score_tuple"] = df.apply(
|
| 164 |
+
lambda x: (x["fuzz"], -1 * x["max_fuzz"], x["length"], -1 * x["rpt_score"]), axis = 1
|
| 165 |
+
)
|
| 166 |
+
df = percentile_sort(df)
|
| 167 |
+
return df
|
| 168 |
+
|
| 169 |
+
def sample_pred_wrapper(context, i2c_obj, times = 5, extend_by_diffusion = False):
|
| 170 |
+
assert type(context) == type("")
|
| 171 |
+
if any(map(lambda x: context.endswith(x), [".jpg", ".png", ".jpeg"])):
|
| 172 |
+
img_path = context
|
| 173 |
+
i2c_df = i2c_obj.predict_to_df([img_path])
|
| 174 |
+
assert i2c_df.size > 0
|
| 175 |
+
context = i2c_df["caption"].iloc[0]
|
| 176 |
+
else:
|
| 177 |
+
pass
|
| 178 |
+
assert type(context) == type("")
|
| 179 |
+
if extend_by_diffusion:
|
| 180 |
+
req_df = sample_pred(context, times = times, stdf_prompt_expander = stdf_prompt_expander)
|
| 181 |
+
else:
|
| 182 |
+
req_df = sample_pred(context, times = times, stdf_prompt_expander = lambda _: _)
|
| 183 |
+
return req_df
|
| 184 |
+
|
| 185 |
+
from image2caption import *
|
| 186 |
+
i2c_obj = Image2Caption(device = device)
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
from image2caption import *
|
| 190 |
+
i2c_obj = Image2Caption(device = device)
|
| 191 |
+
|
| 192 |
+
img_path = "../pic/bug.jpg"
|
| 193 |
+
img_path = "../pic/baobao.jpeg"
|
| 194 |
+
img_path = "../pic/cat0.jpg"
|
| 195 |
+
img_path = "../pic/cat.jpg"
|
| 196 |
+
os.path.exists(img_path)
|
| 197 |
+
|
| 198 |
+
df = sample_pred_wrapper(img_path, i2c_obj = i2c_obj)
|
| 199 |
+
df["dialogue"].values.tolist()
|
| 200 |
+
|
| 201 |
+
img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/2/image/image.jpg"
|
| 202 |
+
img_url = "https://datasets-server.huggingface.co/assets/metashift/--/metashift/train/6/image/image.jpg"
|
| 203 |
+
|
| 204 |
+
df = sample_pred_wrapper(img_url, i2c_obj = i2c_obj)
|
| 205 |
+
df["dialogue"].values.tolist()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
text = "Goldfinger is the seventh novel in Ian Fleming's James Bond series. First published in 1959, it centres on Bond's investigation into the gold-smuggling activities of Auric Goldfinger, who is suspected of being connected to Soviet counter-intelligence. "
|
| 209 |
+
text
|
| 210 |
+
|
| 211 |
+
df = sample_pred_wrapper(text, i2c_obj = i2c_obj, times = 6)
|
| 212 |
+
df["dialogue"].values.tolist()
|
| 213 |
+
|
| 214 |
+
en_l = ['a statue of a bird on top of a rock',
|
| 215 |
+
'a woman standing in front of a flower arrangement',
|
| 216 |
+
'people walking down a dirt road',
|
| 217 |
+
'two pictures of a man with a beard',
|
| 218 |
+
'a sign that is on top of a sign',
|
| 219 |
+
'a woman dressed in a costume holding an umbrella',
|
| 220 |
+
'a woman in a red dress holding a flower in her hand',
|
| 221 |
+
'a little girl in a pink dress with a pink flower in her hair']
|
| 222 |
+
|
| 223 |
+
df = sample_pred(en_l[0], 5)
|
| 224 |
+
df["dialogue"].values.tolist()
|
| 225 |
+
|
| 226 |
+
df = sample_pred(en_l[0], 5, stdf_prompt_expander = stdf_prompt_expander)
|
| 227 |
+
df["dialogue"].values.tolist()
|