Spaces:
Sleeping
Sleeping
max context length limit
Browse files- app/main.py +12 -4
- params.cfg +3 -0
app/main.py
CHANGED
|
@@ -23,6 +23,7 @@ config = getconfig("params.cfg")
|
|
| 23 |
RETRIEVER = config.get("retriever", "RETRIEVER")
|
| 24 |
GENERATOR = config.get("generator", "GENERATOR")
|
| 25 |
INGESTOR = config.get("ingestor", "INGESTOR")
|
|
|
|
| 26 |
|
| 27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 28 |
logger = logging.getLogger(__name__)
|
|
@@ -83,7 +84,7 @@ def ingest_node(state: GraphState) -> GraphState:
|
|
| 83 |
try:
|
| 84 |
# Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
|
| 85 |
ingestor_context = client.predict(
|
| 86 |
-
file
|
| 87 |
api_name="/ingest"
|
| 88 |
)
|
| 89 |
|
|
@@ -168,13 +169,20 @@ def generate_node(state: GraphState) -> GraphState:
|
|
| 168 |
retrieved_context = state.get("context", "")
|
| 169 |
ingestor_context = state.get("ingestor_context", "")
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
combined_context = ""
|
| 172 |
if ingestor_context and retrieved_context:
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
| 174 |
elif ingestor_context:
|
| 175 |
-
|
|
|
|
| 176 |
elif retrieved_context:
|
| 177 |
-
combined_context = retrieved_context
|
| 178 |
|
| 179 |
client = Client(GENERATOR)
|
| 180 |
result = client.predict(
|
|
|
|
| 23 |
RETRIEVER = config.get("retriever", "RETRIEVER")
|
| 24 |
GENERATOR = config.get("generator", "GENERATOR")
|
| 25 |
INGESTOR = config.get("ingestor", "INGESTOR")
|
| 26 |
+
MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
|
| 27 |
|
| 28 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 29 |
logger = logging.getLogger(__name__)
|
|
|
|
| 84 |
try:
|
| 85 |
# Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
|
| 86 |
ingestor_context = client.predict(
|
| 87 |
+
file=tmp_file_path,
|
| 88 |
api_name="/ingest"
|
| 89 |
)
|
| 90 |
|
|
|
|
| 169 |
retrieved_context = state.get("context", "")
|
| 170 |
ingestor_context = state.get("ingestor_context", "")
|
| 171 |
|
| 172 |
+
# Limit context size to prevent token overflow
|
| 173 |
+
MAX_CONTEXT_CHARS = int(MAX_CONTEXT_CHARS) # Adjust based on your model's limits
|
| 174 |
+
|
| 175 |
combined_context = ""
|
| 176 |
if ingestor_context and retrieved_context:
|
| 177 |
+
# Prioritize ingestor context, truncate if needed
|
| 178 |
+
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS//2] if len(ingestor_context) > MAX_CONTEXT_CHARS//2 else ingestor_context
|
| 179 |
+
retrieved_truncated = retrieved_context[:MAX_CONTEXT_CHARS//2] if len(retrieved_context) > MAX_CONTEXT_CHARS//2 else retrieved_context
|
| 180 |
+
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}\n\n=== RETRIEVED CONTEXT ===\n{retrieved_truncated}"
|
| 181 |
elif ingestor_context:
|
| 182 |
+
ingestor_truncated = ingestor_context[:MAX_CONTEXT_CHARS] if len(ingestor_context) > MAX_CONTEXT_CHARS else ingestor_context
|
| 183 |
+
combined_context = f"=== UPLOADED DOCUMENT CONTEXT ===\n{ingestor_truncated}"
|
| 184 |
elif retrieved_context:
|
| 185 |
+
combined_context = retrieved_context[:MAX_CONTEXT_CHARS] if len(retrieved_context) > MAX_CONTEXT_CHARS else retrieved_context
|
| 186 |
|
| 187 |
client = Client(GENERATOR)
|
| 188 |
result = client.predict(
|
params.cfg
CHANGED
|
@@ -6,3 +6,6 @@ GENERATOR = giz/chatfed_generator
|
|
| 6 |
|
| 7 |
[ingestor]
|
| 8 |
INGESTOR = mtyrrell/chatfed_ingestor
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
[ingestor]
|
| 8 |
INGESTOR = mtyrrell/chatfed_ingestor
|
| 9 |
+
|
| 10 |
+
[general]
|
| 11 |
+
MAX_CONTEXT_CHARS = 15000
|