simondh commited on
Commit
e5c1bae
·
1 Parent(s): 36183d4
Files changed (1) hide show
  1. classifiers/llm.py +25 -10
classifiers/llm.py CHANGED
@@ -7,6 +7,11 @@ import random
7
  import json
8
  import asyncio
9
  from typing import List, Dict, Any, Optional
 
 
 
 
 
10
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
11
 
12
  from .base import BaseClassifier
@@ -28,11 +33,16 @@ class LLMClassifier(BaseClassifier):
28
  )
29
 
30
  try:
31
- response = await self.client.chat.completions.create(
32
- model=self.model,
33
- messages=[{"role": "user", "content": prompt}],
34
- temperature=0,
35
- max_tokens=200,
 
 
 
 
 
36
  )
37
 
38
  # Parse JSON response
@@ -87,11 +97,16 @@ class LLMClassifier(BaseClassifier):
87
  prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
88
 
89
  try:
90
- response = await self.client.chat.completions.create(
91
- model=self.model,
92
- messages=[{"role": "user", "content": prompt}],
93
- temperature=0.2,
94
- max_tokens=100,
 
 
 
 
 
95
  )
96
 
97
  # Parse response to get categories
 
7
  import json
8
  import asyncio
9
  from typing import List, Dict, Any, Optional
10
+ import sys
11
+ import os
12
+
13
+ # Add the project root to the Python path
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
16
 
17
  from .base import BaseClassifier
 
33
  )
34
 
35
  try:
36
+ # Use the synchronous client method but run it in a thread pool
37
+ loop = asyncio.get_event_loop()
38
+ response = await loop.run_in_executor(
39
+ None,
40
+ lambda: self.client.chat.completions.create(
41
+ model=self.model,
42
+ messages=[{"role": "user", "content": prompt}],
43
+ temperature=0,
44
+ max_tokens=200,
45
+ )
46
  )
47
 
48
  # Parse JSON response
 
97
  prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
98
 
99
  try:
100
+ # Use the synchronous client method but run it in a thread pool
101
+ loop = asyncio.get_event_loop()
102
+ response = await loop.run_in_executor(
103
+ None,
104
+ lambda: self.client.chat.completions.create(
105
+ model=self.model,
106
+ messages=[{"role": "user", "content": prompt}],
107
+ temperature=0.2,
108
+ max_tokens=100,
109
+ )
110
  )
111
 
112
  # Parse response to get categories