File size: 5,286 Bytes
9f5b532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import time
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
from langgraph.graph import StateGraph

# --- Model and Workflow Setup ---
model_id = "ibm-granite/granite-4.0-tiny-preview"

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

def generate_with_granite(prompt: str, max_tokens: int = 200, use_gpu: bool = False) -> str:
    device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
    model.to(device)

    messages = [{"role": "user", "content": prompt}]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt"
    ).to(device)

    outputs = model.generate(
        input_ids=inputs,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )

    generated = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
    return generated.strip()


def select_genre_node(state: dict) -> dict:
    prompt = f"""
You are a creative assistant. The user wants to write a short animated story.
Based on the following input, suggest a suitable genre and tone for the story.
User Input: {state['user_input']}
Respond in this format:
Genre: <genre>
Tone: <tone>
""".strip()
    response = generate_with_granite(prompt)
    genre, tone = None, None
    for line in response.splitlines():
        if "Genre:" in line:
            genre = line.split("Genre:")[1].strip()
        elif "Tone:" in line:
            tone = line.split("Tone:")[1].strip()
    state["genre"] = genre
    state["tone"] = tone
    return state

def generate_outline_node(state: dict) -> dict:
    prompt = f"""
You are a creative writing assistant helping to write a short animated screenplay.
The user wants to write a story with the following details:
Genre: {state.get('genre')}
Tone: {state.get('tone')}
Idea: {state.get('user_input')}
Write a brief plot outline (3–5 sentences) for the story.
""".strip()
    response = generate_with_granite(prompt, max_tokens=250)
    state["outline"] = response
    return state

def generate_scene_node(state: dict) -> dict:
    prompt = f"""
You are a screenwriter.
Based on the following plot outline, write a key scene from the story.
Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film.
Genre: {state.get('genre')}
Tone: {state.get('tone')}
Outline: {state.get('outline')}
Write the scene in prose format (not screenplay format).
""".strip()
    response = generate_with_granite(prompt, max_tokens=300)
    state["scene"] = response
    return state

def write_dialogue_node(state: dict) -> dict:
    prompt = f"""
You are a dialogue writer for an animated screenplay.
Below is a scene from the story:
{state.get('scene')}
Write the dialogue between the characters in screenplay format.
Keep it short, expressive, and suitable for a short animated film.
Use character names (you may invent them if needed), and format as:
CHARACTER:
Dialogue line
CHARACTER:
Dialogue line
""".strip()
    response = generate_with_granite(prompt, max_tokens=100)
    state["dialogue"] = response
    return state

def with_progress(fn, label, index, total):
    def wrapper(state):
        print(f"\n[{index}/{total}] Starting: {label}")
        start = time.time()
        result = fn(state)
        duration = time.time() - start
        print(f"[{index}/{total}] Completed: {label} in {duration:.2f} seconds")
        return result
    return wrapper

def build_workflow():
    graph = StateGraph(dict)
    graph.add_node("select_genre", with_progress(select_genre_node, "Select Genre", 1, 4))
    graph.add_node("generate_outline", with_progress(generate_outline_node, "Generate Outline", 2, 4))
    graph.add_node("generate_scene", with_progress(generate_scene_node, "Generate Scene", 3, 4))
    graph.add_node("write_dialogue", with_progress(write_dialogue_node, "Write Dialogue", 4, 4))

    graph.set_entry_point("select_genre")
    graph.add_edge("select_genre", "generate_outline")
    graph.add_edge("generate_outline", "generate_scene")
    graph.add_edge("generate_scene", "write_dialogue")
    graph.set_finish_point("write_dialogue")

    return graph.compile()

workflow = build_workflow()

# --- Flask App ---
app = Flask(__name__)
CORS(app)

@app.route("/generate-story", methods=["POST"])
def generate_story():
    data = request.get_json()
    user_input = data.get("user_input")

    if not user_input:
        return jsonify({"error": "Missing 'user_input' in request."}), 400

    initial_state = {"user_input": user_input}
    final_state = workflow.invoke(initial_state)

    return jsonify({
        "genre": final_state.get("genre"),
        "tone": final_state.get("tone"),
        "outline": final_state.get("outline"),
        "scene": final_state.get("scene"),
        "dialogue": final_state.get("dialogue")
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), debug=True)