File size: 5,286 Bytes
9f5b532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import time
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
from langgraph.graph import StateGraph
# --- Model and Workflow Setup ---
model_id = "ibm-granite/granite-4.0-tiny-preview"
processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
def generate_with_granite(prompt: str, max_tokens: int = 200, use_gpu: bool = False) -> str:
device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
model.to(device)
messages = [{"role": "user", "content": prompt}]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(device)
outputs = model.generate(
input_ids=inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
return generated.strip()
def select_genre_node(state: dict) -> dict:
prompt = f"""
You are a creative assistant. The user wants to write a short animated story.
Based on the following input, suggest a suitable genre and tone for the story.
User Input: {state['user_input']}
Respond in this format:
Genre: <genre>
Tone: <tone>
""".strip()
response = generate_with_granite(prompt)
genre, tone = None, None
for line in response.splitlines():
if "Genre:" in line:
genre = line.split("Genre:")[1].strip()
elif "Tone:" in line:
tone = line.split("Tone:")[1].strip()
state["genre"] = genre
state["tone"] = tone
return state
def generate_outline_node(state: dict) -> dict:
prompt = f"""
You are a creative writing assistant helping to write a short animated screenplay.
The user wants to write a story with the following details:
Genre: {state.get('genre')}
Tone: {state.get('tone')}
Idea: {state.get('user_input')}
Write a brief plot outline (3–5 sentences) for the story.
""".strip()
response = generate_with_granite(prompt, max_tokens=250)
state["outline"] = response
return state
def generate_scene_node(state: dict) -> dict:
prompt = f"""
You are a screenwriter.
Based on the following plot outline, write a key scene from the story.
Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film.
Genre: {state.get('genre')}
Tone: {state.get('tone')}
Outline: {state.get('outline')}
Write the scene in prose format (not screenplay format).
""".strip()
response = generate_with_granite(prompt, max_tokens=300)
state["scene"] = response
return state
def write_dialogue_node(state: dict) -> dict:
prompt = f"""
You are a dialogue writer for an animated screenplay.
Below is a scene from the story:
{state.get('scene')}
Write the dialogue between the characters in screenplay format.
Keep it short, expressive, and suitable for a short animated film.
Use character names (you may invent them if needed), and format as:
CHARACTER:
Dialogue line
CHARACTER:
Dialogue line
""".strip()
response = generate_with_granite(prompt, max_tokens=100)
state["dialogue"] = response
return state
def with_progress(fn, label, index, total):
def wrapper(state):
print(f"\n[{index}/{total}] Starting: {label}")
start = time.time()
result = fn(state)
duration = time.time() - start
print(f"[{index}/{total}] Completed: {label} in {duration:.2f} seconds")
return result
return wrapper
def build_workflow():
graph = StateGraph(dict)
graph.add_node("select_genre", with_progress(select_genre_node, "Select Genre", 1, 4))
graph.add_node("generate_outline", with_progress(generate_outline_node, "Generate Outline", 2, 4))
graph.add_node("generate_scene", with_progress(generate_scene_node, "Generate Scene", 3, 4))
graph.add_node("write_dialogue", with_progress(write_dialogue_node, "Write Dialogue", 4, 4))
graph.set_entry_point("select_genre")
graph.add_edge("select_genre", "generate_outline")
graph.add_edge("generate_outline", "generate_scene")
graph.add_edge("generate_scene", "write_dialogue")
graph.set_finish_point("write_dialogue")
return graph.compile()
workflow = build_workflow()
# --- Flask App ---
app = Flask(__name__)
CORS(app)
@app.route("/generate-story", methods=["POST"])
def generate_story():
data = request.get_json()
user_input = data.get("user_input")
if not user_input:
return jsonify({"error": "Missing 'user_input' in request."}), 400
initial_state = {"user_input": user_input}
final_state = workflow.invoke(initial_state)
return jsonify({
"genre": final_state.get("genre"),
"tone": final_state.get("tone"),
"outline": final_state.get("outline"),
"scene": final_state.get("scene"),
"dialogue": final_state.get("dialogue")
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), debug=True)
|