|
|
|
import os |
|
import os |
|
import torch |
|
from pathlib import Path |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
def find_latest_finetuned_model() -> str | None: |
|
""" |
|
Look under: |
|
1. <script dir>/results/auto_antislop_runs |
|
2. /results/auto_antislop_runs |
|
Return the most recent `…/finetuned_model*/merged_16bit` directory or None. |
|
""" |
|
candidate_bases = [ |
|
Path(__file__).resolve().parent / "results" / "auto_antislop_runs", |
|
Path("/results/auto_antislop_runs"), |
|
] |
|
|
|
latest: tuple[float, Path] | None = None |
|
for base in candidate_bases: |
|
if not base.is_dir(): |
|
continue |
|
|
|
|
|
for merged_dir in base.glob("run_*/finetuned_model*/merged_16bit"): |
|
if not merged_dir.is_dir(): |
|
continue |
|
mtime = merged_dir.parent.stat().st_mtime |
|
if latest is None or mtime > latest[0]: |
|
latest = (mtime, merged_dir.resolve()) |
|
|
|
return str(latest[1]) if latest else None |
|
|
|
|
|
model_path = find_latest_finetuned_model() or "." |
|
print(f"Loading model from: {os.path.abspath(model_path)}") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" if device == "cuda" else None, |
|
) |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a creative storyteller."}, |
|
{"role": "user", "content": "Write a short, engaging story about a princess."} |
|
] |
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False) |
|
print("\nApplied chat template:\n", prompt) |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
|
generated_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=500, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
) |
|
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
response = generated_text[len(tokenizer.decode(input_ids[0], skip_special_tokens=True)):] |
|
print("\n--- Generated Story ---\n", response) |
|
print("\nToken count (approximate):", len(generated_ids[0]) - len(input_ids[0])) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|