File size: 5,329 Bytes
8d7f55c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio

from typing import AsyncIterable, Iterable

from pydantic import BaseModel

from pipecat.frames.frames import CancelFrame, EndFrame, ErrorFrame, Frame, MetricsFrame, StartFrame, StopTaskFrame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.utils import obj_count, obj_id

from loguru import logger


class PipelineParams(BaseModel):
    allow_interruptions: bool = False
    enable_metrics: bool = False
    report_only_initial_ttfb: bool = False


class Source(FrameProcessor):

    def __init__(self, up_queue: asyncio.Queue):
        super().__init__()
        self._up_queue = up_queue

    async def process_frame(self, frame: Frame, direction: FrameDirection):
        await super().process_frame(frame, direction)

        match direction:
            case FrameDirection.UPSTREAM:
                await self._up_queue.put(frame)
            case FrameDirection.DOWNSTREAM:
                await self.push_frame(frame, direction)


class PipelineTask:

    def __init__(self, pipeline: BasePipeline, params: PipelineParams = PipelineParams()):
        self.id: int = obj_id()
        self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"

        self._pipeline = pipeline
        self._params = params
        self._finished = False

        self._down_queue = asyncio.Queue()
        self._up_queue = asyncio.Queue()

        self._source = Source(self._up_queue)
        self._source.link(pipeline)

    def has_finished(self):
        return self._finished

    async def stop_when_done(self):
        logger.debug(f"Task {self} scheduled to stop when done")
        await self.queue_frame(EndFrame())

    async def cancel(self):
        logger.debug(f"Canceling pipeline task {self}")
        # Make sure everything is cleaned up downstream. This is sent
        # out-of-band from the main streaming task which is what we want since
        # we want to cancel right away.
        await self._source.process_frame(CancelFrame(), FrameDirection.DOWNSTREAM)
        self._process_down_task.cancel()
        self._process_up_task.cancel()
        await self._process_down_task
        await self._process_up_task

    async def run(self):
        self._process_up_task = asyncio.create_task(self._process_up_queue())
        self._process_down_task = asyncio.create_task(self._process_down_queue())
        await asyncio.gather(self._process_up_task, self._process_down_task)
        self._finished = True

    async def queue_frame(self, frame: Frame):
        await self._down_queue.put(frame)

    async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
        if isinstance(frames, AsyncIterable):
            async for frame in frames:
                await self.queue_frame(frame)
        elif isinstance(frames, Iterable):
            for frame in frames:
                await self.queue_frame(frame)
        else:
            raise Exception("Frames must be an iterable or async iterable")

    def _initial_metrics_frame(self) -> MetricsFrame:
        processors = self._pipeline.processors_with_metrics()
        ttfb = [{"name": p.name, "time": 0.0} for p in processors]
        processing = [{"name": p.name, "time": 0.0} for p in processors]
        return MetricsFrame(ttfb=ttfb, processing=processing)

    async def _process_down_queue(self):
        start_frame = StartFrame(
            allow_interruptions=self._params.allow_interruptions,
            enable_metrics=self._params.enable_metrics,
            report_only_initial_ttfb=self._params.report_only_initial_ttfb
        )
        await self._source.process_frame(start_frame, FrameDirection.DOWNSTREAM)
        await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM)

        running = True
        should_cleanup = True
        while running:
            try:
                frame = await self._down_queue.get()
                await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
                running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
                should_cleanup = not isinstance(frame, StopTaskFrame)
                self._down_queue.task_done()
            except asyncio.CancelledError:
                break
        # Cleanup only if we need to.
        if should_cleanup:
            await self._source.cleanup()
            await self._pipeline.cleanup()
        # We just enqueue None to terminate the task gracefully.
        self._process_up_task.cancel()
        await self._process_up_task

    async def _process_up_queue(self):
        while True:
            try:
                frame = await self._up_queue.get()
                if isinstance(frame, ErrorFrame):
                    logger.error(f"Error running app: {frame.error}")
                    await self.queue_frame(CancelFrame())
                self._up_queue.task_done()
            except asyncio.CancelledError:
                break

    def __str__(self):
        return self.name