0xShonen commited on
Commit
18b3855
·
verified ·
1 Parent(s): 9f13f62

Upload vllm_template_gptoss.py

Browse files
Files changed (1) hide show
  1. vllm_template_gptoss.py +540 -0
vllm_template_gptoss.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ from enum import Enum
5
+ from pydantic import BaseModel, Field
6
+ from typing import Dict, Any, Callable, Literal, Optional, Union, List
7
+ from chutes.image import Image
8
+ from chutes.image.standard.vllm import VLLM
9
+ from chutes.chute import Chute, ChutePack, NodeSelector
10
+
11
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
12
+
13
+
14
+ class DefaultRole(Enum):
15
+ user = "user"
16
+ assistant = "assistant"
17
+
18
+
19
+ class ChatMessage(BaseModel):
20
+ role: str
21
+ content: str
22
+
23
+
24
+ class Logprob(BaseModel):
25
+ logprob: float
26
+ rank: Optional[int] = None
27
+ decoded_token: Optional[str] = None
28
+
29
+
30
+ class ResponseFormat(BaseModel):
31
+ type: Literal["text", "json_object", "json_schema"]
32
+ json_schema: Optional[Dict] = None
33
+
34
+
35
+ class BaseRequest(BaseModel):
36
+ model: str
37
+ frequency_penalty: Optional[float] = 0.0
38
+ logit_bias: Optional[Dict[str, float]] = None
39
+ logprobs: Optional[bool] = False
40
+ top_logprobs: Optional[int] = 0
41
+ max_tokens: Optional[int] = None
42
+ presence_penalty: Optional[float] = 0.0
43
+ response_format: Optional[ResponseFormat] = None
44
+ seed: Optional[int] = Field(None, ge=0, le=9223372036854775807)
45
+ stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
46
+ stream: Optional[bool] = False
47
+ temperature: Optional[float] = 0.7
48
+ top_p: Optional[float] = 1.0
49
+ best_of: Optional[int] = None
50
+ use_beam_search: bool = False
51
+ top_k: int = -1
52
+ min_p: float = 0.0
53
+ repetition_penalty: float = 1.0
54
+ length_penalty: float = 1.0
55
+ stop_token_ids: Optional[List[int]] = Field(default_factory=list)
56
+ include_stop_str_in_output: bool = False
57
+ ignore_eos: bool = False
58
+ min_tokens: int = 0
59
+ skip_special_tokens: bool = True
60
+ spaces_between_special_tokens: bool = True
61
+ prompt_logprobs: Optional[int] = None
62
+
63
+
64
+ class UsageInfo(BaseModel):
65
+ prompt_tokens: int = 0
66
+ total_tokens: int = 0
67
+ completion_tokens: Optional[int] = 0
68
+
69
+
70
+ class TokenizeRequest(BaseRequest):
71
+ model: str
72
+ prompt: str
73
+ add_special_tokens: bool
74
+
75
+
76
+ class DetokenizeRequest(BaseRequest):
77
+ model: str
78
+ tokens: List[int]
79
+
80
+
81
+ class ChatCompletionRequest(BaseRequest):
82
+ messages: List[ChatMessage]
83
+
84
+
85
+ class CompletionRequest(BaseRequest):
86
+ prompt: str
87
+
88
+
89
+ class ChatCompletionLogProb(BaseModel):
90
+ token: str
91
+ logprob: float = -9999.0
92
+ bytes: Optional[List[int]] = None
93
+
94
+
95
+ class ChatCompletionLogProbsContent(ChatCompletionLogProb):
96
+ top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
97
+
98
+
99
+ class ChatCompletionLogProbs(BaseModel):
100
+ content: Optional[List[ChatCompletionLogProbsContent]] = None
101
+
102
+
103
+ class ChatCompletionResponseChoice(BaseModel):
104
+ index: int
105
+ message: ChatMessage
106
+ logprobs: Optional[ChatCompletionLogProbs] = None
107
+ finish_reason: Optional[str] = "stop"
108
+ stop_reason: Optional[Union[int, str]] = None
109
+
110
+
111
+ class ChatCompletionResponse(BaseModel):
112
+ id: str
113
+ object: Literal["chat.completion"] = "chat.completion"
114
+ created: int
115
+ model: str
116
+ choices: List[ChatCompletionResponseChoice]
117
+ usage: UsageInfo
118
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
119
+
120
+
121
+ class TokenizeResponse(BaseRequest):
122
+ count: int
123
+ max_model_len: int
124
+ tokens: List[int]
125
+
126
+
127
+ class DetokenizeResponse(BaseRequest):
128
+ prompt: str
129
+
130
+
131
+ class DeltaMessage(BaseModel):
132
+ role: Optional[str] = None
133
+ content: Optional[str] = None
134
+
135
+
136
+ class ChatCompletionResponseStreamChoice(BaseModel):
137
+ index: int
138
+ delta: DeltaMessage
139
+ logprobs: Optional[ChatCompletionLogProbs] = None
140
+ finish_reason: Optional[str] = None
141
+ stop_reason: Optional[Union[int, str]] = None
142
+
143
+
144
+ class ChatCompletionStreamResponse(BaseModel):
145
+ id: str
146
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
147
+ created: int
148
+ model: str
149
+ choices: List[ChatCompletionResponseStreamChoice]
150
+ usage: Optional[UsageInfo] = Field(default=None)
151
+
152
+
153
+ class CompletionLogProbs(BaseModel):
154
+ text_offset: List[int] = Field(default_factory=list)
155
+ token_logprobs: List[Optional[float]] = Field(default_factory=list)
156
+ tokens: List[str] = Field(default_factory=list)
157
+ top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
158
+
159
+
160
+ class CompletionResponseChoice(BaseModel):
161
+ index: int
162
+ text: str
163
+ logprobs: Optional[CompletionLogProbs] = None
164
+ finish_reason: Optional[str] = None
165
+ stop_reason: Optional[Union[int, str]] = Field(
166
+ default=None,
167
+ description=(
168
+ "The stop string or token id that caused the completion "
169
+ "to stop, None if the completion finished for some other reason "
170
+ "including encountering the EOS token"
171
+ ),
172
+ )
173
+ prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
174
+
175
+
176
+ class CompletionResponse(BaseModel):
177
+ id: str
178
+ object: str = "text_completion"
179
+ created: int
180
+ model: str
181
+ choices: List[CompletionResponseChoice]
182
+ usage: UsageInfo
183
+
184
+
185
+ class CompletionResponseStreamChoice(BaseModel):
186
+ index: int
187
+ text: str
188
+ logprobs: Optional[CompletionLogProbs] = None
189
+ finish_reason: Optional[str] = None
190
+ stop_reason: Optional[Union[int, str]] = Field(
191
+ default=None,
192
+ description=(
193
+ "The stop string or token id that caused the completion "
194
+ "to stop, None if the completion finished for some other reason "
195
+ "including encountering the EOS token"
196
+ ),
197
+ )
198
+
199
+
200
+ class CompletionStreamResponse(BaseModel):
201
+ id: str
202
+ object: str
203
+ created: int
204
+ model: str
205
+ choices: List[CompletionResponseStreamChoice]
206
+ usage: Optional[UsageInfo] = Field(default=None)
207
+
208
+
209
+ class VLLMChute(ChutePack):
210
+ chat: Callable
211
+ completion: Callable
212
+ chat_stream: Callable
213
+ completion_stream: Callable
214
+ models: Callable
215
+
216
+
217
+ def build_vllm_chute(
218
+ username: str,
219
+ model_name: str,
220
+ node_selector: NodeSelector,
221
+ image: str | Image = VLLM,
222
+ tagline: str = "",
223
+ readme: str = "",
224
+ concurrency: int = 32,
225
+ engine_args: Dict[str, Any] = {},
226
+ revision: str = None,
227
+ ):
228
+ if engine_args.get("revision"):
229
+ raise ValueError("revision is now a top-level argument to build_vllm_chute!")
230
+ if not revision:
231
+ from chutes.chute.template.helpers import get_current_hf_commit
232
+
233
+ suggested_commit = None
234
+ try:
235
+ suggested_commit = get_current_hf_commit(model_name)
236
+ except Exception:
237
+ ...
238
+ suggestion = (
239
+ "Unable to fetch the current refs/heads/main commit from HF, please check the model name."
240
+ if not suggested_commit
241
+ else f"The current refs/heads/main commit is: {suggested_commit}"
242
+ )
243
+ raise ValueError(
244
+ f"You must specify revision= to properly lock a model to a given huggingface revision. {suggestion}"
245
+ )
246
+
247
+ chute = Chute(
248
+ username=username,
249
+ name=model_name,
250
+ tagline=tagline,
251
+ readme=readme,
252
+ image=image,
253
+ node_selector=node_selector,
254
+ concurrency=concurrency,
255
+ standard_template="vllm",
256
+ revision=revision,
257
+ )
258
+
259
+ # Semi-optimized defaults for code starts (but not overall perf once hot).
260
+ defaults = {}
261
+ for key, value in defaults.items():
262
+ if key not in engine_args:
263
+ engine_args[key] = value
264
+
265
+ # Minimal input schema with defaults.
266
+ class MinifiedMessage(BaseModel):
267
+ role: DefaultRole = DefaultRole.user
268
+ content: str = Field("")
269
+
270
+ class MinifiedStreamChatCompletion(BaseModel):
271
+ messages: List[MinifiedMessage] = [MinifiedMessage()]
272
+ temperature: float = Field(0.7)
273
+ seed: int = Field(42)
274
+ stream: bool = Field(True)
275
+ max_tokens: int = Field(1024)
276
+ model: str = Field(model_name)
277
+
278
+ class MinifiedChatCompletion(MinifiedStreamChatCompletion):
279
+ stream: bool = Field(False)
280
+
281
+ # Minimal completion input.
282
+ class MinifiedStreamCompletion(BaseModel):
283
+ prompt: str
284
+ temperature: float = Field(0.7)
285
+ seed: int = Field(42)
286
+ stream: bool = Field(True)
287
+ max_tokens: int = Field(1024)
288
+ model: str = Field(model_name)
289
+
290
+ class MinifiedCompletion(MinifiedStreamCompletion):
291
+ stream: bool = Field(False)
292
+
293
+ @chute.on_startup()
294
+ async def initialize_vllm(self):
295
+ nonlocal engine_args
296
+ nonlocal model_name
297
+ nonlocal image
298
+
299
+ # Imports here to avoid needing torch/vllm/etc. to just perform inference/build remotely.
300
+ import torch
301
+ import multiprocessing
302
+ from vllm import AsyncEngineArgs, AsyncLLMEngine
303
+ import vllm.entrypoints.openai.api_server as vllm_api_server
304
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
305
+ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
306
+ import vllm.version as vv
307
+
308
+ # Force download in initializer with some retries.
309
+ from huggingface_hub import snapshot_download
310
+
311
+ download_path = None
312
+ for attempt in range(5):
313
+ download_kwargs = {}
314
+ if self.revision:
315
+ download_kwargs["revision"] = self.revision
316
+ try:
317
+ print(f"Attempting to download {model_name} to cache...")
318
+ download_path = await asyncio.to_thread(
319
+ snapshot_download, repo_id=model_name, **download_kwargs
320
+ )
321
+ print(f"Successfully downloaded {model_name} to {download_path}")
322
+ break
323
+ except Exception as exc:
324
+ print(f"Failed downloading {model_name} {download_kwargs or ''}: {exc}")
325
+ await asyncio.sleep(60)
326
+ if not download_path:
327
+ raise Exception(f"Failed to download {model_name} after 5 attempts")
328
+
329
+ try:
330
+ from vllm.entrypoints.openai.serving_engine import BaseModelPath
331
+ except Exception:
332
+ from vllm.entrypoints.openai.serving_models import (
333
+ BaseModelPath,
334
+ OpenAIServingModels,
335
+ )
336
+ from vllm.entrypoints.openai.serving_tokenization import (
337
+ OpenAIServingTokenization,
338
+ )
339
+
340
+ # Reset torch.
341
+ torch.cuda.empty_cache()
342
+ torch.cuda.init()
343
+ torch.cuda.set_device(0)
344
+ multiprocessing.set_start_method("spawn", force=True)
345
+
346
+ # Tool args.
347
+ if chat_template := engine_args.pop("chat_template", None):
348
+ if len(chat_template) <= 1024 and os.path.exists(chat_template):
349
+ with open(chat_template) as infile:
350
+ chat_template = infile.read()
351
+ extra_args = dict(
352
+ tool_parser=engine_args.pop("tool_call_parser", None),
353
+ enable_auto_tools=engine_args.pop("enable_auto_tool_choice", False),
354
+ chat_template=chat_template,
355
+ chat_template_content_format=engine_args.pop("chat_template_content_format", None),
356
+ )
357
+
358
+ # Configure engine arguments
359
+ gpu_count = int(os.getenv("CUDA_DEVICE_COUNT", str(torch.cuda.device_count())))
360
+ engine_args = AsyncEngineArgs(
361
+ model=model_name,
362
+ tensor_parallel_size=gpu_count,
363
+ **engine_args,
364
+ )
365
+
366
+ # Initialize engine directly in the main process
367
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
368
+ model_config = await self.engine.get_model_config()
369
+
370
+ base_model_paths = [
371
+ BaseModelPath(name=chute.name, model_path=chute.name),
372
+ ]
373
+
374
+ self.include_router(vllm_api_server.router)
375
+ extra_token_args = {}
376
+ version_parts = vv.__version__.split(".")
377
+ old_vllm = False
378
+ if (
379
+ not vv.__version__.startswith("0.1.dev")
380
+ and int(version_parts[0]) == 0
381
+ and int(version_parts[1]) < 7
382
+ ):
383
+ old_vllm = True
384
+ if old_vllm:
385
+ extra_args["lora_modules"] = []
386
+ extra_args["prompt_adapters"] = []
387
+ extra_token_args["lora_modules"] = []
388
+ extra_args["base_model_paths"] = base_model_paths
389
+ else:
390
+ extra_args["models"] = OpenAIServingModels(
391
+ engine_client=self.engine,
392
+ model_config=model_config,
393
+ base_model_paths=base_model_paths,
394
+ lora_modules=[],
395
+ )
396
+ extra_token_args.update(
397
+ {
398
+ "chat_template": extra_args.get("chat_template"),
399
+ "chat_template_content_format": extra_args.get("chat_template_content_format"),
400
+ }
401
+ )
402
+
403
+ vllm_api_server.chat = lambda s: OpenAIServingChat(
404
+ self.engine,
405
+ model_config=model_config,
406
+ response_role="assistant",
407
+ request_logger=None,
408
+ return_tokens_as_token_ids=True,
409
+ **extra_args,
410
+ )
411
+ vllm_api_server.completion = lambda s: OpenAIServingCompletion(
412
+ self.engine,
413
+ model_config=model_config,
414
+ request_logger=None,
415
+ return_tokens_as_token_ids=True,
416
+ **{
417
+ k: v
418
+ for k, v in extra_args.items()
419
+ if k
420
+ not in (
421
+ "chat_template",
422
+ "chat_template_content_format",
423
+ "tool_parser",
424
+ "enable_auto_tools",
425
+ )
426
+ },
427
+ )
428
+ models_arg = base_model_paths if old_vllm else extra_args["models"]
429
+ vllm_api_server.tokenization = lambda s: OpenAIServingTokenization(
430
+ self.engine,
431
+ model_config,
432
+ models_arg,
433
+ request_logger=None,
434
+ **extra_token_args,
435
+ )
436
+ self.state.openai_serving_tokenization = OpenAIServingTokenization(
437
+ self.engine,
438
+ model_config,
439
+ models_arg,
440
+ request_logger=None,
441
+ **extra_token_args,
442
+ )
443
+ setattr(self.state, "enable_server_load_tracking", False)
444
+ if not old_vllm:
445
+ self.state.openai_serving_models = extra_args["models"]
446
+
447
+ def _parse_stream_chunk(encoded_chunk):
448
+ chunk = encoded_chunk if isinstance(encoded_chunk, str) else encoded_chunk.decode()
449
+ if "data: {" in chunk:
450
+ return json.loads(chunk[6:])
451
+ return None
452
+
453
+ @chute.cord(
454
+ passthrough_path="/v1/chat/completions",
455
+ public_api_path="/v1/chat/completions",
456
+ method="POST",
457
+ passthrough=True,
458
+ stream=True,
459
+ input_schema=ChatCompletionRequest,
460
+ minimal_input_schema=MinifiedStreamChatCompletion,
461
+ )
462
+ async def chat_stream(encoded_chunk) -> ChatCompletionStreamResponse:
463
+ return _parse_stream_chunk(encoded_chunk)
464
+
465
+ @chute.cord(
466
+ passthrough_path="/v1/completions",
467
+ public_api_path="/v1/completions",
468
+ method="POST",
469
+ passthrough=True,
470
+ stream=True,
471
+ input_schema=CompletionRequest,
472
+ minimal_input_schema=MinifiedStreamCompletion,
473
+ )
474
+ async def completion_stream(encoded_chunk) -> CompletionStreamResponse:
475
+ return _parse_stream_chunk(encoded_chunk)
476
+
477
+ @chute.cord(
478
+ passthrough_path="/v1/chat/completions",
479
+ public_api_path="/v1/chat/completions",
480
+ method="POST",
481
+ passthrough=True,
482
+ input_schema=ChatCompletionRequest,
483
+ minimal_input_schema=MinifiedChatCompletion,
484
+ )
485
+ async def chat(data) -> ChatCompletionResponse:
486
+ return data
487
+
488
+ @chute.cord(
489
+ path="/do_tokenize",
490
+ passthrough_path="/tokenize",
491
+ public_api_path="/tokenize",
492
+ method="POST",
493
+ passthrough=True,
494
+ input_schema=TokenizeRequest,
495
+ minimal_input_schema=TokenizeRequest,
496
+ )
497
+ async def do_tokenize(data) -> TokenizeResponse:
498
+ return data
499
+
500
+ @chute.cord(
501
+ path="/do_detokenize",
502
+ passthrough_path="/detokenize",
503
+ public_api_path="/detokenize",
504
+ method="POST",
505
+ passthrough=True,
506
+ input_schema=DetokenizeRequest,
507
+ minimal_input_schema=DetokenizeRequest,
508
+ )
509
+ async def do_detokenize(data) -> DetokenizeResponse:
510
+ return data
511
+
512
+ @chute.cord(
513
+ passthrough_path="/v1/completions",
514
+ public_api_path="/v1/completions",
515
+ method="POST",
516
+ passthrough=True,
517
+ input_schema=CompletionRequest,
518
+ minimal_input_schema=MinifiedCompletion,
519
+ )
520
+ async def completion(data) -> CompletionResponse:
521
+ return data
522
+
523
+ @chute.cord(
524
+ passthrough_path="/v1/models",
525
+ public_api_path="/v1/models",
526
+ public_api_method="GET",
527
+ method="GET",
528
+ passthrough=True,
529
+ )
530
+ async def get_models(data):
531
+ return data
532
+
533
+ return VLLMChute(
534
+ chute=chute,
535
+ chat=chat,
536
+ chat_stream=chat_stream,
537
+ completion=completion,
538
+ completion_stream=completion_stream,
539
+ models=get_models,
540
+ )