|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
from typing import List
|
|
|
|
from pipecat.frames.frames import (
|
|
Frame,
|
|
TextFrame,
|
|
VisionImageRawFrame,
|
|
LLMMessagesFrame,
|
|
LLMFullResponseStartFrame,
|
|
LLMResponseStartFrame,
|
|
LLMResponseEndFrame,
|
|
LLMFullResponseEndFrame
|
|
)
|
|
from pipecat.processors.frame_processor import FrameDirection
|
|
from pipecat.services.ai_services import LLMService
|
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
|
|
|
|
from loguru import logger
|
|
|
|
try:
|
|
import google.generativeai as gai
|
|
import google.ai.generativelanguage as glm
|
|
except ModuleNotFoundError as e:
|
|
logger.error(f"Exception: {e}")
|
|
logger.error(
|
|
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable.")
|
|
raise Exception(f"Missing module: {e}")
|
|
|
|
|
|
class GoogleLLMService(LLMService):
|
|
"""This class implements inference with Google's AI models
|
|
|
|
This service translates internally from OpenAILLMContext to the messages format
|
|
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
|
|
franca for all LLM services, so that it is easy to switch between different LLMs.
|
|
"""
|
|
|
|
def __init__(self, *, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
|
|
super().__init__(**kwargs)
|
|
gai.configure(api_key=api_key)
|
|
self._client = gai.GenerativeModel(model)
|
|
|
|
def can_generate_metrics(self) -> bool:
|
|
return True
|
|
|
|
def _get_messages_from_openai_context(
|
|
self, context: OpenAILLMContext) -> List[glm.Content]:
|
|
openai_messages = context.get_messages()
|
|
google_messages = []
|
|
|
|
for message in openai_messages:
|
|
role = message["role"]
|
|
content = message["content"]
|
|
if role == "system":
|
|
role = "user"
|
|
elif role == "assistant":
|
|
role = "model"
|
|
|
|
parts = [glm.Part(text=content)]
|
|
if "mime_type" in message:
|
|
parts.append(
|
|
glm.Part(inline_data=glm.Blob(
|
|
mime_type=message["mime_type"],
|
|
data=message["data"].getvalue()
|
|
)))
|
|
google_messages.append({"role": role, "parts": parts})
|
|
|
|
return google_messages
|
|
|
|
async def _async_generator_wrapper(self, sync_generator):
|
|
for item in sync_generator:
|
|
yield item
|
|
await asyncio.sleep(0)
|
|
|
|
async def _process_context(self, context: OpenAILLMContext):
|
|
await self.push_frame(LLMFullResponseStartFrame())
|
|
try:
|
|
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
|
|
|
messages = self._get_messages_from_openai_context(context)
|
|
|
|
await self.start_ttfb_metrics()
|
|
|
|
response = self._client.generate_content(messages, stream=True)
|
|
|
|
await self.stop_ttfb_metrics()
|
|
|
|
async for chunk in self._async_generator_wrapper(response):
|
|
try:
|
|
text = chunk.text
|
|
await self.push_frame(LLMResponseStartFrame())
|
|
await self.push_frame(TextFrame(text))
|
|
await self.push_frame(LLMResponseEndFrame())
|
|
except Exception as e:
|
|
|
|
if chunk.candidates[0].finish_reason == 3:
|
|
logger.debug(
|
|
f"LLM refused to generate content for safety reasons - {messages}.")
|
|
else:
|
|
logger.exception(f"{self} error: {e}")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"{self} exception: {e}")
|
|
finally:
|
|
await self.push_frame(LLMFullResponseEndFrame())
|
|
|
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
|
await super().process_frame(frame, direction)
|
|
|
|
context = None
|
|
|
|
if isinstance(frame, OpenAILLMContextFrame):
|
|
context: OpenAILLMContext = frame.context
|
|
elif isinstance(frame, LLMMessagesFrame):
|
|
context = OpenAILLMContext.from_messages(frame.messages)
|
|
elif isinstance(frame, VisionImageRawFrame):
|
|
context = OpenAILLMContext.from_image_frame(frame)
|
|
else:
|
|
await self.push_frame(frame, direction)
|
|
|
|
if context:
|
|
await self._process_context(context)
|
|
|