nezahatkorkmaz's picture
Create predict_turkce.py
5365632 verified
raw
history blame
1.55 kB
from PIL import Image
import torch
def predict_turkish(image, turkish_question, tokenizer, model, image_processor, tr_en_tokenizer, tr_en_model, en_tr_tokenizer, en_tr_model, device="cuda"):
inputs = tr_en_tokenizer([turkish_question], return_tensors="pt").to(device)
english_question = tr_en_tokenizer.decode(tr_en_model.generate(**inputs)[0], skip_special_tokens=True)
if isinstance(image, str):
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
from llava.conversation import conv_templates
from llava.constants import DEFAULT_IMAGE_TOKEN
conv = conv_templates["llava-v1"].copy()
conv.messages = []
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + english_question)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
image_tensor = image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(device)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.inference_mode():
output_ids = model.generate(
input_ids=input_ids,
images=image_tensor,
do_sample=False,
temperature=0.2,
max_new_tokens=512
)
english_answer = tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
answer_inputs = en_tr_tokenizer([english_answer], return_tensors="pt").to(device)
return en_tr_tokenizer.decode(en_tr_model.generate(**answer_inputs)[0], skip_special_tokens=True)