File size: 5,329 Bytes
8d7f55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
from typing import AsyncIterable, Iterable
from pydantic import BaseModel
from pipecat.frames.frames import CancelFrame, EndFrame, ErrorFrame, Frame, MetricsFrame, StartFrame, StopTaskFrame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.utils import obj_count, obj_id
from loguru import logger
class PipelineParams(BaseModel):
allow_interruptions: bool = False
enable_metrics: bool = False
report_only_initial_ttfb: bool = False
class Source(FrameProcessor):
def __init__(self, up_queue: asyncio.Queue):
super().__init__()
self._up_queue = up_queue
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
match direction:
case FrameDirection.UPSTREAM:
await self._up_queue.put(frame)
case FrameDirection.DOWNSTREAM:
await self.push_frame(frame, direction)
class PipelineTask:
def __init__(self, pipeline: BasePipeline, params: PipelineParams = PipelineParams()):
self.id: int = obj_id()
self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"
self._pipeline = pipeline
self._params = params
self._finished = False
self._down_queue = asyncio.Queue()
self._up_queue = asyncio.Queue()
self._source = Source(self._up_queue)
self._source.link(pipeline)
def has_finished(self):
return self._finished
async def stop_when_done(self):
logger.debug(f"Task {self} scheduled to stop when done")
await self.queue_frame(EndFrame())
async def cancel(self):
logger.debug(f"Canceling pipeline task {self}")
# Make sure everything is cleaned up downstream. This is sent
# out-of-band from the main streaming task which is what we want since
# we want to cancel right away.
await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM)
self._process_down_task.cancel()
self._process_up_task.cancel()
await self._process_down_task
await self._process_up_task
async def run(self):
self._process_up_task = asyncio.create_task(self._process_up_queue())
self._process_down_task = asyncio.create_task(self._process_down_queue())
await asyncio.gather(self._process_up_task, self._process_down_task)
self._finished = True
async def queue_frame(self, frame: Frame):
await self._down_queue.put(frame)
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
if isinstance(frames, AsyncIterable):
async for frame in frames:
await self.queue_frame(frame)
elif isinstance(frames, Iterable):
for frame in frames:
await self.queue_frame(frame)
else:
raise Exception("Frames must be an iterable or async iterable")
def _initial_metrics_frame(self) -> MetricsFrame:
processors = self._pipeline.processors_with_metrics()
ttfb = [{"name": p.name, "time": 0.0} for p in processors]
processing = [{"name": p.name, "time": 0.0} for p in processors]
return MetricsFrame(ttfb=ttfb, processing=processing)
async def _process_down_queue(self):
start_frame = StartFrame(
allow_interruptions=self._params.allow_interruptions,
enable_metrics=self._params.enable_metrics,
report_only_initial_ttfb=self._params.report_only_initial_ttfb
)
await self._source.process_frame(start_frame, FrameDirection.DOWNSTREAM)
await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM)
running = True
should_cleanup = True
while running:
try:
frame = await self._down_queue.get()
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
should_cleanup = not isinstance(frame, StopTaskFrame)
self._down_queue.task_done()
except asyncio.CancelledError:
break
# Cleanup only if we need to.
if should_cleanup:
await self._source.cleanup()
await self._pipeline.cleanup()
# We just enqueue None to terminate the task gracefully.
self._process_up_task.cancel()
await self._process_up_task
async def _process_up_queue(self):
while True:
try:
frame = await self._up_queue.get()
if isinstance(frame, ErrorFrame):
logger.error(f"Error running app: {frame.error}")
await self.queue_frame(CancelFrame())
self._up_queue.task_done()
except asyncio.CancelledError:
break
def __str__(self):
return self.name
|