|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses
|
|
|
|
import pipecat.frames.protobufs.frames_pb2 as frame_protos
|
|
|
|
from pipecat.frames.frames import AudioRawFrame, Frame, TextFrame, TranscriptionFrame
|
|
from pipecat.serializers.base_serializer import FrameSerializer
|
|
|
|
from loguru import logger
|
|
|
|
|
|
class ProtobufFrameSerializer(FrameSerializer):
|
|
SERIALIZABLE_TYPES = {
|
|
TextFrame: "text",
|
|
AudioRawFrame: "audio",
|
|
TranscriptionFrame: "transcription"
|
|
}
|
|
|
|
SERIALIZABLE_FIELDS = {v: k for k, v in SERIALIZABLE_TYPES.items()}
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def serialize(self, frame: Frame) -> str | bytes | None:
|
|
proto_frame = frame_protos.Frame()
|
|
if type(frame) not in self.SERIALIZABLE_TYPES:
|
|
raise ValueError(
|
|
f"Frame type {type(frame)} is not serializable. You may need to add it to ProtobufFrameSerializer.SERIALIZABLE_FIELDS.")
|
|
|
|
|
|
proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)]
|
|
for field in dataclasses.fields(frame):
|
|
setattr(getattr(proto_frame, proto_optional_name), field.name,
|
|
getattr(frame, field.name))
|
|
|
|
result = proto_frame.SerializeToString()
|
|
return result
|
|
|
|
def deserialize(self, data: str | bytes) -> Frame | None:
|
|
"""Returns a Frame object from a Frame protobuf. Used to convert frames
|
|
passed over the wire as protobufs to Frame objects used in pipelines
|
|
and frame processors.
|
|
|
|
>>> serializer = ProtobufFrameSerializer()
|
|
>>> serializer.deserialize(
|
|
... serializer.serialize(AudioFrame(data=b'1234567890')))
|
|
AudioFrame(data=b'1234567890')
|
|
|
|
>>> serializer.deserialize(
|
|
... serializer.serialize(TextFrame(text='hello world')))
|
|
TextFrame(text='hello world')
|
|
|
|
>>> serializer.deserialize(serializer.serialize(TranscriptionFrame(
|
|
... text="Hello there!", participantId="123", timestamp="2021-01-01")))
|
|
TranscriptionFrame(text='Hello there!', participantId='123', timestamp='2021-01-01')
|
|
"""
|
|
|
|
proto = frame_protos.Frame.FromString(data)
|
|
which = proto.WhichOneof("frame")
|
|
if which not in self.SERIALIZABLE_FIELDS:
|
|
logger.error("Unable to deserialize a valid frame")
|
|
return None
|
|
|
|
class_name = self.SERIALIZABLE_FIELDS[which]
|
|
args = getattr(proto, which)
|
|
args_dict = {}
|
|
for field in proto.DESCRIPTOR.fields_by_name[which].message_type.fields:
|
|
args_dict[field.name] = getattr(args, field.name)
|
|
|
|
|
|
id = getattr(args, "id")
|
|
name = getattr(args, "name")
|
|
if not id:
|
|
del args_dict["id"]
|
|
if not name:
|
|
del args_dict["name"]
|
|
|
|
|
|
instance = class_name(**args_dict)
|
|
|
|
|
|
if id:
|
|
setattr(instance, "id", getattr(args, "id"))
|
|
if name:
|
|
setattr(instance, "name", getattr(args, "name"))
|
|
|
|
return instance
|
|
|