MiniCPM-V-4_5-Demo / models /minicpmv4_5.py
tc-mb's picture
Upload 5 files
7997f38 verified
raw
history blame
6.13 kB
import spaces
from io import BytesIO
import torch
from PIL import Image
import base64
import json
import re
import logging
from transformers import AutoModel, AutoTokenizer, AutoProcessor, set_seed
# set_seed(42)
logger = logging.getLogger(__name__)
class ModelMiniCPMV4_5:
def __init__(self, path) -> None:
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16, device_map="auto")
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(
path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(
path, trust_remote_code=True)
def __call__(self, input_data):
image = None
if "image" in input_data and len(input_data["image"]) > 10:
image = Image.open(BytesIO(base64.b64decode(
input_data["image"]))).convert('RGB')
msgs = input_data["question"]
params = input_data.get("params", "{}")
params = json.loads(params)
msgs = json.loads(msgs)
temporal_ids = input_data.get("temporal_ids", None)
if temporal_ids:
temporal_ids = json.loads(temporal_ids)
if params.get("max_new_tokens", 0) > 16384:
logger.info(f"make max_new_tokens=16384, reducing limit to save memory")
params["max_new_tokens"] = 16384
if params.get("max_inp_length", 0) > 2048 * 10:
logger.info(f"make max_inp_length={2048 * 10}, keeping high limit for video processing")
params["max_inp_length"] = 2048 * 10
for msg in msgs:
if 'content' in msg:
contents = msg['content']
else:
contents = msg.pop('contents')
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(
BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError(
"contents type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
logger.info(f'msgs: {str(msgs)}')
enable_thinking = params.pop('enable_thinking', True)
is_streaming = params.pop('stream', False)
if is_streaming:
return self._stream_chat(image, msgs, enable_thinking, params, temporal_ids)
else:
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer = self.model.chat(**chat_kwargs)
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
if not enable_thinking:
print(f"enable_thinking: {enable_thinking}")
answer = answer.replace('</think>', '')
oids = self.tokenizer.encode(answer)
output_tokens = len(oids)
return answer, output_tokens
def _stream_chat(self, image, msgs, enable_thinking, params, temporal_ids=None):
try:
params['stream'] = True
chat_kwargs = {
"image": image,
"msgs": msgs,
"tokenizer": self.tokenizer,
"processor": self.processor,
"enable_thinking": enable_thinking,
**params
}
if temporal_ids is not None:
chat_kwargs["temporal_ids"] = temporal_ids
answer_generator = self.model.chat(**chat_kwargs)
if not hasattr(answer_generator, '__iter__'):
answer = answer_generator
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
if not enable_thinking:
answer = answer.replace('</think>', '')
char_count = 0
for char in answer:
yield char
char_count += 1
else:
full_answer = ""
chunk_count = 0
char_count = 0
for chunk in answer_generator:
if isinstance(chunk, str):
clean_chunk = re.sub(r'(<box>.*</box>)', '', chunk)
clean_chunk = clean_chunk.replace('<ref>', '')
clean_chunk = clean_chunk.replace('</ref>', '')
clean_chunk = clean_chunk.replace('<box>', '')
clean_chunk = clean_chunk.replace('</box>', '')
if not enable_thinking:
clean_chunk = clean_chunk.replace('</think>', '')
full_answer += chunk
char_count += len(clean_chunk)
chunk_count += 1
yield clean_chunk
else:
full_answer += str(chunk)
char_count += len(str(chunk))
chunk_count += 1
yield str(chunk)
except Exception as e:
logger.error(f"Stream chat error: {e}")
yield f"Error: {str(e)}"