Vx2-3y commited on
Commit
04ba0b0
·
1 Parent(s): 3fa9baf

Implement cloud-ready model loading and inference logic; update requirements for HF Spaces deployment

Browse files
Files changed (2) hide show
  1. main.py +78 -20
  2. requirements.txt +10 -1
main.py CHANGED
@@ -1,6 +1,10 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional, Any
 
 
 
 
4
 
5
  app = FastAPI(
6
  title="NCOS Compliance LLM API",
@@ -11,40 +15,85 @@ app = FastAPI(
11
  # --- Pydantic models for request/response ---
12
 
13
  class InferRequest(BaseModel):
14
- input_text: str
15
- parameters: Optional[dict] = None # e.g., temperature, max_tokens
16
 
17
  class InferResponse(BaseModel):
18
- result: str
19
- status: str
20
- error: Optional[str] = None
21
 
22
  class QueueRequest(BaseModel):
23
- input_text: str
24
- parameters: Optional[dict] = None
25
 
26
  class QueueResponse(BaseModel):
27
- job_id: str
28
- status: str
29
- result: Optional[str] = None
30
- error: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # --- Endpoints ---
33
 
34
- @app.post("/infer", response_model=InferResponse)
35
  def infer(request: InferRequest):
36
  """
37
  Run model inference on the input text.
 
 
 
38
  """
39
- # Placeholder logic for now
 
 
 
40
  try:
41
- # TODO: Call your model here
42
- output = f"Echo: {request.input_text}"
43
- return InferResponse(result=output, status="success")
 
 
 
 
 
 
 
 
44
  except Exception as e:
 
45
  return InferResponse(result="", status="error", error=str(e))
46
 
47
- @app.get("/healthz")
48
  def healthz():
49
  """
50
  Health check endpoint.
@@ -52,23 +101,32 @@ def healthz():
52
  """
53
  return {"status": "ok"}
54
 
55
- @app.post("/queue", response_model=QueueResponse)
56
  def submit_job(request: QueueRequest):
57
  """
58
  Submit a job to the queue (e.g., Redis).
 
 
 
59
  """
60
  # TODO: Integrate with Redis queue
61
  job_id = "job_123" # Placeholder
62
  return QueueResponse(job_id=job_id, status="queued")
63
 
64
- @app.get("/queue", response_model=QueueResponse)
65
  def get_job_status(job_id: str):
66
  """
67
  Get the status/result of a queued job.
 
 
68
  """
69
  # TODO: Query Redis for job status/result
70
  return QueueResponse(job_id=job_id, status="pending")
71
 
72
  # --- End of API contract skeleton ---
73
 
74
- # FastAPI will auto-generate OpenAPI docs at /docs and /openapi.json
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import Optional, Any
4
+ import os
5
+ import logging
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from huggingface_hub import login
8
 
9
  app = FastAPI(
10
  title="NCOS Compliance LLM API",
 
15
  # --- Pydantic models for request/response ---
16
 
17
  class InferRequest(BaseModel):
18
+ input_text: str # The text to run inference on
19
+ parameters: Optional[dict] = None # Optional model parameters (e.g., temperature, max_tokens)
20
 
21
  class InferResponse(BaseModel):
22
+ result: str # The model's output
23
+ status: str # 'success' or 'error'
24
+ error: Optional[str] = None # Error message if status is 'error'
25
 
26
  class QueueRequest(BaseModel):
27
+ input_text: str # The text to enqueue for inference
28
+ parameters: Optional[dict] = None # Optional model parameters
29
 
30
  class QueueResponse(BaseModel):
31
+ job_id: str # Unique job identifier
32
+ status: str # 'queued', 'pending', 'done', or 'error'
33
+ result: Optional[str] = None # Model output if available
34
+ error: Optional[str] = None # Error message if status is 'error'
35
+
36
+ # --- Model Loading (Cloud-Ready) ---
37
+
38
+ # Read model name and token from environment variables for security
39
+ MODEL_NAME = os.getenv("HF_MODEL_NAME", "ACATECH/ncos") # Default to ACATECH/ncos
40
+ HF_TOKEN = os.getenv("HF_TOKEN") # Should be set in Hugging Face Space secrets
41
+
42
+ # Set up logging
43
+ logging.basicConfig(level=logging.INFO)
44
+ logger = logging.getLogger("ncos-backend")
45
+
46
+ # Login to Hugging Face Hub if token is provided
47
+ if HF_TOKEN:
48
+ try:
49
+ login(token=HF_TOKEN)
50
+ logger.info("Logged in to Hugging Face Hub.")
51
+ except Exception as e:
52
+ logger.error(f"Failed to login to Hugging Face Hub: {e}")
53
+
54
+ # Load model and tokenizer at startup
55
+ try:
56
+ logger.info(f"Loading model: {MODEL_NAME}")
57
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
58
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
59
+ # Use pipeline for simple inference
60
+ ncos_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
61
+ logger.info("Model and tokenizer loaded successfully.")
62
+ except Exception as e:
63
+ logger.error(f"Model loading failed: {e}")
64
+ ncos_pipeline = None
65
 
66
  # --- Endpoints ---
67
 
68
+ @app.post("/infer", response_model=InferResponse, summary="Run model inference", description="Run LLM inference on the input text and return the result.")
69
  def infer(request: InferRequest):
70
  """
71
  Run model inference on the input text.
72
+ - **input_text**: The text to run inference on.
73
+ - **parameters**: Optional model parameters (e.g., temperature, max_tokens).
74
+ Returns the model's output or an error message.
75
  """
76
+ if ncos_pipeline is None:
77
+ # Model failed to load
78
+ logger.error("Inference requested but model is not loaded.")
79
+ return InferResponse(result="", status="error", error="Model not loaded.")
80
  try:
81
+ # Prepare parameters for the pipeline
82
+ params = request.parameters or {}
83
+ # Set sensible defaults if not provided
84
+ params.setdefault("max_new_tokens", 128)
85
+ params.setdefault("temperature", 0.7)
86
+ # Run inference
87
+ logger.info(f"Running inference for input: {request.input_text}")
88
+ output = ncos_pipeline(request.input_text, **params)
89
+ # output is a list of dicts with 'generated_text'
90
+ result_text = output[0]["generated_text"] if output and "generated_text" in output[0] else str(output)
91
+ return InferResponse(result=result_text, status="success")
92
  except Exception as e:
93
+ logger.error(f"Error during inference: {e}")
94
  return InferResponse(result="", status="error", error=str(e))
95
 
96
+ @app.get("/healthz", summary="Health check", description="Check if the backend service is healthy.")
97
  def healthz():
98
  """
99
  Health check endpoint.
 
101
  """
102
  return {"status": "ok"}
103
 
104
+ @app.post("/queue", response_model=QueueResponse, summary="Submit job to queue", description="Submit a job to the Redis queue for asynchronous processing.")
105
  def submit_job(request: QueueRequest):
106
  """
107
  Submit a job to the queue (e.g., Redis).
108
+ - **input_text**: The text to enqueue for inference.
109
+ - **parameters**: Optional model parameters.
110
+ Returns a job ID and status.
111
  """
112
  # TODO: Integrate with Redis queue
113
  job_id = "job_123" # Placeholder
114
  return QueueResponse(job_id=job_id, status="queued")
115
 
116
+ @app.get("/queue", response_model=QueueResponse, summary="Get job status/result", description="Get the status or result of a queued job by job_id.")
117
  def get_job_status(job_id: str):
118
  """
119
  Get the status/result of a queued job.
120
+ - **job_id**: The job identifier.
121
+ Returns the job status and result if available.
122
  """
123
  # TODO: Query Redis for job status/result
124
  return QueueResponse(job_id=job_id, status="pending")
125
 
126
  # --- End of API contract skeleton ---
127
 
128
+ # FastAPI will auto-generate OpenAPI docs at /docs and /openapi.json
129
+ #
130
+ # To test endpoints:
131
+ # - Use curl, httpie, or Postman to send requests to /infer, /healthz, /queue
132
+ # - Visit /docs for interactive API documentation
requirements.txt CHANGED
@@ -11,4 +11,13 @@ supabase
11
  python-dotenv
12
 
13
  # Data validation and settings management
14
- pydantic
 
 
 
 
 
 
 
 
 
 
11
  python-dotenv
12
 
13
  # Data validation and settings management
14
+ pydantic
15
+
16
+ # For loading and running LLMs from Hugging Face
17
+ transformers==4.40.2
18
+
19
+ # For model hub integration
20
+ huggingface_hub==0.23.1
21
+
22
+ # For GPU inference (update if needed for CUDA compatibility)
23
+ torch==2.2.2