Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L40S
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L40S
	| import base64 | |
| import ctypes | |
| import io | |
| import json | |
| import os | |
| import struct | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import AsyncGenerator, Union | |
| import httpx | |
| import numpy as np | |
| import ormsgpack | |
| import soundfile as sf | |
| from .schema import ( | |
| ServeMessage, | |
| ServeRequest, | |
| ServeTextPart, | |
| ServeVQGANDecodeRequest, | |
| ServeVQGANEncodeRequest, | |
| ServeVQPart, | |
| ) | |
| class CustomAudioFrame: | |
| def __init__(self, data, sample_rate, num_channels, samples_per_channel): | |
| if len(data) < num_channels * samples_per_channel * ctypes.sizeof( | |
| ctypes.c_int16 | |
| ): | |
| raise ValueError( | |
| "data length must be >= num_channels * samples_per_channel * sizeof(int16)" | |
| ) | |
| self._data = bytearray(data) | |
| self._sample_rate = sample_rate | |
| self._num_channels = num_channels | |
| self._samples_per_channel = samples_per_channel | |
| def data(self): | |
| return memoryview(self._data).cast("h") | |
| def sample_rate(self): | |
| return self._sample_rate | |
| def num_channels(self): | |
| return self._num_channels | |
| def samples_per_channel(self): | |
| return self._samples_per_channel | |
| def duration(self): | |
| return self.samples_per_channel / self.sample_rate | |
| def __repr__(self): | |
| return ( | |
| f"CustomAudioFrame(sample_rate={self.sample_rate}, " | |
| f"num_channels={self.num_channels}, " | |
| f"samples_per_channel={self.samples_per_channel}, " | |
| f"duration={self.duration:.3f})" | |
| ) | |
| class FishE2EEventType(Enum): | |
| SPEECH_SEGMENT = 1 | |
| TEXT_SEGMENT = 2 | |
| END_OF_TEXT = 3 | |
| END_OF_SPEECH = 4 | |
| ASR_RESULT = 5 | |
| USER_CODES = 6 | |
| class FishE2EEvent: | |
| type: FishE2EEventType | |
| frame: np.ndarray = None | |
| text: str = None | |
| vq_codes: list[list[int]] = None | |
| client = httpx.AsyncClient( | |
| timeout=None, | |
| limits=httpx.Limits( | |
| max_connections=None, | |
| max_keepalive_connections=None, | |
| keepalive_expiry=None, | |
| ), | |
| ) | |
| class FishE2EAgent: | |
| def __init__(self): | |
| self.llm_url = "http://localhost:8080/v1/chat" | |
| self.vqgan_url = "http://localhost:8080" | |
| self.client = httpx.AsyncClient(timeout=None) | |
| async def get_codes(self, audio_data, sample_rate): | |
| audio_buffer = io.BytesIO() | |
| sf.write(audio_buffer, audio_data, sample_rate, format="WAV") | |
| audio_buffer.seek(0) | |
| # Step 1: Encode audio using VQGAN | |
| encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()]) | |
| encode_request_bytes = ormsgpack.packb( | |
| encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC | |
| ) | |
| encode_response = await self.client.post( | |
| f"{self.vqgan_url}/v1/vqgan/encode", | |
| data=encode_request_bytes, | |
| headers={"Content-Type": "application/msgpack"}, | |
| ) | |
| encode_response_data = ormsgpack.unpackb(encode_response.content) | |
| codes = encode_response_data["tokens"][0] | |
| return codes | |
| async def stream( | |
| self, | |
| system_audio_data: np.ndarray | None, | |
| user_audio_data: np.ndarray | None, | |
| sample_rate: int, | |
| num_channels: int, | |
| chat_ctx: dict | None = None, | |
| ) -> AsyncGenerator[bytes, None]: | |
| if system_audio_data is not None: | |
| sys_codes = await self.get_codes(system_audio_data, sample_rate) | |
| else: | |
| sys_codes = None | |
| if user_audio_data is not None: | |
| user_codes = await self.get_codes(user_audio_data, sample_rate) | |
| # Step 2: Prepare LLM request | |
| if chat_ctx is None: | |
| sys_parts = [ | |
| ServeTextPart( | |
| text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。' | |
| ), | |
| ] | |
| if system_audio_data is not None: | |
| sys_parts.append(ServeVQPart(codes=sys_codes)) | |
| chat_ctx = { | |
| "messages": [ | |
| ServeMessage( | |
| role="system", | |
| parts=sys_parts, | |
| ), | |
| ], | |
| } | |
| else: | |
| if chat_ctx["added_sysaudio"] is False and sys_codes: | |
| chat_ctx["added_sysaudio"] = True | |
| chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes)) | |
| prev_messages = chat_ctx["messages"].copy() | |
| if user_audio_data is not None: | |
| yield FishE2EEvent( | |
| type=FishE2EEventType.USER_CODES, | |
| vq_codes=user_codes, | |
| ) | |
| else: | |
| user_codes = None | |
| request = ServeRequest( | |
| messages=prev_messages | |
| + ( | |
| [ | |
| ServeMessage( | |
| role="user", | |
| parts=[ServeVQPart(codes=user_codes)], | |
| ) | |
| ] | |
| if user_codes | |
| else [] | |
| ), | |
| streaming=True, | |
| num_samples=1, | |
| ) | |
| # Step 3: Stream LLM response and decode audio | |
| buffer = b"" | |
| vq_codes = [] | |
| current_vq = False | |
| async def decode_send(): | |
| nonlocal current_vq | |
| nonlocal vq_codes | |
| data = np.concatenate(vq_codes, axis=1).tolist() | |
| # Decode VQ codes to audio | |
| decode_request = ServeVQGANDecodeRequest(tokens=[data]) | |
| decode_response = await self.client.post( | |
| f"{self.vqgan_url}/v1/vqgan/decode", | |
| data=ormsgpack.packb( | |
| decode_request, | |
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC, | |
| ), | |
| headers={"Content-Type": "application/msgpack"}, | |
| ) | |
| decode_data = ormsgpack.unpackb(decode_response.content) | |
| # Convert float16 audio data to int16 | |
| audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16) | |
| audio_data = (audio_data * 32768).astype(np.int16).tobytes() | |
| audio_frame = CustomAudioFrame( | |
| data=audio_data, | |
| samples_per_channel=len(audio_data) // 2, | |
| sample_rate=44100, | |
| num_channels=1, | |
| ) | |
| yield FishE2EEvent( | |
| type=FishE2EEventType.SPEECH_SEGMENT, | |
| frame=audio_frame, | |
| vq_codes=data, | |
| ) | |
| current_vq = False | |
| vq_codes = [] | |
| async with self.client.stream( | |
| "POST", | |
| self.llm_url, | |
| data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), | |
| headers={"Content-Type": "application/msgpack"}, | |
| ) as response: | |
| async for chunk in response.aiter_bytes(): | |
| buffer += chunk | |
| while len(buffer) >= 4: | |
| read_length = struct.unpack("I", buffer[:4])[0] | |
| if len(buffer) < 4 + read_length: | |
| break | |
| body = buffer[4 : 4 + read_length] | |
| buffer = buffer[4 + read_length :] | |
| data = ormsgpack.unpackb(body) | |
| if data["delta"] and data["delta"]["part"]: | |
| if current_vq and data["delta"]["part"]["type"] == "text": | |
| async for event in decode_send(): | |
| yield event | |
| if data["delta"]["part"]["type"] == "text": | |
| yield FishE2EEvent( | |
| type=FishE2EEventType.TEXT_SEGMENT, | |
| text=data["delta"]["part"]["text"], | |
| ) | |
| elif data["delta"]["part"]["type"] == "vq": | |
| vq_codes.append(np.array(data["delta"]["part"]["codes"])) | |
| current_vq = True | |
| if current_vq and vq_codes: | |
| async for event in decode_send(): | |
| yield event | |
| yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT) | |
| yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH) | |
| # Example usage: | |
| async def main(): | |
| import torchaudio | |
| agent = FishE2EAgent() | |
| # Replace this with actual audio data loading | |
| with open("uz_story_en.m4a", "rb") as f: | |
| audio_data = f.read() | |
| audio_data, sample_rate = torchaudio.load("uz_story_en.m4a") | |
| audio_data = (audio_data.numpy() * 32768).astype(np.int16) | |
| stream = agent.stream(audio_data, sample_rate, 1) | |
| if os.path.exists("audio_segment.wav"): | |
| os.remove("audio_segment.wav") | |
| async for event in stream: | |
| if event.type == FishE2EEventType.SPEECH_SEGMENT: | |
| # Handle speech segment (e.g., play audio or save to file) | |
| with open("audio_segment.wav", "ab+") as f: | |
| f.write(event.frame.data) | |
| elif event.type == FishE2EEventType.ASR_RESULT: | |
| print(event.text, flush=True) | |
| elif event.type == FishE2EEventType.TEXT_SEGMENT: | |
| print(event.text, flush=True, end="") | |
| elif event.type == FishE2EEventType.END_OF_TEXT: | |
| print("\nEnd of text reached.") | |
| elif event.type == FishE2EEventType.END_OF_SPEECH: | |
| print("End of speech reached.") | |
| if __name__ == "__main__": | |
| import asyncio | |
| asyncio.run(main()) | |

