File size: 1,854 Bytes
583755a
 
 
 
 
 
9ac0790
 
79a2070
f38d4d9
 
0f569ec
 
 
 
583755a
 
 
 
 
 
 
7a6576c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583755a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import re
import os


hf_token = os.environ.get("HUGGINGFACE_TOKEN")
model_id = "google/gemma-3n-E4B"


cache_dir = "/tmp/hf_cache"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, token=hf_token, cache_dir=cache_dir)


def call_llm(prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = model.generate(**inputs, max_new_tokens=2048)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

visual_prompt = """
You are a data visualization expert. You will be given a summary of a cleaned dataset.
Your tasks:
1. Suggest 3–5 interesting visualizations that would help uncover patterns or relationships.
2. For each, describe what insight it may reveal.
3. For each, write Python code using pandas/seaborn/matplotlib to generate the plot. Use 'df' as the dataframe and be precise with column names.
4. Always be careful and precise with column names 
Output JSON in this exact format:
{
  "visualizations": [
    {
      "title": "Histogram of Age",
      "description": "Shows the distribution of age",
      "code": "sns.histplot(df['age'], kde=True); plt.title('Age Distribution'); plt.savefig('charts/age.png'); plt.clf()"
    },
    ...
  ]
}
Dataset Summary:
{column_data}
"""

def generate_visual_plan(column_data):
    prompt = visual_prompt.format(column_data=json.dumps(column_data, indent=2))
    response = call_llm(prompt)

    match = re.search(r"\{.*\}", response, re.DOTALL)
    if match:
        try:
            parsed = json.loads(match.group(0))
            return parsed["visualizations"]
        except:
            print("Failed to parse visualization JSON.")
            print(response)
    return []