Upload modeling_minicpm.py
Browse files- modeling_minicpm.py +28 -1
modeling_minicpm.py
CHANGED
@@ -20,7 +20,7 @@
|
|
20 |
""" PyTorch MiniCPM model."""
|
21 |
import math
|
22 |
import warnings
|
23 |
-
from typing import List, Optional, Tuple, Union
|
24 |
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
@@ -49,6 +49,7 @@ from transformers.utils import (
|
|
49 |
)
|
50 |
from transformers.utils.import_utils import is_torch_fx_available
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
|
|
52 |
|
53 |
|
54 |
if is_flash_attn_2_available():
|
@@ -1302,6 +1303,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1302 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1303 |
)
|
1304 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1305 |
|
1306 |
|
1307 |
@add_start_docstrings(
|
|
|
20 |
""" PyTorch MiniCPM model."""
|
21 |
import math
|
22 |
import warnings
|
23 |
+
from typing import List, Optional, Tuple, Union, Dict
|
24 |
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
|
|
49 |
)
|
50 |
from transformers.utils.import_utils import is_torch_fx_available
|
51 |
from .configuration_minicpm import MiniCPMConfig
|
52 |
+
import re
|
53 |
|
54 |
|
55 |
if is_flash_attn_2_available():
|
|
|
1303 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1304 |
)
|
1305 |
return reordered_past
|
1306 |
+
|
1307 |
+
@torch.inference_mode()
|
1308 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1309 |
+
max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
|
1310 |
+
**kwargs):
|
1311 |
+
if history is None:
|
1312 |
+
history = []
|
1313 |
+
if logits_processor:
|
1314 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1315 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1316 |
+
else:
|
1317 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1318 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1319 |
+
|
1320 |
+
history.append({"role": role, "content": query})
|
1321 |
+
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
|
1322 |
+
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
1323 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
1324 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1325 |
+
response = tokenizer.decode(outputs)
|
1326 |
+
pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
|
1327 |
+
matches = pattern.findall(response)
|
1328 |
+
if len(matches) > 0:
|
1329 |
+
response = matches[0]
|
1330 |
+
history.append({"role": "assistant", "content": response})
|
1331 |
+
return response, history
|
1332 |
|
1333 |
|
1334 |
@add_start_docstrings(
|