|
import os |
|
import sys |
|
import types |
|
import importlib.machinery |
|
from typing import List, Dict |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
def _make_pkg_stub(fullname: str): |
|
m = types.ModuleType(fullname) |
|
m.__file__ = f"<stub {fullname}>" |
|
m.__package__ = fullname.rpartition('.')[0] |
|
m.__path__ = [] |
|
m.__spec__ = importlib.machinery.ModuleSpec(fullname, loader=None, is_package=True) |
|
sys.modules[fullname] = m |
|
return m |
|
|
|
for name in [ |
|
"flash_attn", |
|
"flash_attn.ops", |
|
"flash_attn.layers", |
|
"flash_attn.functional", |
|
"flash_attn.bert_padding", |
|
"flash_attn.flash_attn_interface", |
|
]: |
|
if name not in sys.modules: |
|
_make_pkg_stub(name) |
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
|
MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base") |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
TASK_TOKENS = { |
|
"caption": "<CAPTION>", |
|
"object_detection": "<OBJECT_DETECTION>", |
|
} |
|
|
|
_processor = None |
|
_model = None |
|
|
|
def get_florence2(): |
|
global _processor, _model |
|
if _processor is None or _model is None: |
|
_processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
_model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
trust_remote_code=True, |
|
attn_implementation="eager", |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device).eval() |
|
_model.config.use_cache = False |
|
return _processor, _model |
|
|
|
@torch.inference_mode() |
|
def florence2_text(image: Image.Image, task: str = "caption"): |
|
proc, mdl = get_florence2() |
|
token = TASK_TOKENS.get(task, "<CAPTION>") |
|
text = token |
|
|
|
|
|
batch = proc(text=text, images=image, return_tensors="pt") |
|
inputs = {} |
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
if v.is_floating_point(): |
|
inputs[k] = v.to(device=device, dtype=mdl.dtype) |
|
else: |
|
inputs[k] = v.to(device=device) |
|
else: |
|
inputs[k] = v |
|
|
|
ids = mdl.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
num_beams=1, |
|
use_cache=False, |
|
early_stopping=False, |
|
eos_token_id=getattr(getattr(proc, "tokenizer", None), "eos_token_id", None), |
|
) |
|
out = proc.batch_decode(ids, skip_special_tokens=True)[0].strip() |
|
if ">" in out: |
|
out = out.split(">", 1)[-1].strip() |
|
return out |
|
|
|
|
|
FOOD_DB = { |
|
"rice": {"kcal":130, "carb_g":28, "protein_g":2.4, "fat_g":0.3, "sodium_mg":0, "cat":"全榖雜糧類", "base_g":150, "tip":"主食可改糙米/全穀增加膳食纖維"}, |
|
"noodles":{"kcal":138, "carb_g":25, "protein_g":4.5, "fat_g":1.9, "sodium_mg":170, "cat":"全榖雜糧類", "base_g":180, "tip":"盡量選清湯少油,避免重鹹湯底"}, |
|
"bread": {"kcal":265, "carb_g":49, "protein_g":9.0, "fat_g":3.2, "sodium_mg":490, "cat":"全榖雜糧類", "base_g":60, "tip":"可選全麥減少抹醬、甜餡"}, |
|
"broccoli":{"kcal":35, "carb_g":7, "protein_g":2.4, "fat_g":0.4, "sodium_mg":33, "cat":"蔬菜類", "base_g":80, "tip":"川燙/清炒保留口感與維生素"}, |
|
"spinach":{"kcal":23, "carb_g":3.6,"protein_g":2.9,"fat_g":0.4,"sodium_mg":70, "cat":"蔬菜類", "base_g":80, "tip":"川燙後快炒,少鹽少油"}, |
|
"chicken":{"kcal":215,"carb_g":0, "protein_g":27, "fat_g":12, "sodium_mg":90, "cat":"豆魚蛋肉類", "base_g":120, "tip":"去皮烹調、烤/氣炸取代油炸"}, |
|
"soy_braised_chicken_leg":{"kcal":220,"carb_g":0,"protein_g":24,"fat_g":12,"sodium_mg":550,"cat":"豆魚蛋肉類","base_g":130,"tip":"減醬油與滷汁、可先汆燙再滷"}, |
|
"salmon":{"kcal":208,"carb_g":0, "protein_g":20, "fat_g":13, "sodium_mg":60, "cat":"豆魚蛋肉類", "base_g":120, "tip":"烤/蒸保留 Omega-3,少鹽少醬"}, |
|
"pork_chop":{"kcal":242,"carb_g":0,"protein_g":27,"fat_g":14,"sodium_mg":75, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少裹粉油炸,改煎烤並瀝油"}, |
|
"tofu": {"kcal":76, "carb_g":1.9,"protein_g":8.1,"fat_g":4.8,"sodium_mg":7, "cat":"豆魚蛋肉類", "base_g":120, "tip":"少勾芡、少滷汁,清蒸清爽"}, |
|
"egg": {"kcal":155,"carb_g":1.1,"protein_g":13, "fat_g":11, "sodium_mg":124, "cat":"豆魚蛋肉類", "base_g":60, "tip":"水煮/荷包少油,避免重鹹醬料"}, |
|
"banana":{"kcal":89, "carb_g":23, "protein_g":1.1,"fat_g":0.3,"sodium_mg":1, "cat":"水果類", "base_g":100, "tip":"控制份量,避免一次過量"}, |
|
"miso_soup":{"kcal":36,"carb_g":4.3,"protein_g":2.0,"fat_g":1.3,"sodium_mg":550, "cat":"湯品/飲品", "base_g":200, "tip":"味噌湯偏鹹,建議少量品嚐"}, |
|
} |
|
|
|
ALIASES = { |
|
"white rice":"rice","steamed rice":"rice","飯":"rice","白飯":"rice", |
|
"麵":"noodles","拉麵":"noodles","麵條":"noodles","義大利麵":"noodles", |
|
"麵包":"bread","吐司":"bread", |
|
"雞肉":"chicken","雞胸":"chicken","烤雞":"chicken", |
|
"滷雞腿":"soy_braised_chicken_leg","醬油雞腿":"soy_braised_chicken_leg", |
|
"鮭魚":"salmon","三文魚":"salmon", |
|
"豬排":"pork_chop", |
|
"豆腐":"tofu", |
|
"蛋":"egg","水煮蛋":"egg","荷包蛋":"egg", |
|
"花椰菜":"broccoli","青花菜":"broccoli","菠菜":"spinach", |
|
"香蕉":"banana","味噌湯":"miso_soup", |
|
} |
|
|
|
RULES = {"T2DM": {"carb_g_per_meal_max": 60}, "HTN": {"sodium_mg_per_meal_max": 600}} |
|
PORTION_MUL = {"小":0.8, "中":1.0, "大":1.2} |
|
|
|
def detect_foods_from_text(text: str) -> List[str]: |
|
lower = text.lower() |
|
labels = set() |
|
for k in FOOD_DB.keys(): |
|
if k in lower: |
|
labels.add(k) |
|
for alias, key in ALIASES.items(): |
|
if alias in text or alias.lower() in lower: |
|
labels.add(key) |
|
return list(labels) |
|
|
|
|
|
import re |
|
DEFAULT_BASE_G = 100 |
|
STOPWORDS = { |
|
"a","an","the","with","and","of","on","in","to","served","over","side","sides", |
|
"plate","bento","box","set","dish","meal","mixed","assorted","fresh","hot","cold", |
|
"grilled","roasted","fried","deep","steamed","boiled","braised","stir","stirred","sautéed", |
|
"sauce","soup","salad","topped","seasoned","style","japanese","taiwanese","korean","chinese", |
|
"便當","套餐","一盤","一碗","配菜","附餐","湯","沙拉","醬","佐","搭配","附","拌","炒","滷","炸","烤","蒸","煮" |
|
} |
|
def extract_food_terms_free(text: str): |
|
parts = re.split(r"(?:,|\.|;|\band\b|\bwith\b|\bserved with\b|\baccompanied by\b|\n)+", text, flags=re.I) |
|
hits = set() |
|
for p in parts: |
|
if not p: continue |
|
toks = re.findall(r"[A-Za-z\u4e00-\u9fff]+", p.lower()) |
|
toks = [w for w in toks if len(w) >= 2 and w not in STOPWORDS] |
|
if not toks: continue |
|
head = toks[-1] |
|
hits.add(ALIASES.get(head, head)) |
|
return list(hits) |
|
|
|
def estimate_weight(name: str, plate_cm: int, portion: str) -> int: |
|
base = FOOD_DB.get(name, {}).get("base_g", DEFAULT_BASE_G) |
|
mul = PORTION_MUL.get(portion, 1.0) |
|
grams = int(base * mul * (plate_cm / 24)) |
|
return max(10, grams) |
|
|
|
def grams_to_nutrition(name: str, grams: int) -> Dict: |
|
info = FOOD_DB[name] |
|
ratio = grams / 100.0 |
|
out = {"name": name, "cat": info["cat"], "weight_g": grams, "tip": info.get("tip","")} |
|
for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): |
|
out[k] = round(info[k] * ratio, 1) |
|
return out |
|
|
|
def make_placeholder_item(name: str, plate_cm: int, portion: str): |
|
grams = int(DEFAULT_BASE_G * (plate_cm / 24) * PORTION_MUL.get(portion, 1.0)) |
|
return { |
|
"name": name, "cat": "未分類", "weight_g": grams, |
|
"kcal": "待新增資訊", "carb_g": "待新增資訊", "protein_g": "待新增資訊", |
|
"fat_g": "待新增資訊", "sodium_mg": "待新增資訊", "tip": "待新增資訊" |
|
} |
|
|
|
def eval_rules(items: List[Dict], conditions: List[str]): |
|
totals = {} |
|
for it in items: |
|
|
|
if isinstance(it.get("kcal"), (int, float)): |
|
for k in ("kcal","carb_g","protein_g","fat_g","sodium_mg"): |
|
totals[k] = round(totals.get(k,0) + float(it[k]), 1) |
|
advice = [] |
|
if "T2DM" in conditions and totals.get("carb_g",0) > RULES["T2DM"]["carb_g_per_meal_max"]: |
|
advice.append("【糖尿病】碳水偏高,建議主食減量或改全穀。") |
|
if "HTN" in conditions and totals.get("sodium_mg",0) > RULES["HTN"]["sodium_mg_per_meal_max"]: |
|
advice.append("【高血壓】鈉含量偏高,少鹽、避免重口味與滷味/湯品。") |
|
cats = {} |
|
for it in items: |
|
cats[it["cat"]] = cats.get(it["cat"], 0) + 1 |
|
return totals, advice, cats |
|
|
|
|
|
def run_pipeline(image, plate_cm, portion, conditions, task_mode, dev_mode): |
|
if image is None: |
|
return "請先上傳一張照片。", "", [], {} |
|
|
|
|
|
if dev_mode: |
|
txt = "A bento with white rice, broccoli and grilled chicken thigh." |
|
else: |
|
t = "caption" if task_mode == "描述 (Caption)" else "object_detection" |
|
txt = florence2_text(image, task=t) |
|
|
|
|
|
labels_known = detect_foods_from_text(txt) |
|
labels_free = extract_food_terms_free(txt) |
|
labels_all = [] |
|
seen = set() |
|
for term in labels_free + labels_known: |
|
key = ALIASES.get(term, term) |
|
if key not in seen: |
|
labels_all.append(key) |
|
seen.add(key) |
|
|
|
|
|
items = [] |
|
for name in labels_all[:6]: |
|
if name in FOOD_DB: |
|
g = estimate_weight(name, plate_cm, portion) |
|
items.append(grams_to_nutrition(name, g)) |
|
else: |
|
items.append(make_placeholder_item(name, plate_cm, portion)) |
|
|
|
totals, advice, cats = eval_rules(items, conditions) |
|
|
|
|
|
lines = [f"模型輸出:{txt}", ""] |
|
if labels_all: |
|
lines.append("偵測到: " + ", ".join(labels_all)) |
|
else: |
|
lines.append("偵測到: (無)") |
|
|
|
lines.append("") |
|
for it in items: |
|
kcal = it['kcal'] if isinstance(it['kcal'], (int, float)) else it['kcal'] |
|
carb = it['carb_g'] if isinstance(it['carb_g'], (int, float)) else it['carb_g'] |
|
prot = it['protein_g'] if isinstance(it['protein_g'], (int, float)) else it['protein_g'] |
|
fat = it['fat_g'] if isinstance(it['fat_g'], (int, float)) else it['fat_g'] |
|
na = it['sodium_mg'] if isinstance(it['sodium_mg'], (int, float)) else it['sodium_mg'] |
|
lines.append(f"- {it['name']} ({it['cat']}) {it['weight_g']} g → " |
|
f"{kcal} kcal, C{carb} g, P{prot} g, F{fat} g, Na{na} mg") |
|
|
|
if totals: |
|
lines.append("") |
|
lines.append(f"總計:{totals.get('kcal',0)} kcal,碳水 {totals.get('carb_g',0)} g,蛋白 {totals.get('protein_g',0)} g,脂肪 {totals.get('fat_g',0)} g,鈉 {totals.get('sodium_mg',0)} mg") |
|
if advice: |
|
lines.append("建議:" + " ".join(advice)) |
|
|
|
return "\n".join(lines), txt, items, totals |
|
|
|
with gr.Blocks(title="FoodAI · Florence-2 Demo") as demo: |
|
gr.Markdown("# 🍱 FoodAI · Florence-2 Demo\n上傳餐點 → 產生描述/偵測 → 估營養/建議\n\n> 開發模式:不跑模型,固定假字串方便測試 UI/流程。") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
img = gr.Image(type="pil", label="上傳圖片") |
|
plate = gr.Slider(18, 28, value=24, step=1, label="盤子直徑 (cm)") |
|
portion = gr.Radio(["小", "中", "大"], value="中", label="份量") |
|
cond = gr.CheckboxGroup(["T2DM", "HTN"], label="狀況") |
|
task_mode = gr.Radio(["描述 (Caption)", "偵測 (Object Detection)"], value="描述 (Caption)", label="任務") |
|
dev_mode = gr.Checkbox(label="開發模式(不跑模型)", value=False) |
|
btn = gr.Button("開始分析", variant="primary") |
|
with gr.Column(scale=1): |
|
out_md = gr.Markdown(label="結果") |
|
raw = gr.Textbox(label="模型原始輸出", lines=4) |
|
js = gr.JSON(label="逐項結果") |
|
total = gr.JSON(label="總計") |
|
|
|
btn.click(run_pipeline, inputs=[img, plate, portion, cond, task_mode, dev_mode], outputs=[out_md, raw, js, total]) |
|
|
|
if __name__ == "__main__": |
|
|
|
PORT = int(os.getenv("PORT", "7860")) |
|
demo.launch(server_name="0.0.0.0", server_port=PORT) |
|
|