JamePeng2023
commited on
Commit
•
98dd963
1
Parent(s):
6fcf8b4
Add streaming support for text generation
Browse files- Implemented streaming functionality for real-time text output.
- Added `_decode_stream` method to handle text streaming.
- Updated `chat` method to support streaming mode.
- Adjusted code to process and yield text in chunks for better responsiveness.
This update enhances the user experience by allowing incremental text generation and display.
- modeling_minicpm.py +64 -8
modeling_minicpm.py
CHANGED
@@ -22,12 +22,14 @@ import math
|
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
|
|
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
30 |
|
|
|
31 |
from transformers.activations import ACT2FN
|
32 |
from transformers.cache_utils import Cache, DynamicCache
|
33 |
from transformers.modeling_attn_mask_utils import (
|
@@ -1248,6 +1250,9 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
1248 |
self.vocab_size = config.vocab_size
|
1249 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1250 |
|
|
|
|
|
|
|
1251 |
# Initialize weights and apply final processing
|
1252 |
self.post_init()
|
1253 |
|
@@ -1426,11 +1431,52 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
1426 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1427 |
)
|
1428 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1429 |
|
1430 |
@torch.inference_mode()
|
1431 |
-
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
|
1432 |
-
|
1433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1434 |
if history is None:
|
1435 |
history = []
|
1436 |
if logits_processor:
|
@@ -1443,12 +1489,22 @@ class MiniCPM3ForCausalLM(MiniCPM3PreTrainedModel):
|
|
1443 |
history.append({"role": role, "content": query})
|
1444 |
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
1445 |
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
1446 |
-
outputs = self.generate(**inputs, **gen_kwargs)
|
1447 |
-
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1448 |
-
response = tokenizer.decode(outputs)
|
1449 |
-
history.append({"role": "assistant", "content": response})
|
1450 |
-
return response, history
|
1451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1452 |
|
1453 |
@add_start_docstrings(
|
1454 |
"""
|
|
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
|
25 |
+
from threading import Thread
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
28 |
import torch.utils.checkpoint
|
29 |
from torch import nn
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
|
32 |
+
from transformers import TextIteratorStreamer
|
33 |
from transformers.activations import ACT2FN
|
34 |
from transformers.cache_utils import Cache, DynamicCache
|
35 |
from transformers.modeling_attn_mask_utils import (
|
|
|
1250 |
self.vocab_size = config.vocab_size
|
1251 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1252 |
|
1253 |
+
# List of terminator tokens used to indicate the end of a sequence or conversation.
|
1254 |
+
self.terminators = ['</s>', '<|im_end|>']
|
1255 |
+
|
1256 |
# Initialize weights and apply final processing
|
1257 |
self.post_init()
|
1258 |
|
|
|
1431 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
1432 |
)
|
1433 |
return reordered_past
|
1434 |
+
|
1435 |
+
# Internal function to handle streaming of generated text using TextIteratorStreamer.
|
1436 |
+
def _decode_stream(self, input_ids, tokenizer, **kwargs):
|
1437 |
+
# Convert terminators to token IDs
|
1438 |
+
terminators_ids = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
1439 |
+
# Initialize TextIteratorStreamer for handling streaming output
|
1440 |
+
streamer = TextIteratorStreamer(tokenizer=tokenizer,skip_prompt=True, skip_special_tokens=True)
|
1441 |
+
# Set up generation parameters, including input IDs, eos token IDs, and streamer
|
1442 |
+
generation_kwargs = {
|
1443 |
+
'input_ids': input_ids,
|
1444 |
+
'eos_token_id': terminators_ids,
|
1445 |
+
'streamer': streamer
|
1446 |
+
}
|
1447 |
+
generation_kwargs.update(kwargs)
|
1448 |
+
# Run the generation task in a separate thread to enable streaming output
|
1449 |
+
thread = Thread(target=self.generate, kwargs=generation_kwargs)
|
1450 |
+
thread.start()
|
1451 |
+
# Return the streamer instance for later access to streamed text
|
1452 |
+
return streamer
|
1453 |
+
|
1454 |
|
1455 |
@torch.inference_mode()
|
1456 |
+
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", max_length: int = 4096, num_beams=1,
|
1457 |
+
do_sample=True, logits_processor=None, stream=False, top_p=0.8, temperature=0.3, **kwargs):
|
1458 |
+
"""
|
1459 |
+
Main function for handling dialogue generation based on the input query and history.
|
1460 |
+
|
1461 |
+
Parameters:
|
1462 |
+
- tokenizer: Tokenizer instance used for encoding and decoding.
|
1463 |
+
- query: The user input query string.
|
1464 |
+
- history: Dialogue history, a list of dictionaries where each dictionary contains role and content.
|
1465 |
+
- role: The current role, default is "user".
|
1466 |
+
- max_length: Maximum length of the generated text.
|
1467 |
+
- num_beams: Number of beams for beam search.
|
1468 |
+
- do_sample: Whether to use sampling for generation.
|
1469 |
+
- logits_processor: Function for processing logits (if any).
|
1470 |
+
- stream: Whether to use streaming output.
|
1471 |
+
- top_p: Nucleus sampling parameter.
|
1472 |
+
- temperature: Temperature parameter for generation.
|
1473 |
+
- **kwargs: Additional arguments for generation.
|
1474 |
+
|
1475 |
+
Returns:
|
1476 |
+
- If stream is True, returns a generator function to get the generated text incrementally.
|
1477 |
+
- If stream is False, returns the complete generated response string.
|
1478 |
+
"""
|
1479 |
+
|
1480 |
if history is None:
|
1481 |
history = []
|
1482 |
if logits_processor:
|
|
|
1489 |
history.append({"role": role, "content": query})
|
1490 |
history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
1491 |
inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
|
|
|
|
|
|
|
|
|
|
|
1492 |
|
1493 |
+
if stream:
|
1494 |
+
res = self._decode_stream(inputs["input_ids"], tokenizer, **gen_kwargs)
|
1495 |
+
def stream_gen():
|
1496 |
+
for text in res:
|
1497 |
+
# Remove terminators from the text
|
1498 |
+
for term in self.terminators:
|
1499 |
+
text = text.replace(term, '')
|
1500 |
+
yield text
|
1501 |
+
return stream_gen()
|
1502 |
+
|
1503 |
+
else:
|
1504 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
1505 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1506 |
+
response = tokenizer.decode(outputs)
|
1507 |
+
return response
|
1508 |
|
1509 |
@add_start_docstrings(
|
1510 |
"""
|