plaggy commited on
Commit
45dffe9
·
1 Parent(s): ff7acb4
Files changed (2) hide show
  1. home.html +1 -1
  2. src/main.py +21 -14
home.html CHANGED
@@ -4,7 +4,7 @@
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>Auto Re-Train</title>
7
- <link rel="stylesheet" href="style.css" />
8
  </head>
9
  <body>
10
  <div class="card">
 
4
  <meta charset="utf-8" />
5
  <meta name="viewport" content="width=device-width" />
6
  <title>Auto Re-Train</title>
7
+ <link rel="stylesheet" href="./style.css" />
8
  </head>
9
  <body>
10
  <div class="card">
src/main.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  import tempfile
8
  import requests
9
 
10
- from fastapi import FastAPI, Header, HTTPException, BackgroundTasks
11
  from fastapi.responses import FileResponse
12
 
13
  from aiohttp import ClientSession
@@ -18,7 +18,6 @@ from tqdm.asyncio import tqdm_asyncio
18
 
19
  from src.models import chunk_config, embed_config, WebhookPayload
20
 
21
-
22
  logging.basicConfig(level=logging.INFO)
23
  logger = logging.getLogger(__name__)
24
 
@@ -41,10 +40,7 @@ async def post_webhook(
41
  if not (
42
  payload.event.action == "update"
43
  and payload.event.scope.startswith("repo.content")
44
- and (
45
- payload.repo.name == embed_config.input_dataset
46
- # or payload.repo.name == chunk_config.input_dataset
47
- )
48
  and payload.repo.type == "dataset"
49
  ):
50
  # no-op
@@ -93,7 +89,7 @@ def chunk_generator(input_dataset, chunker):
93
 
94
  def chunk_dataset():
95
  logger.info("Update detected, chunking is scheduled")
96
- input_ds = load_dataset(chunk_config.input_dataset, split=chunk_config.input_splits)
97
  chunker = Chunker(
98
  strategy=chunk_config.strategy,
99
  split_seq=chunk_config.split_seq,
@@ -123,7 +119,7 @@ def chunk_dataset():
123
  EMBEDDING
124
  """
125
 
126
- async def embed_sent(sentence, semaphore, tei_url, tmp_file):
127
  async with semaphore:
128
  payload = {
129
  "inputs": sentence,
@@ -136,7 +132,7 @@ async def embed_sent(sentence, semaphore, tei_url, tmp_file):
136
  "Authorization": f"Bearer {HF_TOKEN}"
137
  }
138
  ) as session:
139
- async with session.post(tei_url, json=payload) as resp:
140
  if resp.status != 200:
141
  raise RuntimeError(await resp.text())
142
  result = await resp.json()
@@ -146,10 +142,10 @@ async def embed_sent(sentence, semaphore, tei_url, tmp_file):
146
  )
147
 
148
 
149
- async def embed(input_ds, tei_url, temp_file):
150
  semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
151
  jobs = [
152
- asyncio.create_task(embed_sent(row[chunk_config.input_text_col], semaphore, tei_url, temp_file))
153
  for row in input_ds if row[chunk_config.input_text_col].strip()
154
  ]
155
  logger.info(f"num chunks to embed: {len(jobs)}")
@@ -160,20 +156,24 @@ async def embed(input_ds, tei_url, temp_file):
160
 
161
 
162
  def wake_up_endpoint(url):
 
163
  while requests.get(
164
  url=url,
165
  headers={"Authorization": f"Bearer {HF_TOKEN}"}
166
  ).status_code != 200:
167
  time.sleep(2)
 
 
 
168
  logger.info("TEI endpoint is up")
169
 
170
 
171
  def embed_dataset():
172
  logger.info("Update detected, embedding is scheduled")
173
- wake_up_endpoint(embed_config.tei_url)
174
- input_ds = load_dataset(embed_config.input_dataset, split=embed_config.input_splits)
175
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
176
- asyncio.run(embed(input_ds, embed_config.tei_url, temp_file))
177
 
178
  dataset = Dataset.from_json(temp_file.name)
179
  dataset.push_to_hub(
@@ -183,3 +183,10 @@ def embed_dataset():
183
  )
184
 
185
  logger.info("Done embedding")
 
 
 
 
 
 
 
 
7
  import tempfile
8
  import requests
9
 
10
+ from fastapi import FastAPI, BackgroundTasks
11
  from fastapi.responses import FileResponse
12
 
13
  from aiohttp import ClientSession
 
18
 
19
  from src.models import chunk_config, embed_config, WebhookPayload
20
 
 
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
 
40
  if not (
41
  payload.event.action == "update"
42
  and payload.event.scope.startswith("repo.content")
43
+ # and payload.repo.name == chunk_config.input_dataset # any input dataset
 
 
 
44
  and payload.repo.type == "dataset"
45
  ):
46
  # no-op
 
89
 
90
  def chunk_dataset():
91
  logger.info("Update detected, chunking is scheduled")
92
+ input_ds = load_dataset(chunk_config.input_dataset, split="+".join(chunk_config.input_splits))
93
  chunker = Chunker(
94
  strategy=chunk_config.strategy,
95
  split_seq=chunk_config.split_seq,
 
119
  EMBEDDING
120
  """
121
 
122
+ async def embed_sent(sentence, semaphore, tmp_file):
123
  async with semaphore:
124
  payload = {
125
  "inputs": sentence,
 
132
  "Authorization": f"Bearer {HF_TOKEN}"
133
  }
134
  ) as session:
135
+ async with session.post(TEI_URL, json=payload) as resp:
136
  if resp.status != 200:
137
  raise RuntimeError(await resp.text())
138
  result = await resp.json()
 
142
  )
143
 
144
 
145
+ async def embed(input_ds, temp_file):
146
  semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound)
147
  jobs = [
148
+ asyncio.create_task(embed_sent(row[chunk_config.input_text_col], semaphore, temp_file))
149
  for row in input_ds if row[chunk_config.input_text_col].strip()
150
  ]
151
  logger.info(f"num chunks to embed: {len(jobs)}")
 
156
 
157
 
158
  def wake_up_endpoint(url):
159
+ n_loop = 0
160
  while requests.get(
161
  url=url,
162
  headers={"Authorization": f"Bearer {HF_TOKEN}"}
163
  ).status_code != 200:
164
  time.sleep(2)
165
+ n_loop += 1
166
+ if n_loop > 10:
167
+ raise TimeoutError("TEI endpoint is unavailable")
168
  logger.info("TEI endpoint is up")
169
 
170
 
171
  def embed_dataset():
172
  logger.info("Update detected, embedding is scheduled")
173
+ wake_up_endpoint(TEI_URL)
174
+ input_ds = load_dataset(embed_config.input_dataset, split="+".join(chunk_config.input_splits))
175
  with tempfile.NamedTemporaryFile(mode="a", suffix=".jsonl") as temp_file:
176
+ asyncio.run(embed(input_ds, temp_file))
177
 
178
  dataset = Dataset.from_json(temp_file.name)
179
  dataset.push_to_hub(
 
183
  )
184
 
185
  logger.info("Done embedding")
186
+
187
+
188
+ # For debugging
189
+
190
+ # import uvicorn
191
+ # if __name__ == "__main__":
192
+ # uvicorn.run(app, host="0.0.0.0", port=7860)