gemma-3-4b-it-antislop-exp62 / test_inference.py
sam-paech's picture
Upload folder using huggingface_hub
82ca1f8 verified
import os
import os
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
def find_latest_finetuned_model() -> str | None:
"""
Look under:
1. <script dir>/results/auto_antislop_runs
2. /results/auto_antislop_runs
Return the most recent `…/finetuned_model*/merged_16bit` directory or None.
"""
candidate_bases = [
Path(__file__).resolve().parent / "results" / "auto_antislop_runs",
Path("/results/auto_antislop_runs"),
]
latest: tuple[float, Path] | None = None
for base in candidate_bases:
if not base.is_dir():
continue
# run_*/finetuned_model*/merged_16bit
for merged_dir in base.glob("run_*/finetuned_model*/merged_16bit"):
if not merged_dir.is_dir():
continue
mtime = merged_dir.parent.stat().st_mtime # use finetuned_model* dir mtime
if latest is None or mtime > latest[0]:
latest = (mtime, merged_dir.resolve())
return str(latest[1]) if latest else None
model_path = find_latest_finetuned_model() or "."
print(f"Loading model from: {os.path.abspath(model_path)}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto" if device == "cuda" else None,
)
messages = [
{"role": "system", "content": "You are a creative storyteller."},
{"role": "user", "content": "Write a short, engaging story about a princess."}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print("\nApplied chat template:\n", prompt)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated_ids = model.generate(
input_ids,
max_new_tokens=500,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
response = generated_text[len(tokenizer.decode(input_ids[0], skip_special_tokens=True)):]
print("\n--- Generated Story ---\n", response)
print("\nToken count (approximate):", len(generated_ids[0]) - len(input_ids[0]))
except Exception as e:
print(f"Error: {e}")