sibthinon commited on
Commit
80c9031
·
verified ·
1 Parent(s): 6133ede

change to model bge visual

Browse files
Files changed (1) hide show
  1. app.py +47 -39
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
  import time
3
  from datetime import datetime
4
- from sentence_transformers import SentenceTransformer
 
5
  from qdrant_client import QdrantClient
6
  from qdrant_client.models import Filter, FieldCondition, MatchValue
7
  import os
@@ -21,86 +22,91 @@ qdrant_client = QdrantClient(
21
  # Airtable Config
22
  AIRTABLE_API_KEY = os.environ.get("airtable_api")
23
  BASE_ID = os.environ.get("airtable_baseid")
24
- TABLE_NAME = "Feedback_search"
25
- api = Api(AIRTABLE_API_KEY)
26
- table = api.table(BASE_ID, TABLE_NAME)
27
 
28
  # Preload Models
29
- model = SentenceTransformer("BAAI/bge-m3")
30
- collection_name = "product_bge-m3"
31
- threshold = 0.5
 
 
 
 
 
32
 
33
  # Utils
34
- def is_non_thai(text):
35
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
36
 
37
  def normalize(text: str) -> str:
38
- if is_non_thai(text):
39
  return text.strip()
40
- text = unicodedata.normalize("NFC", text)
41
- return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower()
42
 
43
  # Global state
44
- latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""}
45
 
46
  # Search Function
47
  def search_product(query):
48
- yield gr.update(value="🔄 กำลังค้นหา..."), ""
49
 
50
- start_time = time.time()
51
- latest_query_result["raw_query"] = query
52
 
53
- corrected_query = normalize(query)
54
- query_embed = model.encode(corrected_query)
55
 
56
  try:
 
57
  result = qdrant_client.query_points(
58
- collection_name=collection_name,
59
- query=query_embed.tolist(),
60
- with_payload=True,
61
- query_filter=Filter(must=[FieldCondition(key="type", match=MatchValue(value="product"))]),
62
- limit=50
63
  ).points
64
  except Exception as e:
65
- yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>"
66
  return
67
 
68
  if len(result) > 0:
69
  topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
70
  result = result[:topk]
71
 
72
- scored = []
73
  for r in result:
74
- name = str(r.payload.get("name", "")).lower()
75
- brand = str(r.payload.get("brand", "")).lower()
76
- query_lower = corrected_query.lower()
77
 
78
  # ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
79
  if len(corrected_query) >= 3 and name:
80
- fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0
81
- fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0
82
  else:
83
  fuzzy_name_score = 0.0
84
  fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0
85
 
86
  # รวม hybrid score
87
  if fuzzy_name_score < 0.5:
88
- hybrid_score = r.score
89
  else:
90
- hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score
91
  if fuzzy_brand_score >= 0.8:
92
- hybrid_score = hybrid_score*1.2
93
  r.payload["score"] = hybrid_score # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล
94
  r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug
95
  r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug
96
  r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug
97
- scored.append((r, hybrid_score))
98
 
99
  # เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
100
- scored = sorted(scored, key=lambda x: x[1], reverse=True)
101
- result = [r[0] for r in scored]
102
 
103
- elapsed = time.time() - start_time
104
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
105
  if corrected_query != query:
106
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
@@ -108,11 +114,11 @@ def search_product(query):
108
  result_summary, found = "", False
109
 
110
  for res in result:
111
- if res.payload["score"] >= threshold:
112
- found = True
113
  name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
114
  score = f"{res.payload['score']:.4f}"
115
- img_url = res.payload.get("imageUrl", "")
116
  price = res.payload.get("price", "ไม่ระบุ")
117
  brand = res.payload.get("brand", "")
118
 
@@ -146,6 +152,8 @@ def search_product(query):
146
  def log_feedback(feedback):
147
  try:
148
  now = datetime.now().strftime("%Y-%m-%d")
 
 
149
  table.create({
150
  "model": "BGE M3",
151
  "timestamp": now,
 
1
  import gradio as gr
2
  import time
3
  from datetime import datetime
4
+ from visual_bge.modeling import Visualized_BGE
5
+ from huggingface_hub import hf_hub_download
6
  from qdrant_client import QdrantClient
7
  from qdrant_client.models import Filter, FieldCondition, MatchValue
8
  import os
 
22
  # Airtable Config
23
  AIRTABLE_API_KEY = os.environ.get("airtable_api")
24
  BASE_ID = os.environ.get("airtable_baseid")
25
+ TABLE_NAME = "Feedback_search" # use table name
26
+ api = Api(AIRTABLE_API_KEY) # api to airtable
27
+ table = api.table(BASE_ID, TABLE_NAME) # choose table
28
 
29
  # Preload Models
30
+ model_weight = hf_hub_download(repo_id="BAAI/bge-visualized", filename="Visualized_m3.pth")
31
+ # Load model
32
+ model = Visualized_BGE(
33
+ model_name_bge="BAAI/bge-m3",
34
+ model_weight=model_weight
35
+ )
36
+ collection_name = "product_visual_bge" # setup collection name in qdrant
37
+ threshold = 0.5 # threshold use when rerank
38
 
39
  # Utils
40
+ def is_non_thai(text): # check if english retune true
41
  return re.match(r'^[A-Za-z0-9&\-\s]+$', text) is not None
42
 
43
  def normalize(text: str) -> str:
44
+ if is_non_thai(text): # send text to check english
45
  return text.strip()
46
+ text = unicodedata.normalize("NFC", text) # change text to unicode
47
+ return text.replace("เแ", "แ").replace("เเ", "แ").strip().lower() # เเก้กรณีกด เ หลายที
48
 
49
  # Global state
50
+ latest_query_result = {"query": "", "result": "", "raw_query": "", "time": ""} # create for send to airtable
51
 
52
  # Search Function
53
  def search_product(query):
54
+ yield gr.update(value="🔄 กำลังค้นหา..."), "" # when user search
55
 
56
+ start_time = time.time() # start timer
57
+ latest_query_result["raw_query"] = query # collect user qeary
58
 
59
+ corrected_query = normalize(query) # change query to normalize query
60
+ query_embed = model.encode(text=corrected_query)[0] # embed corrected_query to vector
61
 
62
  try:
63
+ #use qdrant search
64
  result = qdrant_client.query_points(
65
+ collection_name=collection_name, # choose collection in qdrant
66
+ query=query_embed.tolist(), # vector query
67
+ with_payload=True, # see payload
68
+ limit=50 # need 50 product
 
69
  ).points
70
  except Exception as e:
71
+ yield gr.update(value="❌ Qdrant error"), f"<p>❌ Qdrant error: {str(e)}</p>" # have problem when search
72
  return
73
 
74
  if len(result) > 0:
75
  topk = 50 # ดึงมา rerank แค่ 50 อันดับแรกจาก Qdrant
76
  result = result[:topk]
77
 
78
+ scored = [] # use to collect product and score
79
  for r in result:
80
+ name = str(r.payload.get("name", "")).lower() # get name in payload and lowercase
81
+ brand = str(r.payload.get("brand", "")).lower() # get brand in payload and lowercase
82
+ query_lower = corrected_query.lower() # lowercase corected_quey
83
 
84
  # ถ้า query สั้นเกินไป ให้ fuzzy_score = 0 เพื่อกันเพี้ยน
85
  if len(corrected_query) >= 3 and name:
86
+ fuzzy_name_score = fuzz.partial_ratio(query_lower, name) / 100.0 # query compare name score
87
+ fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0 # query compare brand score
88
  else:
89
  fuzzy_name_score = 0.0
90
  fuzzy_brand_score = fuzz.partial_ratio(query_lower, brand) / 100.0
91
 
92
  # รวม hybrid score
93
  if fuzzy_name_score < 0.5:
94
+ hybrid_score = r.score # not change qdrant score
95
  else:
96
+ hybrid_score = 0.7 * r.score + 0.3 * fuzzy_name_score # use qdrant score 70% and fuzzy name score 30%
97
  if fuzzy_brand_score >= 0.8:
98
+ hybrid_score = hybrid_score*1.2 # มั่นใจว่าถูกเเบรนด์ เพิ่ม score 120%
99
  r.payload["score"] = hybrid_score # เก็บลง payload ใช้เทียบ treshold ตอนเเสดงผล
100
  r.payload["fuzzy_name_score"] = fuzzy_name_score # เก็บไว้เผื่อ debug
101
  r.payload["fuzzy_brand_score"] = fuzzy_brand_score # เก็บไว้เผื่อ debug
102
  r.payload['semantic_score'] = r.score # เก็บไว้เผื่อ debug
103
+ scored.append((r, hybrid_score)) # collect product and hybrid score
104
 
105
  # เรียงตาม hybrid score แล้วกรองผลลัพธ์ที่ hybrid score ต่ำเกิน
106
+ scored = sorted(scored, key=lambda x: x[1], reverse=True) # sort
107
+ result = [r[0] for r in scored] # collect new sort product
108
 
109
+ elapsed = time.time() - start_time # stop search time
110
  html_output = f"<p>⏱ <strong>{elapsed:.2f} วินาที</strong></p>"
111
  if corrected_query != query:
112
  html_output += f"<p>🔧 แก้คำค้นจาก: <code>{query}</code> → <code>{corrected_query}</code></p>"
 
114
  result_summary, found = "", False
115
 
116
  for res in result:
117
+ if res.payload["score"] >= threshold: # choose only product score more than threshold
118
+ found = True # find product
119
  name = res.payload.get("name", "ไม่ทราบชื่อสินค้า")
120
  score = f"{res.payload['score']:.4f}"
121
+ img_url = res.payload.get("image_url", "")
122
  price = res.payload.get("price", "ไม่ระบุ")
123
  brand = res.payload.get("brand", "")
124
 
 
152
  def log_feedback(feedback):
153
  try:
154
  now = datetime.now().strftime("%Y-%m-%d")
155
+ # create table for send to airtable
156
+ # คอลัมน์ต้องตรงกับบน airtable
157
  table.create({
158
  "model": "BGE M3",
159
  "timestamp": now,