Vx2-3y commited on
Commit
a08b2e2
·
1 Parent(s): f9abbde

Integrate Redis job queue: async job submission, status, and model flexibility

Browse files
Files changed (2) hide show
  1. main.py +65 -4
  2. requirements.txt +4 -1
main.py CHANGED
@@ -5,6 +5,10 @@ import os
5
  import logging
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from huggingface_hub import login
 
 
 
 
8
 
9
  app = FastAPI(
10
  title="NCOS Compliance LLM API",
@@ -63,6 +67,51 @@ except Exception as e:
63
  logger.error(f"Model loading failed: {e}")
64
  ncos_pipeline = None
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # --- Endpoints ---
67
 
68
  @app.post("/infer", response_model=InferResponse, summary="Run model inference", description="Run LLM inference on the input text and return the result.")
@@ -109,8 +158,14 @@ def submit_job(request: QueueRequest):
109
  - **parameters**: Optional model parameters.
110
  Returns a job ID and status.
111
  """
112
- # TODO: Integrate with Redis queue
113
- job_id = "job_123" # Placeholder
 
 
 
 
 
 
114
  return QueueResponse(job_id=job_id, status="queued")
115
 
116
  @app.get("/queue", response_model=QueueResponse, summary="Get job status/result", description="Get the status or result of a queued job by job_id.")
@@ -120,8 +175,14 @@ def get_job_status(job_id: str):
120
  - **job_id**: The job identifier.
121
  Returns the job status and result if available.
122
  """
123
- # TODO: Query Redis for job status/result
124
- return QueueResponse(job_id=job_id, status="pending")
 
 
 
 
 
 
125
 
126
  @app.get("/")
127
  def root():
 
5
  import logging
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from huggingface_hub import login
8
+ import threading
9
+ import time
10
+ import uuid
11
+ import redis
12
 
13
  app = FastAPI(
14
  title="NCOS Compliance LLM API",
 
67
  logger.error(f"Model loading failed: {e}")
68
  ncos_pipeline = None
69
 
70
+ # --- Redis Connection ---
71
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0") # Set your cloud Redis URL in env
72
+ redis_client = redis.Redis.from_url(REDIS_URL)
73
+
74
+ # --- Job Queue Logic ---
75
+ JOB_QUEUE = "ncos_job_queue"
76
+ JOB_RESULT_PREFIX = "ncos_job_result:"
77
+
78
+ # --- Model Cache ---
79
+ model_cache = {"name": None, "pipeline": None}
80
+
81
+ # --- Background Worker Thread ---
82
+ def job_worker():
83
+ while True:
84
+ job_data = redis_client.lpop(JOB_QUEUE)
85
+ if job_data:
86
+ job = eval(job_data) # In production, use json.loads for safety
87
+ job_id = job["job_id"]
88
+ input_text = job["input_text"]
89
+ parameters = job.get("parameters", {})
90
+ model_name = job.get("model_name", "gpt2")
91
+ try:
92
+ # Load model if needed
93
+ if model_cache["name"] != model_name:
94
+ logger.info(f"Loading model for job: {model_name}")
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
96
+ model = AutoModelForCausalLM.from_pretrained(model_name)
97
+ model_cache["pipeline"] = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
98
+ model_cache["name"] = model_name
99
+ pipe = model_cache["pipeline"]
100
+ params = parameters or {}
101
+ params.setdefault("max_new_tokens", 128)
102
+ params.setdefault("temperature", 0.7)
103
+ output = pipe(input_text, **params)
104
+ result_text = output[0]["generated_text"] if output and "generated_text" in output[0] else str(output)
105
+ redis_client.set(JOB_RESULT_PREFIX + job_id, result_text)
106
+ except Exception as e:
107
+ logger.error(f"Job {job_id} failed: {e}")
108
+ redis_client.set(JOB_RESULT_PREFIX + job_id, f"ERROR: {e}")
109
+ else:
110
+ time.sleep(1)
111
+
112
+ # Start background worker thread
113
+ threading.Thread(target=job_worker, daemon=True).start()
114
+
115
  # --- Endpoints ---
116
 
117
  @app.post("/infer", response_model=InferResponse, summary="Run model inference", description="Run LLM inference on the input text and return the result.")
 
158
  - **parameters**: Optional model parameters.
159
  Returns a job ID and status.
160
  """
161
+ job_id = str(uuid.uuid4())
162
+ job = {
163
+ "job_id": job_id,
164
+ "input_text": request.input_text,
165
+ "parameters": request.parameters,
166
+ "model_name": os.getenv("HF_MODEL_NAME", "gpt2") # Allow override per job in future
167
+ }
168
+ redis_client.rpush(JOB_QUEUE, str(job))
169
  return QueueResponse(job_id=job_id, status="queued")
170
 
171
  @app.get("/queue", response_model=QueueResponse, summary="Get job status/result", description="Get the status or result of a queued job by job_id.")
 
175
  - **job_id**: The job identifier.
176
  Returns the job status and result if available.
177
  """
178
+ result = redis_client.get(JOB_RESULT_PREFIX + job_id)
179
+ if result:
180
+ result_str = result.decode("utf-8")
181
+ if result_str.startswith("ERROR:"):
182
+ return QueueResponse(job_id=job_id, status="error", error=result_str)
183
+ return QueueResponse(job_id=job_id, status="done", result=result_str)
184
+ else:
185
+ return QueueResponse(job_id=job_id, status="pending")
186
 
187
  @app.get("/")
188
  def root():
requirements.txt CHANGED
@@ -23,4 +23,7 @@ huggingface_hub==0.23.1
23
  torch==2.2.2
24
 
25
  # Pin numpy to <2 to avoid incompatibility with modules compiled against numpy 1.x
26
- numpy<2
 
 
 
 
23
  torch==2.2.2
24
 
25
  # Pin numpy to <2 to avoid incompatibility with modules compiled against numpy 1.x
26
+ numpy<2
27
+
28
+ # For Redis job queue integration
29
+ redis==5.0.3