Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
1557ad2
1
Parent(s):
e071b26
minor update for src/model_operations.py
Browse files
src/backend/model_operations.py
CHANGED
|
@@ -162,7 +162,7 @@ class SummaryGenerator:
|
|
| 162 |
using_replicate_api = False
|
| 163 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 164 |
using_pipeline = False
|
| 165 |
-
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b']
|
| 166 |
|
| 167 |
for replicate_api_model in replicate_api_models:
|
| 168 |
if replicate_api_model in self.model_id.lower():
|
|
@@ -375,12 +375,19 @@ class SummaryGenerator:
|
|
| 375 |
model=self.model_id,
|
| 376 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 377 |
device_map="auto",
|
|
|
|
| 378 |
)
|
| 379 |
else:
|
| 380 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
| 381 |
print("Tokenizer loaded")
|
| 382 |
-
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
print("Local model loaded")
|
| 385 |
|
| 386 |
|
|
@@ -394,6 +401,8 @@ class SummaryGenerator:
|
|
| 394 |
outputs = self.local_pipeline(
|
| 395 |
messages,
|
| 396 |
max_new_tokens=250,
|
|
|
|
|
|
|
| 397 |
)
|
| 398 |
result = outputs[0]["generated_text"][-1]['content']
|
| 399 |
print(result)
|
|
@@ -435,8 +444,8 @@ class SummaryGenerator:
|
|
| 435 |
result = result.split("### Assistant:\n")[-1]
|
| 436 |
|
| 437 |
else:
|
| 438 |
-
print(prompt)
|
| 439 |
-
print('-'*50)
|
| 440 |
result = result.replace(prompt.strip(), '')
|
| 441 |
|
| 442 |
print(result)
|
|
|
|
| 162 |
using_replicate_api = False
|
| 163 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
| 164 |
using_pipeline = False
|
| 165 |
+
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5']
|
| 166 |
|
| 167 |
for replicate_api_model in replicate_api_models:
|
| 168 |
if replicate_api_model in self.model_id.lower():
|
|
|
|
| 375 |
model=self.model_id,
|
| 376 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 377 |
device_map="auto",
|
| 378 |
+
trust_remote_code=True
|
| 379 |
)
|
| 380 |
else:
|
| 381 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
| 382 |
print("Tokenizer loaded")
|
| 383 |
+
if 'jamba' in self.model_id.lower():
|
| 384 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
| 385 |
+
torch_dtype=torch.bfloat16,
|
| 386 |
+
attn_implementation="flash_attention_2",
|
| 387 |
+
device_map="auto")
|
| 388 |
+
else:
|
| 389 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
|
| 390 |
+
# print(self.local_model.device)
|
| 391 |
print("Local model loaded")
|
| 392 |
|
| 393 |
|
|
|
|
| 401 |
outputs = self.local_pipeline(
|
| 402 |
messages,
|
| 403 |
max_new_tokens=250,
|
| 404 |
+
temperature=0.0,
|
| 405 |
+
do_sample=False
|
| 406 |
)
|
| 407 |
result = outputs[0]["generated_text"][-1]['content']
|
| 408 |
print(result)
|
|
|
|
| 444 |
result = result.split("### Assistant:\n")[-1]
|
| 445 |
|
| 446 |
else:
|
| 447 |
+
# print(prompt)
|
| 448 |
+
# print('-'*50)
|
| 449 |
result = result.replace(prompt.strip(), '')
|
| 450 |
|
| 451 |
print(result)
|