lucy1118's picture
Upload 78 files
8d7f55c verified
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from typing import List
from pipecat.frames.frames import (
Frame,
TextFrame,
VisionImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMResponseStartFrame,
LLMResponseEndFrame,
LLMFullResponseEndFrame
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext, OpenAILLMContextFrame
from loguru import logger
try:
import google.generativeai as gai
import google.ai.generativelanguage as glm
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")
class GoogleLLMService(LLMService):
"""This class implements inference with Google's AI models
This service translates internally from OpenAILLMContext to the messages format
expected by the Google AI model. We are using the OpenAILLMContext as a lingua
franca for all LLM services, so that it is easy to switch between different LLMs.
"""
def __init__(self, *, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
super().__init__(**kwargs)
gai.configure(api_key=api_key)
self._client = gai.GenerativeModel(model)
def can_generate_metrics(self) -> bool:
return True
def _get_messages_from_openai_context(
self, context: OpenAILLMContext) -> List[glm.Content]:
openai_messages = context.get_messages()
google_messages = []
for message in openai_messages:
role = message["role"]
content = message["content"]
if role == "system":
role = "user"
elif role == "assistant":
role = "model"
parts = [glm.Part(text=content)]
if "mime_type" in message:
parts.append(
glm.Part(inline_data=glm.Blob(
mime_type=message["mime_type"],
data=message["data"].getvalue()
)))
google_messages.append({"role": role, "parts": parts})
return google_messages
async def _async_generator_wrapper(self, sync_generator):
for item in sync_generator:
yield item
await asyncio.sleep(0)
async def _process_context(self, context: OpenAILLMContext):
await self.push_frame(LLMFullResponseStartFrame())
try:
logger.debug(f"Generating chat: {context.get_messages_json()}")
messages = self._get_messages_from_openai_context(context)
await self.start_ttfb_metrics()
response = self._client.generate_content(messages, stream=True)
await self.stop_ttfb_metrics()
async for chunk in self._async_generator_wrapper(response):
try:
text = chunk.text
await self.push_frame(LLMResponseStartFrame())
await self.push_frame(TextFrame(text))
await self.push_frame(LLMResponseEndFrame())
except Exception as e:
# Google LLMs seem to flag safety issues a lot!
if chunk.candidates[0].finish_reason == 3:
logger.debug(
f"LLM refused to generate content for safety reasons - {messages}.")
else:
logger.exception(f"{self} error: {e}")
except Exception as e:
logger.exception(f"{self} exception: {e}")
finally:
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
elif isinstance(frame, LLMMessagesFrame):
context = OpenAILLMContext.from_messages(frame.messages)
elif isinstance(frame, VisionImageRawFrame):
context = OpenAILLMContext.from_image_frame(frame)
else:
await self.push_frame(frame, direction)
if context:
await self._process_context(context)