from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI()
# Load Granite 2B model
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
device_map="auto"
)
@app.get("/", response_class=HTMLResponse)
def index():
return """
Granite 2B Summarizer
Granite 2B Summarization Demo
"""
@app.post("/summarize", response_class=HTMLResponse)
def summarize(text: str = Form(...)):
prompt = f"Summarize the following text:\n{text.strip()}\nSummary:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
summary_ids = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
temperature=0.7
)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Remove the prompt from the output
summary = output.replace(prompt, "").strip()
return f"Summary
{summary}
Back"