rxavier commited on
Commit
23c3e28
·
1 Parent(s): 0d081dc

Update off_topic.py

Browse files
Files changed (1) hide show
  1. off_topic.py +8 -26
off_topic.py CHANGED
@@ -67,13 +67,12 @@ class OffTopicDetector:
67
  self.image_size = image_size
68
  self.translator = translator
69
 
70
- def predict_probas(self, images: List[PIL.Image.Image], domain: str,
71
  title: Optional[str] = None,
72
  valid_templates: Optional[List[str]] = None,
73
  invalid_classes: Optional[List[str]] = None,
74
  autocast: bool = True):
75
- site, domain = domain.split("-")
76
- domain = re.sub("_", " ", domain).lower()
77
  if valid_templates:
78
  valid_classes = [template.format(domain) for template in valid_templates]
79
  else:
@@ -87,7 +86,7 @@ class OffTopicDetector:
87
  else:
88
  src_lang = "es"
89
  translated_title = self.translator.translate(title, src_lang=src_lang, dest_lang="en", max_length=100)[0]
90
- valid_classes.append(translated_title)
91
  if not invalid_classes:
92
  invalid_classes = ["promotional ad with store information", "promotional text", "google maps screenshot", "business card", "qr code"]
93
 
@@ -130,9 +129,9 @@ class OffTopicDetector:
130
  use_title: bool = False,
131
  valid_templates: Optional[List[str]] = None,
132
  invalid_classes: Optional[List[str]] = None):
133
- images, domain, title = self.get_item_data(url_or_id)
134
  title = title if use_title else None
135
- probas, valid_probas, invalid_probas = self.predict_probas(images, domain, title, valid_templates,
136
  invalid_classes)
137
  return images, domain, probas, valid_probas, invalid_probas
138
 
@@ -146,16 +145,17 @@ class OffTopicDetector:
146
  item_id = re.sub("-", "", url_or_id)
147
  start = time.time()
148
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
149
- domain = response["domain_id"]
150
  title = response["title"]
 
151
  img_urls = [x["url"] for x in response["pictures"]]
152
  img_urls = [x.replace("-O.jpg", f"-{self.image_size}.jpg") for x in img_urls]
 
153
  end = time.time()
154
  duration = end - start
155
  print(f"Items API time: {round(duration * 1000, 0)} ms")
156
  images = self.get_images(img_urls)
157
  dedup_images = self._filter_dups(images)
158
- return dedup_images, domain, title
159
 
160
  def _filter_dups(self, images: List):
161
  if len(images) > 1:
@@ -190,24 +190,6 @@ class OffTopicDetector:
190
  tasks = [_process_download(url, client) for url in urls]
191
  return await asyncio.gather(*tasks)
192
 
193
- @staticmethod
194
- def _non_async_get_item_data(url_or_id: str, save_images: bool = False):
195
- if url_or_id.startswith("http"):
196
- item_id = "".join(url_or_id.split("/")[3].split("-")[:2])
197
- else:
198
- item_id = re.sub("-", "", url_or_id)
199
- response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
200
- domain = re.sub("_", " ", response["domain_id"].split("-")[-1]).lower()
201
- img_urls = [x["url"] for x in response["pictures"]]
202
- images = []
203
- for img_url in img_urls:
204
- img = httpx.get(img_url)
205
- images.append(Image.open(BytesIO(img.content)))
206
- if save_images:
207
- with open(re.sub("D_NQ_NP_", "", img_url.split("/")[-1]) , "wb") as f:
208
- f.write(img.content)
209
- return images, domain
210
-
211
  def show(self, images: List[PIL.Image.Image], valid_probas: np.ndarray, n_cols: int = 3,
212
  title: Optional[str] = None, threshold: Optional[float] = None):
213
  if threshold is not None:
 
67
  self.image_size = image_size
68
  self.translator = translator
69
 
70
+ def predict_probas(self, images: List[PIL.Image.Image], domain: str, site: str,
71
  title: Optional[str] = None,
72
  valid_templates: Optional[List[str]] = None,
73
  invalid_classes: Optional[List[str]] = None,
74
  autocast: bool = True):
75
+ domain = domain.lower()
 
76
  if valid_templates:
77
  valid_classes = [template.format(domain) for template in valid_templates]
78
  else:
 
86
  else:
87
  src_lang = "es"
88
  translated_title = self.translator.translate(title, src_lang=src_lang, dest_lang="en", max_length=100)[0]
89
+ valid_classes.append(translated_title.lower())
90
  if not invalid_classes:
91
  invalid_classes = ["promotional ad with store information", "promotional text", "google maps screenshot", "business card", "qr code"]
92
 
 
129
  use_title: bool = False,
130
  valid_templates: Optional[List[str]] = None,
131
  invalid_classes: Optional[List[str]] = None):
132
+ images, domain, site, title = self.get_item_data(url_or_id)
133
  title = title if use_title else None
134
+ probas, valid_probas, invalid_probas = self.predict_probas(images, domain, site, title, valid_templates,
135
  invalid_classes)
136
  return images, domain, probas, valid_probas, invalid_probas
137
 
 
145
  item_id = re.sub("-", "", url_or_id)
146
  start = time.time()
147
  response = httpx.get(f"https://api.mercadolibre.com/items/{item_id}").json()
 
148
  title = response["title"]
149
+ site, domain = response["domain_id"].split("-")
150
  img_urls = [x["url"] for x in response["pictures"]]
151
  img_urls = [x.replace("-O.jpg", f"-{self.image_size}.jpg") for x in img_urls]
152
+ domain_name = httpx.get(f"https://api.mercadolibre.com/catalog_domains/CBT-{domain}").json()["name"]
153
  end = time.time()
154
  duration = end - start
155
  print(f"Items API time: {round(duration * 1000, 0)} ms")
156
  images = self.get_images(img_urls)
157
  dedup_images = self._filter_dups(images)
158
+ return dedup_images, domain_name, site, title
159
 
160
  def _filter_dups(self, images: List):
161
  if len(images) > 1:
 
190
  tasks = [_process_download(url, client) for url in urls]
191
  return await asyncio.gather(*tasks)
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def show(self, images: List[PIL.Image.Image], valid_probas: np.ndarray, n_cols: int = 3,
194
  title: Optional[str] = None, threshold: Optional[float] = None):
195
  if threshold is not None: