Spaces:
Runtime error
Runtime error
from secrets_key import OPENAI_KEY, RANDOM_SEED | |
from openai import OpenAI | |
import json | |
import pandas as pd | |
from pprint import pprint | |
client = OpenAI(api_key=OPENAI_KEY) | |
ti_prompt = """ | |
In the following HIT, you will be presented with a collection of images and a story that is, in some manner, related to that goal. You are also given a specific agent and an entity (generally, person or object). | |
1. Identify a goal that the agent in the story is trying to achieve. | |
2. Identify and write a condition that is necessary for goal completion. The condition should be related to the entity. | |
3. Write an alternate condition that will make the goal unlikely to happen. It is likely that this alternate condition will contradict information provided in the images and/or story. | |
Make each response minimum 5 words long and maximum 25 words long. | |
Story: {story} | |
Agent: {agent} | |
Entity: {entity} | |
""" | |
t_prompt = """ | |
In the following HIT, you will be presented with a story. You are also given a specific agent and an entity (generally, person or object). | |
1. Identify a goal that the agent in the story is trying to achieve. | |
2. Identify and write a condition that is necessary for goal completion. The condition should be related to the entity. | |
3. Write an alternate condition that will make the goal unlikely to happen. It is likely that this alternate condition will contradict information provided in the story. | |
Story: {story} | |
Agent: {agent} | |
Entity: {entity} | |
""" | |
i_prompt = """ | |
In the following HIT, you will be presented with a collection of images of a story. You are also given a specific agent and an entity (generally, person or object). | |
1. Identify a goal that the agent in the story is trying to achieve. | |
2. Identify and write a condition that is necessary for goal completion. The condition should be related to the entity. | |
3. Write an alternate condition that will make the goal unlikely to happen. It is likely that this alternate condition will contradict information provided in the images. | |
Agent: {agent} | |
Entity: {entity} | |
""" | |
def analysis(story, agent, entity, images, text=True, image=True): | |
if text and image: | |
now_prompt = ti_prompt.format(story=story, agent=agent, entity=entity) | |
elif text: | |
now_prompt = t_prompt.format(story=story, agent=agent, entity=entity) | |
elif image: | |
now_prompt = i_prompt.format(agent=agent, entity=entity) | |
else: | |
raise ValueError("text and image cannot both be False") | |
content = [ | |
{"type": "text", "text": now_prompt}, | |
] | |
if image: | |
content.extend([{"type": "image_url", "image_url": image} for image in images]) | |
response = client.chat.completions.create( | |
model="gpt-4-vision-preview", | |
seed=RANDOM_SEED, | |
messages=[ | |
{ | |
"role": "user", | |
"content": content | |
} | |
], | |
temperature=1, | |
max_tokens=256, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
) | |
out = response.choices[0].message.content | |
print("OUTPUT:", out) | |
print() | |
return out | |
if __name__ == '__main__': | |
df = pd.read_csv('./results.csv') | |
df = df.sample(frac=1, random_state=42).reset_index(drop=True) | |
count = 0 | |
done = set() | |
data = [] | |
for ind, row in df.iterrows(): | |
item_id = row['Input.item_id'] | |
if item_id in done: | |
continue | |
done.add(item_id) | |
try: | |
story = row['Input.story'] | |
agent = row['Input.agent'] | |
entity = row['Input.entity'] | |
images = [row[f'Input.image{i}'] for i in range(1,4)] | |
print("HITId:", row['HITId']) | |
print("Prompt:", ti_prompt.format(story=story, agent=agent, entity=entity)) | |
print() | |
print("Image1:", images[0]) | |
print("Image2:", images[1]) | |
print("Image3:", images[2]) | |
print() | |
gpt4v_out = "" | |
for text in [True, False]: | |
for image in [True, False]: | |
if text or image: | |
print("Text Visible:", text) | |
print("Image Visible:", image) | |
out = analysis(story, agent, entity, images, text=text, image=image) | |
# gpt4v_out += f"#### Text Visible: {text}\n" | |
# gpt4v_out += f"#### Image Visible: {image}\n" | |
if text and image: | |
gpt4v_out += f"### Both text and image are visible.\n" | |
elif text: | |
gpt4v_out += f"### Only text is visible.\n" | |
elif image: | |
gpt4v_out += f"### Only image is visible.\n" | |
gpt4v_out += f"#### Output: \n" | |
gpt4v_out += out | |
gpt4v_out += "\n" | |
obj = row.to_dict() | |
obj['GPT4V.out'] = gpt4v_out | |
data.append(obj) | |
print("====================================") | |
print() | |
count += 1 | |
except Exception as e: | |
done.remove(item_id) | |
print("ERROR:", e) | |
print("====================================") | |
print() | |
if count == 10: | |
break | |
df = pd.DataFrame(data) | |
df.to_csv('./results_with_gpt4v.csv', index=False) |