marcosremar2 commited on
Commit
34b8b49
·
1 Parent(s): c3907b6
.cursor/rules/principal.mdc CHANGED
@@ -3,49 +3,3 @@ description:
3
  globs:
4
  alwaysApply: false
5
  ---
6
- . envia o mínimo possível de arquivos, na verdade tem que baixar mais os arquivos durante a inicialização
7
-
8
- # Resumo do Projeto LLaMA-Omni2 para Hugging Face Spaces
9
-
10
- Estou configurando uma aplicação de demonstração do LLaMA-Omni2, um assistente de linguagem e fala, para ser facilmente implantada no Hugging Face Spaces. Aqui está um resumo do que foi implementado:
11
-
12
- ## Objetivo do Projeto
13
- Criar uma interface web interativa que demonstre as capacidades do LLaMA-Omni2, permitindo aos usuários interagir com o modelo através de texto e fala, recebendo respostas também nos dois formatos.
14
-
15
- ## Componentes Principais
16
-
17
- 1. **Interface Gradio**: Uma interface web amigável com duas abas:
18
- - **Entrada de Áudio**: Permite aos usuários falar ou fazer upload de arquivos de áudio
19
- - **Entrada de Texto**: Permite interações baseadas em texto
20
-
21
- 2. **Pipeline de Reconhecimento de Fala**:
22
- - Usa o modelo Whisper (tiny) para transcrever áudio para texto
23
- - Configurado para carregar diretamente do Hugging Face
24
-
25
- 3. **Geração de Texto e Fala**:
26
- - Usa o modelo LLaMA-Omni2-0.5B para gerar respostas
27
- - Suporta dois métodos de geração de fala: `generate_with_speech` e `generate_speech`
28
- - Gerencia a conversão de respostas de texto para áudio
29
-
30
- 4. **Otimizações para Hugging Face Spaces**:
31
- - Carregamento dinâmico de modelos (não incluídos no repositório)
32
- - Configuração para utilizar GPU quando disponível
33
- - Sistema de logging abrangente para depuração
34
-
35
- 5. **Gestão de Repositório**:
36
- - Arquivo `.gitignore` configurado para excluir modelos grandes e artefatos desnecessários
37
- - Remoção de arquivos grandes do histórico do git
38
- - Estrutura de projeto limpa e organizada
39
-
40
- ## Arquivos Principais
41
- - `app.py`: Contém a lógica principal da aplicação e a interface Gradio
42
- - `requirements.txt`: Lista todas as dependências necessárias
43
- - `.huggingface-space`: Configuração para o ambiente Hugging Face Spaces
44
- - `.gitignore`: Exclui arquivos grandes e temporários do controle de versão
45
-
46
- ## Tecnologias Utilizadas
47
- - **Frameworks**: PyTorch, Transformers, Gradio
48
- - **Modelos**: LLaMA-Omni2-0.5B (para texto/fala), Whisper-tiny (para reconhecimento de fala)
49
- - **Infraestrutura**: Hugging Face Spaces para hospedagem
50
-
51
- O projeto está configurado para baixar os modelos dinamicamente quando implantado, em vez de incluí-los no repositório, resultando em um código limpo e eficiente que pode ser facilmente compartilhado e implantado.
 
3
  globs:
4
  alwaysApply: false
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -20,12 +20,12 @@ var/
20
  .installed.cfg
21
  *.egg
22
 
23
- # Ambientes virtuais
24
  venv/
25
  ENV/
26
  .env/
27
 
28
- # Modelos e dados
29
  models/
30
  *.pt
31
  *.pth
@@ -40,7 +40,8 @@ models/
40
  whisper-large-v3/
41
  cosy2_decoder/
42
  speech_encoder/
43
- # Excluir todos os arquivos grandes de modelos de forma explícita
 
44
  flow.decoder.estimator.fp32.onnx
45
  flow.decoder.estimator.fp16.A10.plan
46
  flow.encoder.fp32.zip
@@ -58,8 +59,9 @@ model.safetensors.index.fp32.json
58
  .idea/
59
  *.swp
60
  *.swo
 
61
 
62
- # Sistema operacional
63
  .DS_Store
64
  Thumbs.db
65
 
@@ -67,7 +69,7 @@ Thumbs.db
67
  logs/
68
  *.log
69
 
70
- # Arquivos grandes
71
  *.dylib
72
  *.js.map
73
  *.so
 
20
  .installed.cfg
21
  *.egg
22
 
23
+ # Environments
24
  venv/
25
  ENV/
26
  .env/
27
 
28
+ # Model files and data
29
  models/
30
  *.pt
31
  *.pth
 
40
  whisper-large-v3/
41
  cosy2_decoder/
42
  speech_encoder/
43
+
44
+ # Ignore all large model files
45
  flow.decoder.estimator.fp32.onnx
46
  flow.decoder.estimator.fp16.A10.plan
47
  flow.encoder.fp32.zip
 
59
  .idea/
60
  *.swp
61
  *.swo
62
+ .cursor/
63
 
64
+ # OS
65
  .DS_Store
66
  Thumbs.db
67
 
 
69
  logs/
70
  *.log
71
 
72
+ # Large files
73
  *.dylib
74
  *.js.map
75
  *.so
.huggingface-space DELETED
@@ -1,9 +0,0 @@
1
- name: llama-omni
2
- sdk: gradio
3
- sdk_version: 5.29.0
4
- python_version: "3.10"
5
- gpu: true
6
- hardware: a100-sxm
7
- datasets:
8
- - openai/whisper-tiny
9
- - ICTNLP/LLaMA-Omni2-0.5B
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,47 +0,0 @@
1
- FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
2
-
3
- WORKDIR /app
4
-
5
- # Instalar dependências do sistema
6
- RUN apt-get update && apt-get install -y \
7
- git \
8
- wget \
9
- ffmpeg \
10
- libsndfile1 \
11
- build-essential \
12
- ninja-build \
13
- && rm -rf /var/lib/apt/lists/*
14
-
15
- # Copiar os arquivos de código
16
- COPY . .
17
-
18
- # Preparar diretório para modelos
19
- RUN mkdir -p models
20
-
21
- # Instalar requisitos Python
22
- RUN pip install --no-cache-dir -r requirements.txt
23
-
24
- # Instalar o LLaMA-Omni2 diretamente (se estiver presente)
25
- RUN if [ -d "./LLaMA-Omni2" ]; then \
26
- cd LLaMA-Omni2 && \
27
- pip install -e . \
28
- ; fi
29
-
30
- # Instalar fairseq se necessário
31
- RUN pip install gitpython
32
- RUN if [ -d "./LLaMA-Omni2" ]; then \
33
- pip install fairseq --no-build-isolation \
34
- ; fi
35
-
36
- # Tentar instalar flash-attention (com tolerância a falhas)
37
- RUN pip install flash-attn --no-build-isolation || echo "Failed to install flash-attn, continuing without it"
38
-
39
- # Expor a porta para o Gradio
40
- EXPOSE 7860
41
-
42
- # Definir variáveis de ambiente
43
- ENV PYTHONUNBUFFERED=1
44
- ENV MODELS_DIR=/app/models
45
-
46
- # Comando para iniciar o servidor
47
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,162 +1,106 @@
1
- ---
2
- title: LLaMA-Omni Demo
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- python_version: "3.10"
8
- sdk_version: "5.29.0"
9
- app_file: app.py
10
- pinned: false
11
- # Considere adicionar hardware se necessário (GPU)
12
- # Ex: hardware: nvidia-t4
13
- ---
14
-
15
- # LLaMA-Omni2 Interface
16
-
17
- Interface para o modelo LLaMA-Omni2, que permite entrada e saída de áudio com processamento de linguagem natural.
18
-
19
- ## Características
20
-
21
- - Transcrição de áudio usando Whisper
22
- - Processamento de texto com LLaMA-Omni2
23
- - Síntese de fala usando CosyVoice 2
24
- - Geração de texto e fala em tempo real
25
- - Download automático de modelos durante a inicialização
26
-
27
- ## Requisitos
28
-
29
- - Python 3.8+
30
- - PyTorch 2.0+
31
- - Transformers 4.36+
32
- - Gradio 3.50+
33
- - CUDA (opcional, mas recomendado para melhor desempenho)
34
-
35
- ## Configuração de Modelos
36
-
37
- Este projeto utiliza um sistema de download automático de modelos durante a inicialização, evitando a necessidade de armazenar arquivos grandes no repositório Git.
38
-
39
- Os modelos serão baixados automaticamente na primeira execução:
40
-
41
- - **Whisper Large V3** - Modelo de reconhecimento de fala
42
- - **CosyVoice 2** - Vocoder para síntese de fala
43
- - **LLaMA-Omni2** - Modelo de linguagem multimodal
44
 
45
- Todos os modelos são armazenados na pasta `models/`, que está no `.gitignore` para evitar o commit de arquivos grandes.
46
 
47
- ## Configuração
48
 
49
- 1. Clone o repositório:
50
- ```bash
51
- git clone https://github.com/seu-usuario/llama-omni2.git
52
- cd llama-omni2
53
- ```
54
-
55
- 2. Instale as dependências:
56
- ```bash
57
- pip install -r requirements.txt
58
- ```
59
 
60
- 3. Execute o aplicativo:
61
- ```bash
62
- python app.py
63
- ```
64
 
65
- Na primeira execução, os modelos serão baixados automaticamente. Isso pode levar algum tempo, dependendo da sua conexão com a internet.
 
 
66
 
67
- ## Uso
68
 
69
- Após iniciar o aplicativo, acesse a interface web em http://localhost:7860 para interagir com o modelo.
 
 
 
 
70
 
71
- - **Entrada de Áudio**: Grave ou faça upload de um arquivo de áudio
72
- - **Saída de Texto**: Veja a transcrição e a resposta do modelo
73
- - **Saída de Áudio**: Ouça a resposta sintetizada
 
 
 
74
 
75
- ## Usando o launcher
 
 
 
 
 
76
 
77
- Você também pode usar o launcher para iniciar a aplicação completa:
 
 
 
78
 
79
- ```bash
80
- python launch_llama_omni2.py
81
- ```
82
 
83
- Opções do launcher:
84
- - `--skip-download`: Pula o download das dependências
85
- - `--extraction-dir`: Define o diretório de extração (padrão: extraction_dir)
86
- - `--models-dir`: Define o diretório de modelos (padrão: models)
87
- - `--controller-only`: Inicia apenas o controlador
88
- - `--worker-only`: Inicia apenas o worker do modelo
89
- - `--gradio-only`: Inicia apenas a interface Gradio
90
 
91
- ## Estrutura do Projeto
92
 
93
- - `app.py` - Aplicativo Gradio principal
94
- - `audio_interface.py` - Interface de áudio para LLaMA-Omni2
95
- - `launch_llama_omni2.py` - Script para lançar todos os componentes
96
- - `model_downloader.py` - Sistema de download automático de modelos
97
- - `models/` - Diretório para armazenar os modelos baixados
98
- - `requirements.txt` - Dependências do projeto
99
 
100
- ## Funcionamento do Download Automático
 
 
 
101
 
102
- O sistema de download automático funciona da seguinte forma:
 
 
 
 
103
 
104
- 1. Na inicialização, o script verifica se os modelos necessários existem localmente
105
- 2. Se um modelo não for encontrado, ele é baixado automaticamente do Hugging Face Hub
106
- 3. Após o download, o modelo é carregado normalmente pelo aplicativo
107
 
108
- Isso permite:
109
- - Manter o repositório Git leve, sem arquivos grandes
110
- - Facilitar a implantação em diferentes ambientes
111
- - Garantir que os usuários sempre tenham os modelos corretos
112
 
113
- ## Modo Sem Download
 
 
114
 
115
- Este projeto suporta um modo "sem download" que permite usar os modelos diretamente do Hugging Face Hub, sem baixá-los localmente. Isso é útil para:
 
 
 
116
 
117
- - Desenvolvimento e testes onde não é necessário baixar os modelos completos
118
- - Ambientes com espaço em disco limitado
119
- - Integração contínua e cenários de implantação onde os modelos são acessados remotamente
120
 
121
- Para ativar o modo sem download, você pode:
122
 
123
- 1. **Usar o script Python no_download.py (recomendado)**:
124
- ```bash
125
- # Executar app.py sem download
126
- python no_download.py app.py
127
-
128
- # Executar outro script sem download
129
- python no_download.py audio_interface.py
130
- ```
131
 
132
- 2. **Usar o script auxiliar**:
133
- ```bash
134
- ./run_without_downloads.sh
135
- ```
136
 
137
- 3. **Definir a variável de ambiente**:
138
- ```bash
139
- export NO_DOWNLOAD=1
140
- python app.py
141
- ```
142
 
143
- 4. **Usar a opção de linha de comando no launcher**:
144
- ```bash
145
- python launch_llama_omni2.py --no-model-download
146
- ```
147
 
148
- No modo sem download, o aplicativo usará os modelos diretamente do Hugging Face Hub, sem baixar arquivos localmente. Isso pode ser mais lento para uso contínuo, mas é mais rápido para inicializar e não ocupa espaço em disco.
149
 
150
- ## Contribuição
151
 
152
- Contribuições são bem-vindas! Por favor, siga estas diretrizes:
 
 
153
 
154
- 1. Faça um fork do repositório
155
- 2. Crie um branch para sua feature (`git checkout -b feature/nova-feature`)
156
- 3. Faça commit das suas mudanças (`git commit -am 'Adiciona nova feature'`)
157
- 4. Faça push para o branch (`git push origin feature/nova-feature`)
158
- 5. Crie um novo Pull Request
159
 
160
- ## Licença
161
 
162
- Este projeto está licenciado sob os termos da licença MIT.
 
1
+ # 🦙🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ This is a Gradio deployment of [LLaMA-Omni](https://github.com/ictnlp/LLaMA-Omni), a speech-language model built upon Llama-3.1-8B-Instruct. It supports low-latency and high-quality speech interactions, simultaneously generating both text and speech responses based on speech instructions.
4
 
5
+ ## 💡 Highlights
6
 
7
+ * 💪 **Built on Llama-3.1-8B-Instruct, ensuring high-quality responses.**
8
+ * 🚀 **Low-latency speech interaction with a latency as low as 226ms.**
9
+ * 🎧 **Simultaneous generation of both text and speech responses.**
 
 
 
 
 
 
 
10
 
11
+ ## 📋 Prerequisites
 
 
 
12
 
13
+ - Python 3.10+
14
+ - PyTorch 2.0+
15
+ - CUDA-compatible GPU (for optimal performance)
16
 
17
+ ## 🛠️ Setup
18
 
19
+ 1. Clone this repository:
20
+ ```bash
21
+ git clone https://github.com/your-username/llama-omni.git
22
+ cd llama-omni
23
+ ```
24
 
25
+ 2. Create a virtual environment and install dependencies:
26
+ ```bash
27
+ conda create -n llama-omni python=3.10
28
+ conda activate llama-omni
29
+ pip install -e .
30
+ ```
31
 
32
+ 3. Install fairseq:
33
+ ```bash
34
+ git clone https://github.com/pytorch/fairseq
35
+ cd fairseq
36
+ pip install -e . --no-build-isolation
37
+ ```
38
 
39
+ 4. Install flash-attention:
40
+ ```bash
41
+ pip install flash-attn --no-build-isolation
42
+ ```
43
 
44
+ ## 🚀 Deployment
 
 
45
 
46
+ This repository is configured for deployment on Gradio. The model weights and required components will be downloaded automatically during the first initialization.
 
 
 
 
 
 
47
 
48
+ ### Gradio Spaces Deployment
49
 
50
+ To deploy on Gradio Spaces:
 
 
 
 
 
51
 
52
+ 1. Create a new Gradio Space
53
+ 2. Connect this GitHub repository
54
+ 3. Set the environment requirements (Python 3.10)
55
+ 4. Deploy!
56
 
57
+ The app will automatically:
58
+ - Download the required models (Whisper, LLaMA-Omni, vocoder)
59
+ - Start the controller
60
+ - Start the model worker
61
+ - Launch the web interface
62
 
63
+ ## 🖥️ Local Usage
 
 
64
 
65
+ If you want to run the application locally:
 
 
 
66
 
67
+ ```bash
68
+ python app.py
69
+ ```
70
 
71
+ This will:
72
+ 1. Start the controller
73
+ 2. Start a model worker that loads LLaMA-Omni
74
+ 3. Launch a web interface
75
 
76
+ You can then access the interface at: http://localhost:8000
 
 
77
 
78
+ ## 📝 Example Usage
79
 
80
+ ### Speech-to-Speech
 
 
 
 
 
 
 
81
 
82
+ 1. Select the "Speech Input" tab
83
+ 2. Record or upload audio
84
+ 3. Click "Submit"
85
+ 4. Receive both text and speech responses
86
 
87
+ ### Text-to-Speech
 
 
 
 
88
 
89
+ 1. Select the "Text Input" tab
90
+ 2. Type your message
91
+ 3. Click "Submit"
92
+ 4. Receive both text and speech responses
93
 
94
+ ## 📚 Development
95
 
96
+ To contribute to this project:
97
 
98
+ 1. Fork the repository
99
+ 2. Make your changes
100
+ 3. Submit a pull request
101
 
102
+ ## 📄 LICENSE
 
 
 
 
103
 
104
+ This code is released under the Apache-2.0 License. The model is intended for academic research purposes only and may **NOT** be used for commercial purposes.
105
 
106
+ Original work by Qingkai Fang, Shoutao Guo, Yan Zhou, Zhengrui Ma, Shaolei Zhang, Yang Feng.
SETUP_INSTRUCTIONS.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLaMA-Omni Setup Instructions
2
+
3
+ This repository contains the code structure for deploying LLaMA-Omni on Gradio. The actual model files will be downloaded automatically during deployment.
4
+
5
+ ## Repository Structure
6
+
7
+ ```
8
+ llama-omni/
9
+ ├── app.py # Main application entry point
10
+ ├── app_gradio_spaces.py # Entry point for Gradio Spaces
11
+ ├── check_setup.py # Checks if the environment is properly set up
12
+ ├── cog.yaml # Configuration for Cog (container deployment)
13
+ ├── gradio_app.py # Simplified Gradio app for testing
14
+ ├── predict.py # Predictor for Cog deployment
15
+ ├── pyproject.toml # Project configuration
16
+ ├── requirements.txt # Dependencies for pip
17
+ ├── README.md # Project documentation
18
+ ├── SETUP_INSTRUCTIONS.md # This file
19
+ └── omni_speech/ # Main package
20
+ ├── __init__.py
21
+ ├── infer/ # Inference code
22
+ │ ├── __init__.py
23
+ │ ├── examples/ # Example inputs
24
+ │ │ └── example.json
25
+ │ ├── inference.py # Inference logic
26
+ │ └── run.sh # Script for running inference
27
+ └── serve/ # Serving code
28
+ ├── __init__.py
29
+ ├── controller.py # Controller for managing workers
30
+ ├── model_worker.py # Worker for serving the model
31
+ └── gradio_web_server.py # Gradio web interface
32
+ ```
33
+
34
+ ## Deployment Options
35
+
36
+ 1. **Gradio Spaces**:
37
+ - Connect this repository to a Gradio Space
38
+ - The application will automatically download required models
39
+ - Use `app_gradio_spaces.py` as the entry point
40
+
41
+ 2. **Local Deployment**:
42
+ - Clone this repository
43
+ - Install dependencies: `pip install -r requirements.txt`
44
+ - Run the application: `python app.py`
45
+
46
+ 3. **Container Deployment with Cog**:
47
+ - Install Cog: `curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m``
48
+ - Build the container: `cog build`
49
+ - Run the container: `cog predict -i [email protected]`
50
+
51
+ ## Important Notes
52
+
53
+ - The actual model files are not included in this repository
54
+ - During deployment, the application will download:
55
+ - Whisper speech recognition model
56
+ - LLaMA-Omni model (simulated in this setup)
57
+ - HiFi-GAN vocoder
58
+
59
+ ## Testing the Setup
60
+
61
+ Run the setup check script to verify your environment:
62
+
63
+ ```bash
64
+ python check_setup.py
65
+ ```
66
+
67
+ This will check for required directories, files, and Python packages.
app.py CHANGED
@@ -1,377 +1,132 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  import os
5
- import warnings
6
- import importlib
7
- import sys
8
  import subprocess
9
- import numpy as np
10
- import tempfile
11
- import soundfile as sf
12
- import logging
13
- import huggingface_hub
14
- from huggingface_hub import snapshot_download
15
-
16
- # Configurar logging
17
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
18
- logger = logging.getLogger(__name__)
19
-
20
- # Verificar modo sem download (primeiro, antes de importar model_downloader)
21
- NO_DOWNLOAD = os.environ.get("NO_DOWNLOAD", "0").lower() in ("1", "true", "yes")
22
- logger.info(f"Inicializando app.py com NO_DOWNLOAD={NO_DOWNLOAD} (valor da env: {os.environ.get('NO_DOWNLOAD', 'não definido')})")
23
-
24
- # Import do novo model_downloader
25
- try:
26
- from model_downloader import download_model_if_needed, download_all_models, get_model_repo_id, NO_DOWNLOAD as DOWNLOADER_NO_DOWNLOAD
27
- # Verificar se os valores são consistentes
28
- if NO_DOWNLOAD != DOWNLOADER_NO_DOWNLOAD:
29
- logger.warning(f"Inconsistência detectada: NO_DOWNLOAD no app.py={NO_DOWNLOAD}, mas NO_DOWNLOAD no model_downloader.py={DOWNLOADER_NO_DOWNLOAD}")
30
- # Atualizar para o valor no model_downloader.py
31
- NO_DOWNLOAD = DOWNLOADER_NO_DOWNLOAD
32
- except ImportError:
33
- logger.warning("model_downloader não pôde ser importado, trabalhando sem ele")
34
- # Definir funções vazias para manter compatibilidade
35
- def download_model_if_needed(model_key): return False
36
- def download_all_models(): pass
37
- def get_model_repo_id(model_key): return None
38
-
39
- # Configuração do caminho para os modelos
40
- MODELS_DIR = os.environ.get("MODELS_DIR", "models")
41
- os.makedirs(MODELS_DIR, exist_ok=True)
42
-
43
- # --- Model Configuration ---
44
- whisper_model_id = "openai/whisper-tiny"
45
- llama_omni_model_id = "ICTNLP/LLaMA-Omni2-0.5B" # Modelo específico que queremos usar
46
- HF_TOKEN = os.environ.get("HF_TOKEN", None) # Token para acessar modelos privados, se necessário
47
-
48
- # --- Device Configuration ---
49
- if torch.cuda.is_available():
50
- device_for_pipelines = 0 # Use the first GPU for Hugging Face pipelines
51
- torch_device = "cuda:0" # PyTorch device string
52
- dtype_for_pipelines = torch.float16
53
- else:
54
- device_for_pipelines = -1 # Use CPU for Hugging Face pipelines
55
- torch_device = "cpu"
56
- dtype_for_pipelines = torch.float32
57
-
58
- logger.info(f"Using device: {torch_device} for model loading.")
59
- logger.info(f"Pipelines will use device_id: {device_for_pipelines} and dtype: {dtype_for_pipelines}")
60
-
61
- # --- Check Download Mode ---
62
- if NO_DOWNLOAD:
63
- logger.warning("Modo NO_DOWNLOAD ativado. Os modelos não serão baixados, usando diretamente do Hugging Face Hub.")
64
- # Usar IDs dos modelos diretamente do Hugging Face
65
- whisper_repo_id = get_model_repo_id("speech_encoder") or "openai/whisper-large-v3"
66
- llama_omni_repo_id = get_model_repo_id("llama_omni2") or llama_omni_model_id
67
-
68
- # Definir caminhos para modelo
69
- whisper_path_to_use = whisper_repo_id
70
- model_path_to_use = llama_omni_repo_id
71
-
72
- logger.info(f"Usando modelo whisper direto do HF: {whisper_path_to_use}")
73
- logger.info(f"Usando modelo LLaMA-Omni2 direto do HF: {model_path_to_use}")
74
- else:
75
- # --- Download Models if Needed ---
76
- logger.info("Verificando se os modelos estão disponíveis localmente...")
77
-
78
- # Download do modelo de speech recognition (Whisper)
79
- download_model_if_needed("speech_encoder")
80
-
81
- # Download do modelo de síntese de voz
82
- download_model_if_needed("cosy2_decoder")
83
-
84
- # Download do modelo LLaMA-Omni2
85
- download_model_if_needed("llama_omni2")
86
-
87
- # Configurar caminhos para modelos locais
88
- whisper_local_path = os.path.join(MODELS_DIR, "speech_encoder", "whisper-large-v3")
89
- whisper_path_to_use = whisper_local_path if os.path.exists(whisper_local_path) else whisper_model_id
90
-
91
- local_model_path = os.path.join(MODELS_DIR, "LLaMA-Omni2-0.5B")
92
- model_path_to_use = local_model_path if os.path.exists(local_model_path) and os.path.isdir(local_model_path) else llama_omni_model_id
93
-
94
- # --- Load Speech-to-Text (ASR) Pipeline ---
95
- asr_pipeline_instance = None
96
- try:
97
- logger.info(f"Loading ASR model: {whisper_path_to_use}...")
98
-
99
- asr_pipeline_instance = pipeline(
100
- "automatic-speech-recognition",
101
- model=whisper_path_to_use,
102
- torch_dtype=dtype_for_pipelines,
103
- device=device_for_pipelines
104
- )
105
- logger.info(f"ASR model loaded successfully.")
106
- except Exception as e:
107
- logger.error(f"Error loading ASR model: {e}")
108
- asr_pipeline_instance = None
109
-
110
- # --- Load Text Generation Model ---
111
- text_gen_pipeline_instance = None
112
- text_generation_model_id = None # Will be set to the model that successfully loads
113
-
114
- try:
115
- logger.info(f"Attempting to load LLaMA-Omni2 model: {model_path_to_use}...")
116
- # LLaMA models often require specific loading configurations
117
- tokenizer = AutoTokenizer.from_pretrained(
118
- model_path_to_use,
119
- trust_remote_code=True,
120
- use_fast=False,
121
- token=HF_TOKEN
122
- )
123
-
124
- model = AutoModelForCausalLM.from_pretrained(
125
- model_path_to_use,
126
- torch_dtype=dtype_for_pipelines,
127
- trust_remote_code=True,
128
- device_map="auto" if torch.cuda.is_available() else None,
129
- low_cpu_mem_usage=True,
130
- token=HF_TOKEN
131
- )
132
-
133
- # Check if this is a specialized Omni2 model with audio capabilities
134
- is_omni2_speech_model = hasattr(model, "generate_with_speech") or hasattr(model, "generate_speech")
135
 
136
- text_gen_pipeline_instance = pipeline(
137
- "text-generation",
138
- model=model,
139
- tokenizer=tokenizer,
140
- torch_dtype=dtype_for_pipelines,
141
- device=device_for_pipelines if not torch.cuda.is_available() else None
142
- )
143
- text_generation_model_id = llama_omni_model_id
144
- logger.info(f"LLaMA-Omni2 model loaded successfully.")
145
- logger.info(f"Model has speech generation capabilities: {is_omni2_speech_model}")
146
 
147
- except Exception as e:
148
- logger.error(f"Error loading LLaMA-Omni2 model: {e}")
149
- logger.error("Não foi possível carregar o modelo LLaMA-Omni2. Verifique se o modelo está disponível ou se há erro nas configurações.")
150
- text_gen_pipeline_instance = None
151
-
152
- # --- Core Functions ---
153
- def transcribe_audio_input(audio_filepath):
154
- if not asr_pipeline_instance:
155
- return "ASR model not available. Please check startup logs.", ""
156
- if audio_filepath is None:
157
- return "No audio file provided for transcription.", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  try:
159
- logger.info(f"Transcribing: {audio_filepath}")
160
- result = asr_pipeline_instance(audio_filepath, chunk_length_s=30)
161
- transcribed_text = result["text"]
162
- logger.info(f"Transcription: '{transcribed_text}'")
163
- return transcribed_text, transcribed_text
164
- except Exception as e:
165
- logger.error(f"Transcription error: {e}")
166
- return f"Error during transcription: {str(e)}", ""
167
-
168
- def generate_text_response(prompt_text):
169
- """Generate both text and speech response if possible"""
170
- if not text_gen_pipeline_instance:
171
- logger.error("Text generation model not available for response generation")
172
- return f"Text generation model not available. Check logs.", None
173
- if not prompt_text or not prompt_text.strip():
174
- return "Prompt is empty. Please provide text for generation.", None
175
 
176
- try:
177
- logger.info(f"Generating response for prompt (first 100 chars): '{prompt_text[:100]}...'")
178
-
179
- # Try to use special speech generation if available
180
- model = text_gen_pipeline_instance.model
181
-
182
- # Check if model has speech generation capability
183
- if hasattr(model, "generate_with_speech") or hasattr(model, "generate_speech"):
184
- try:
185
- # Prepare inputs
186
- inputs = text_gen_pipeline_instance.tokenizer(prompt_text, return_tensors="pt").to(model.device)
187
-
188
- # Generate with speech
189
- if hasattr(model, "generate_with_speech"):
190
- logger.info("Using generate_with_speech method")
191
- outputs = model.generate_with_speech(
192
- **inputs,
193
- max_new_tokens=150,
194
- do_sample=True,
195
- temperature=0.7,
196
- top_p=0.9
197
- )
198
- text_response = text_gen_pipeline_instance.tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True)
199
- audio_data = outputs.get("speech_output", None)
200
- elif hasattr(model, "generate_speech"):
201
- logger.info("Using generate_speech method")
202
- # Text generation first
203
- output_ids = model.generate(
204
- **inputs,
205
- max_new_tokens=150,
206
- do_sample=True,
207
- temperature=0.7,
208
- top_p=0.9
209
- )
210
- text_response = text_gen_pipeline_instance.tokenizer.decode(output_ids[0], skip_special_tokens=True)
211
-
212
- # Then speech generation
213
- audio_data = model.generate_speech(output_ids)
214
-
215
- # Save audio if we got it
216
- if audio_data is not None:
217
- audio_path = save_audio_to_temp_file(audio_data)
218
- return text_response, audio_path
219
- else:
220
- logger.warning("No audio data was generated")
221
- return text_response, None
222
-
223
- except Exception as speech_error:
224
- logger.error(f"Error generating speech with LLaMA-Omni2: {speech_error}")
225
- logger.info("Falling back to text-only generation")
226
-
227
- # Parameters optimized for LLaMA-Omni2 text-only generation
228
- logger.info("Using text-only generation")
229
- generated_outputs = text_gen_pipeline_instance(
230
- prompt_text,
231
- max_new_tokens=150,
232
- do_sample=True,
233
- temperature=0.7,
234
- top_p=0.9,
235
- num_return_sequences=1
236
- )
237
-
238
- response_text = generated_outputs[0]["generated_text"]
239
- logger.info(f"Generated text-only response with length: {len(response_text)}")
240
- return response_text, None
241
- except Exception as e:
242
- logger.error(f"Text generation error: {e}")
243
- return f"Error during text generation: {str(e)}", None
244
-
245
- def save_audio_to_temp_file(audio_data):
246
- """Save audio data to a temporary file and return the path"""
247
- try:
248
- # Create a temporary file
249
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
250
- temp_path = tmp_file.name
251
-
252
- # Convert audio data to the right format if needed and save
253
- if isinstance(audio_data, np.ndarray):
254
- # Assuming sample rate of 16000 Hz, which is common for speech models
255
- sf.write(temp_path, audio_data, 16000)
256
- elif isinstance(audio_data, torch.Tensor):
257
- # Convert tensor to numpy array
258
- audio_np = audio_data.cpu().numpy()
259
- sf.write(temp_path, audio_np, 16000)
260
- else:
261
- print(f"Unknown audio data type: {type(audio_data)}")
262
- return None
263
-
264
- print(f"Audio saved to temporary file: {temp_path}")
265
- return temp_path
266
- except Exception as e:
267
- print(f"Error saving audio to file: {e}")
268
- return None
269
-
270
- def combined_pipeline_process(audio_filepath):
271
- if audio_filepath is None:
272
- return "No audio input.", "No audio input.", None
273
-
274
- transcribed_text, _ = transcribe_audio_input(audio_filepath)
275
-
276
- if not asr_pipeline_instance or "Error during transcription" in transcribed_text or not transcribed_text.strip():
277
- error_msg_for_generation = "Cannot generate response: Transcription failed or was empty."
278
- if not asr_pipeline_instance:
279
- error_msg_for_generation = "Cannot generate response: ASR model not loaded."
280
- return transcribed_text, error_msg_for_generation, None
281
-
282
- if not text_gen_pipeline_instance:
283
- return transcribed_text, f"Cannot generate response: No text generation model available.", None
284
 
285
- final_response, audio_path = generate_text_response(transcribed_text)
286
- return transcribed_text, final_response, audio_path
287
-
288
- # Determine model status for UI
289
- if text_generation_model_id == llama_omni_model_id:
290
- llama_model_status = "LLaMA-Omni2-0.5B loaded successfully"
291
- using_model = "LLaMA-Omni2-0.5B"
292
- else:
293
- llama_model_status = "Failed to load LLaMA-Omni2 model"
294
- using_model = "No model available"
295
-
296
- # --- Gradio Interface Definition ---
297
- with gr.Blocks(theme=gr.themes.Soft(), title="Whisper + LLaMA-Omni2 Demo") as app_interface:
298
- gr.Markdown(
299
- f"""
300
- # Speech-to-Text and Text/Speech Generation Demo
301
 
302
- Esta aplicação usa **OpenAI Whisper Tiny** para reconhecimento de fala e **LLaMA-Omni2-0.5B** para geração de texto e fala.
 
303
 
304
- **Modelo em uso:** {using_model}
 
305
 
306
- Envie um arquivo de áudio para transcrevê-lo. O texto transcrito será então usado como prompt para o modelo de geração de texto/fala.
307
- """
308
- )
309
-
310
- with gr.Tab("Pipeline Completo: Áudio -> Transcrição -> Geração"):
311
- gr.Markdown("### Etapa 1: Envie Áudio -> Etapa 2: Transcrição -> Etapa 3: Geração de Texto/Fala")
312
- input_audio_pipeline = gr.Audio(type="filepath", label="Envie seu arquivo de áudio (.wav, .mp3)")
313
- submit_button_full = gr.Button("Executar Processo Completo", variant="primary")
314
- output_transcription_pipeline = gr.Textbox(label="Texto Transcrito (do Whisper)", lines=5)
315
- model_label = f"Texto Gerado (do {using_model})"
316
- output_generation_pipeline = gr.Textbox(label=model_label, lines=7)
317
- output_audio_pipeline = gr.Audio(label="Fala Gerada (se disponível)", visible=True)
318
 
319
- submit_button_full.click(
320
- fn=combined_pipeline_process,
321
- inputs=[input_audio_pipeline],
322
- outputs=[output_transcription_pipeline, output_generation_pipeline, output_audio_pipeline]
323
- )
324
-
325
- with gr.Tab("Testar Reconhecimento de Fala (Whisper Tiny)"):
326
- gr.Markdown("### Transcreva áudio para texto usando Whisper Tiny.")
327
- input_audio_asr = gr.Audio(type="filepath", label="Envie Áudio para Reconhecimento")
328
- submit_button_asr = gr.Button("Transcrever Áudio", variant="secondary")
329
- output_transcription_asr = gr.Textbox(label="Resultado da Transcrição", lines=10)
330
-
331
- def asr_only_ui(audio_file):
332
- if audio_file is None: return "Por favor, envie um arquivo de áudio."
333
- transcription, _ = transcribe_audio_input(audio_file)
334
- return transcription
335
-
336
- submit_button_asr.click(
337
- fn=asr_only_ui,
338
- inputs=[input_audio_asr],
339
- outputs=[output_transcription_asr]
340
- )
341
-
342
- with gr.Tab(f"Testar Geração de Texto/Fala"):
343
- model_name_gen = using_model
344
- gr.Markdown(f"### Gere texto e fala a partir de um prompt usando {model_name_gen}.")
345
- input_text_prompt_gen = gr.Textbox(label="Seu Prompt de Texto", placeholder="Digite seu texto aqui...", lines=5)
346
- submit_button_gen = gr.Button("Gerar Texto e Fala", variant="secondary")
347
- output_generation_gen = gr.Textbox(label="Resultado do Texto Gerado", lines=10)
348
- output_audio_gen = gr.Audio(label="Fala Gerada (se disponível)")
349
-
350
- def text_generation_ui(prompt):
351
- if not prompt or not prompt.strip():
352
- return "Por favor, forneça um prompt primeiro.", None
353
- response_text, audio_path = generate_text_response(prompt)
354
- return response_text, audio_path
355
 
356
- submit_button_gen.click(
357
- fn=text_generation_ui,
358
- inputs=[input_text_prompt_gen],
359
- outputs=[output_generation_gen, output_audio_gen]
360
- )
361
-
362
- gr.Markdown("--- ")
363
- gr.Markdown("### Status do Carregamento do Modelo (na inicialização do aplicativo):")
364
- asr_load_status = "Carregado com sucesso" if asr_pipeline_instance else "Falha ao carregar (verifique os logs)"
365
 
366
- gr.Markdown(f"* **Modelo Whisper ({whisper_model_id}):** `{asr_load_status}`")
367
- gr.Markdown(f"* **Modelo LLaMA-Omni2 ({llama_omni_model_id}):** `{llama_model_status}`")
 
 
 
 
 
 
368
 
369
- # --- Launch the Gradio App ---
370
  if __name__ == "__main__":
371
- print("Launching Gradio demo...")
372
- try:
373
- app_interface.launch(share=True, server_name="0.0.0.0")
374
- except Exception as e:
375
- print(f"Error launching with share=True: {e}")
376
- print("Trying to launch without sharing...")
377
- app_interface.launch(server_name="0.0.0.0")
 
 
 
 
1
  import os
 
 
 
2
  import subprocess
3
+ import threading
4
+ import time
5
+ import gradio as gr
6
+ import whisper
7
+ import requests
8
+
9
+ # Configuration
10
+ MODEL_NAME = "Llama-3.1-8B-Omni"
11
+ CONTROLLER_PORT = 10000
12
+ WEB_SERVER_PORT = 8000
13
+ MODEL_WORKER_PORT = 40000
14
+
15
+ # Paths
16
+ VOCODER_PATH = "vocoder/g_00500000"
17
+ VOCODER_CFG = "vocoder/config.json"
18
+
19
+ def download_models():
20
+ """Ensure that required models are available"""
21
+ os.makedirs("models/speech_encoder", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Download Whisper model if needed (this will happen during deployment)
24
+ print("Setting up Whisper model...")
25
+ whisper.load_model("large-v3", download_root="models/speech_encoder/")
 
 
 
 
 
 
 
26
 
27
+ # Download vocoder if needed
28
+ if not os.path.exists(VOCODER_PATH):
29
+ print("Downloading vocoder...")
30
+ os.makedirs("vocoder", exist_ok=True)
31
+ subprocess.run([
32
+ "wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000",
33
+ "-P", "vocoder/"
34
+ ])
35
+ subprocess.run([
36
+ "wget", "https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json",
37
+ "-P", "vocoder/"
38
+ ])
39
+
40
+ def start_controller():
41
+ """Start the controller process"""
42
+ print("Starting controller...")
43
+ controller_process = subprocess.Popen([
44
+ "python", "-m", "omni_speech.serve.controller",
45
+ "--host", "0.0.0.0",
46
+ "--port", str(CONTROLLER_PORT)
47
+ ])
48
+ time.sleep(5) # Wait for controller to start
49
+ return controller_process
50
+
51
+ def start_model_worker():
52
+ """Start the model worker process"""
53
+ print("Starting model worker...")
54
+ worker_process = subprocess.Popen([
55
+ "python", "-m", "omni_speech.serve.model_worker",
56
+ "--host", "0.0.0.0",
57
+ "--controller", f"http://localhost:{CONTROLLER_PORT}",
58
+ "--port", str(MODEL_WORKER_PORT),
59
+ "--worker", f"http://localhost:{MODEL_WORKER_PORT}",
60
+ "--model-path", MODEL_NAME,
61
+ "--model-name", MODEL_NAME,
62
+ "--s2s"
63
+ ])
64
+ time.sleep(10) # Wait for model worker to start
65
+ return worker_process
66
+
67
+ def start_web_server():
68
+ """Start the web server process"""
69
+ print("Starting web server...")
70
+ web_process = subprocess.Popen([
71
+ "python", "-m", "omni_speech.serve.gradio_web_server",
72
+ "--controller", f"http://localhost:{CONTROLLER_PORT}",
73
+ "--port", str(WEB_SERVER_PORT),
74
+ "--model-list-mode", "reload",
75
+ "--vocoder", VOCODER_PATH,
76
+ "--vocoder-cfg", VOCODER_CFG
77
+ ])
78
+ return web_process
79
+
80
+ def check_services():
81
+ """Check if all services are running"""
82
  try:
83
+ controller_resp = requests.get(f"http://localhost:{CONTROLLER_PORT}/status").json()
84
+ web_server_resp = requests.get(f"http://localhost:{WEB_SERVER_PORT}/").status_code
85
+ return controller_resp["status"] == "ok" and web_server_resp == 200
86
+ except Exception:
87
+ return False
88
+
89
+ def main():
90
+ # Download required models
91
+ download_models()
 
 
 
 
 
 
 
92
 
93
+ # Start all services
94
+ controller = start_controller()
95
+ worker = start_model_worker()
96
+ web_server = start_web_server()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # Create a simple redirection interface
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# 🦙🎧 LLaMA-Omni")
101
+ gr.Markdown("## Starting LLaMA-Omni services...")
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ with gr.Row():
104
+ status = gr.Textbox(value="Initializing...", label="Status")
105
 
106
+ with gr.Row():
107
+ redirect_btn = gr.Button("Go to LLaMA-Omni Interface")
108
 
109
+ def update_status():
110
+ if check_services():
111
+ return "All services running! Click the button below to access the interface."
112
+ else:
113
+ return "Still starting services... Please wait."
 
 
 
 
 
 
 
114
 
115
+ def redirect():
116
+ return gr.Redirect(f"http://localhost:{WEB_SERVER_PORT}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Update status every 5 seconds
119
+ demo.load(update_status, outputs=status, every=5)
120
+ redirect_btn.click(redirect)
 
 
 
 
 
 
121
 
122
+ # Launch the Gradio interface
123
+ try:
124
+ demo.launch(server_name="0.0.0.0")
125
+ finally:
126
+ # Clean up processes when Gradio is closed
127
+ controller.terminate()
128
+ worker.terminate()
129
+ web_server.terminate()
130
 
 
131
  if __name__ == "__main__":
132
+ main()
 
 
 
 
 
 
app.yaml DELETED
@@ -1,23 +0,0 @@
1
- sdk: docker
2
- build_config:
3
- gpu: true
4
- cuda: "11.8"
5
- python_version: "3.10"
6
- system_packages:
7
- - "ffmpeg"
8
- - "libsndfile1"
9
- - "build-essential"
10
- - "ninja-build"
11
- - "git"
12
- resources:
13
- gpu: A10G
14
- cpu: 4
15
- memory: "30G"
16
- disk: "10G"
17
- models:
18
- - "openai/whisper-tiny"
19
- - "ICTNLP/LLaMA-Omni2-0.5B"
20
- secrets:
21
- - name: HF_TOKEN
22
- help: "Token de autenticação do Hugging Face (opcional)"
23
- required: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_gradio_spaces.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import threading
5
+ import time
6
+ import gradio as gr
7
+
8
+ def run_background_process(cmd, name):
9
+ """Run a background process and return the process object."""
10
+ print(f"Starting {name}...")
11
+ process = subprocess.Popen(
12
+ cmd,
13
+ stdout=subprocess.PIPE,
14
+ stderr=subprocess.STDOUT,
15
+ text=True,
16
+ bufsize=1,
17
+ universal_newlines=True,
18
+ shell=True
19
+ )
20
+ return process
21
+
22
+ def read_process_output(process, output_box, name):
23
+ """Read and update the output from a process."""
24
+ full_output = f"### {name} Output:\n\n"
25
+ for line in process.stdout:
26
+ full_output += line
27
+ output_box.update(value=full_output)
28
+
29
+ # Process ended
30
+ return_code = process.wait()
31
+ full_output += f"\n\nProcess exited with code {return_code}"
32
+ output_box.update(value=full_output)
33
+
34
+ def setup_environment():
35
+ """Set up the environment by installing dependencies and downloading models."""
36
+ # Create necessary directories
37
+ os.makedirs("models/speech_encoder", exist_ok=True)
38
+ os.makedirs("vocoder", exist_ok=True)
39
+
40
+ # Download whisper model
41
+ os.system("pip install openai-whisper>=20231117")
42
+ os.system("pip install fairseq==0.12.2")
43
+
44
+ # Download vocoder
45
+ if not os.path.exists("vocoder/g_00500000"):
46
+ os.system("wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000 -P vocoder/")
47
+ os.system("wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json -P vocoder/")
48
+
49
+ # Initialize Whisper (it will be downloaded automatically)
50
+ os.system("python -c \"import whisper; whisper.load_model('large-v3', download_root='models/speech_encoder/')\"")
51
+
52
+ return "✅ Environment setup complete!"
53
+
54
+ def start_services(controller_output, model_worker_output, web_server_output):
55
+ """Start the controller, model worker, and web server."""
56
+ # Start the controller
57
+ controller_process = run_background_process(
58
+ "python -m omni_speech.serve.controller --host 0.0.0.0 --port 10000",
59
+ "Controller"
60
+ )
61
+
62
+ # Start a thread to read controller output
63
+ controller_thread = threading.Thread(
64
+ target=read_process_output,
65
+ args=(controller_process, controller_output, "Controller"),
66
+ daemon=True
67
+ )
68
+ controller_thread.start()
69
+
70
+ # Wait for controller to start
71
+ time.sleep(5)
72
+
73
+ # Start the model worker
74
+ model_worker_process = run_background_process(
75
+ "python -m omni_speech.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path Llama-3.1-8B-Omni --model-name Llama-3.1-8B-Omni --s2s",
76
+ "Model Worker"
77
+ )
78
+
79
+ # Start a thread to read model worker output
80
+ model_worker_thread = threading.Thread(
81
+ target=read_process_output,
82
+ args=(model_worker_process, model_worker_output, "Model Worker"),
83
+ daemon=True
84
+ )
85
+ model_worker_thread.start()
86
+
87
+ # Wait for model worker to start
88
+ time.sleep(10)
89
+
90
+ # Start the web server
91
+ web_server_process = run_background_process(
92
+ "python -m omni_speech.serve.gradio_web_server --controller http://localhost:10000 --port 8001 --model-list-mode reload --vocoder vocoder/g_00500000 --vocoder-cfg vocoder/config.json",
93
+ "Web Server"
94
+ )
95
+
96
+ # Start a thread to read web server output
97
+ web_server_thread = threading.Thread(
98
+ target=read_process_output,
99
+ args=(web_server_process, web_server_output, "Web Server"),
100
+ daemon=True
101
+ )
102
+ web_server_thread.start()
103
+
104
+ # Wait for web server to start
105
+ time.sleep(5)
106
+
107
+ return "✅ All services started! Click the 'Open Interface' button below."
108
+
109
+ def build_ui():
110
+ """Build the Gradio UI."""
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# 🦙🎧 LLaMA-Omni Deployment")
113
+
114
+ with gr.Tab("Setup"):
115
+ setup_btn = gr.Button("Setup Environment")
116
+ setup_output = gr.Textbox(label="Setup Output", value="Click 'Setup Environment' to start.")
117
+ setup_btn.click(setup_environment, outputs=setup_output)
118
+
119
+ with gr.Tab("Services"):
120
+ start_btn = gr.Button("Start LLaMA-Omni Services")
121
+ status_output = gr.Textbox(label="Status", value="Click 'Start LLaMA-Omni Services' to begin.")
122
+
123
+ with gr.Accordion("Service Logs", open=False):
124
+ controller_output = gr.Markdown("Controller not started")
125
+ model_worker_output = gr.Markdown("Model Worker not started")
126
+ web_server_output = gr.Markdown("Web Server not started")
127
+
128
+ start_btn.click(
129
+ start_services,
130
+ inputs=[],
131
+ outputs=[status_output, controller_output, model_worker_output, web_server_output]
132
+ )
133
+
134
+ interface_btn = gr.Button("Open Interface")
135
+ interface_btn.click(lambda: gr.update(value="http://localhost:8001"), None, None)
136
+
137
+ with gr.Tab("About"):
138
+ gr.Markdown("""
139
+ # About LLaMA-Omni
140
+
141
+ LLaMA-Omni is a speech-language model built upon Llama-3.1-8B-Instruct. It supports low-latency and high-quality speech interactions, simultaneously generating both text and speech responses based on speech instructions.
142
+
143
+ ## Features
144
+
145
+ * Built on Llama-3.1-8B-Instruct, ensuring high-quality responses
146
+ * Low-latency speech interaction with a latency as low as 226ms
147
+ * Simultaneous generation of both text and speech responses
148
+
149
+ ## License
150
+
151
+ This code is released under the Apache-2.0 License. The model is intended for academic research purposes only and may NOT be used for commercial purposes.
152
+
153
+ Original work by Qingkai Fang, Shoutao Guo, Yan Zhou, Zhengrui Ma, Shaolei Zhang, Yang Feng.
154
+ """)
155
+
156
+ return demo
157
+
158
+ if __name__ == "__main__":
159
+ demo = build_ui()
160
+ demo.launch(server_port=7860)
audio_interface.py DELETED
@@ -1,451 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Audio interface for LLaMA-Omni2 that accepts audio input and returns audio output.
4
- This interface:
5
- 1. Transcribes audio input using Whisper
6
- 2. Processes the transcription with LLaMA-Omni2 model
7
- 3. Synthesizes the response back to audio using CosyVoice 2
8
-
9
- Enhanced with streaming generation and read-write scheduling for real-time response.
10
- """
11
-
12
- import os
13
- import sys
14
- import argparse
15
- import logging
16
- import time
17
- import asyncio
18
- import tempfile
19
- from pathlib import Path
20
- from queue import Queue
21
- from threading import Thread
22
- import json
23
-
24
- import torch
25
- import torchaudio
26
- import gradio as gr
27
- import whisper
28
- import aiohttp
29
- import numpy as np
30
-
31
- # Import model downloader
32
- try:
33
- from model_downloader import download_model_if_needed, download_all_models, get_model_repo_id, NO_DOWNLOAD
34
- has_model_downloader = True
35
- except ImportError:
36
- has_model_downloader = False
37
- NO_DOWNLOAD = False
38
-
39
- # Configure logging
40
- logging.basicConfig(level=logging.INFO)
41
- logger = logging.getLogger(__name__)
42
-
43
- class AudioInterface:
44
- def __init__(
45
- self,
46
- controller_url: str,
47
- whisper_model_path: str,
48
- vocoder_dir: str,
49
- model_name: str = "LLaMA-Omni2-7B-Bilingual",
50
- read_tokens: int = 3,
51
- write_tokens: int = 10
52
- ):
53
- self.controller_url = controller_url
54
- self.whisper_model_path = whisper_model_path
55
- self.vocoder_dir = vocoder_dir
56
- self.model_name = model_name
57
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
58
-
59
- # Read-write scheduling parameters for streaming generation
60
- self.read_tokens = read_tokens # Number of text tokens to read
61
- self.write_tokens = write_tokens # Number of speech tokens to write
62
-
63
- # Download required models if needed
64
- self._ensure_models_available()
65
-
66
- # Load Whisper model
67
- try:
68
- # Se NO_DOWNLOAD estiver ativado, usar diretamente o modelo do Hugging Face
69
- if has_model_downloader and NO_DOWNLOAD:
70
- whisper_model_path = "openai/whisper-large-v3"
71
- logger.info(f"Modo NO_DOWNLOAD: Carregando Whisper direto do Hugging Face: {whisper_model_path}")
72
-
73
- logger.info(f"Loading Whisper model from {whisper_model_path}")
74
- self.whisper_model = whisper.load_model("large-v3",
75
- download_root=whisper_model_path if not NO_DOWNLOAD else None,
76
- device=self.device)
77
- logger.info("Whisper model loaded successfully")
78
- except Exception as e:
79
- logger.error(f"Failed to load Whisper model: {e}")
80
- self.whisper_model = None
81
-
82
- # Load CosyVoice vocoder
83
- try:
84
- # Se NO_DOWNLOAD estiver ativado, usar diretamente o modelo do Hugging Face
85
- if has_model_downloader and NO_DOWNLOAD:
86
- logger.warning("Modo NO_DOWNLOAD ativado. O vocoder CosyVoice pode não funcionar corretamente sem os arquivos locais.")
87
-
88
- sys.path.insert(0, vocoder_dir)
89
- from cosy_voice_2.inference import CosyVoice
90
-
91
- self.vocoder = CosyVoice(
92
- device=self.device,
93
- model_path=vocoder_dir
94
- )
95
- logger.info(f"CosyVoice vocoder loaded from {vocoder_dir}")
96
- except Exception as e:
97
- logger.error(f"Failed to load CosyVoice vocoder: {e}")
98
- self.vocoder = None
99
-
100
- logger.info(f"Using LLaMA-Omni2 model: {model_name}")
101
-
102
- def _ensure_models_available(self):
103
- """Garante que os modelos necessários estão disponíveis"""
104
- # Verificar se temos o model_downloader disponível
105
- if has_model_downloader:
106
- if NO_DOWNLOAD:
107
- logger.info("Modo NO_DOWNLOAD ativado. Pulando verificação de modelos locais.")
108
- return
109
-
110
- logger.info("Verificando modelos necessários...")
111
-
112
- # Baixar modelo Whisper
113
- download_model_if_needed("speech_encoder")
114
-
115
- # Baixar modelo CosyVoice
116
- download_model_if_needed("cosy2_decoder")
117
-
118
- logger.info("Verificação de modelos concluída")
119
- else:
120
- logger.warning("model_downloader não está disponível. Assumindo que os modelos já estão disponíveis localmente.")
121
-
122
- async def get_worker_address(self):
123
- """Get the address of the worker serving the model"""
124
- try:
125
- async with aiohttp.ClientSession() as session:
126
- async with session.get(
127
- f"{self.controller_url}/get_worker_address?model_name={self.model_name}",
128
- timeout=30
129
- ) as response:
130
- if response.status == 200:
131
- data = await response.json()
132
- return data.get("address")
133
- else:
134
- logger.error(f"Failed to get worker address: {await response.text()}")
135
- return None
136
- except Exception as e:
137
- logger.error(f"Error getting worker address: {e}")
138
- return None
139
-
140
- async def generate_text(self, prompt: str, streaming=False):
141
- """Generate text from LLaMA-Omni2 model"""
142
- worker_addr = await self.get_worker_address()
143
- if not worker_addr:
144
- return f"Error: No worker available for model {self.model_name}"
145
-
146
- try:
147
- async with aiohttp.ClientSession() as session:
148
- # For streaming generation
149
- if streaming:
150
- async with session.post(
151
- f"{worker_addr}/generate_stream",
152
- json={"prompt": prompt},
153
- timeout=120
154
- ) as response:
155
- if response.status == 200:
156
- response_text = ""
157
- async for line in response.content:
158
- if line:
159
- data = json.loads(line)
160
- chunk = data.get("text", "")
161
- response_text += chunk
162
- yield response_text
163
- return response_text
164
- else:
165
- error_text = await response.text()
166
- logger.error(f"Failed to generate text stream: {error_text}")
167
- return f"Error: {error_text}"
168
- # For non-streaming generation
169
- else:
170
- async with session.post(
171
- f"{worker_addr}/generate",
172
- json={"prompt": prompt},
173
- timeout=120
174
- ) as response:
175
- if response.status == 200:
176
- data = await response.json()
177
- return data.get("response", "No response received from model")
178
- else:
179
- error_text = await response.text()
180
- logger.error(f"Failed to generate text: {error_text}")
181
- return f"Error: {error_text}"
182
- except Exception as e:
183
- logger.error(f"Error generating text: {e}")
184
- return f"Error: {str(e)}"
185
-
186
- def transcribe_audio(self, audio_path):
187
- """Transcribe audio using Whisper"""
188
- if self.whisper_model is None:
189
- return "Error: Whisper model not loaded"
190
-
191
- try:
192
- logger.info(f"Transcribing audio from {audio_path}")
193
- result = self.whisper_model.transcribe(audio_path)
194
- logger.info("Transcription completed")
195
- return result["text"]
196
- except Exception as e:
197
- logger.error(f"Error transcribing audio: {e}")
198
- return f"Error transcribing audio: {str(e)}"
199
-
200
- def synthesize_speech(self, text):
201
- """Synthesize speech from text using CosyVoice"""
202
- if self.vocoder is None:
203
- return None, 16000, "Error: Vocoder not loaded"
204
-
205
- try:
206
- logger.info("Synthesizing speech from text response")
207
- # Generate speech using CosyVoice
208
- waveform = self.vocoder.inference(text)
209
- sample_rate = self.vocoder.sample_rate
210
-
211
- # Convert to numpy array for Gradio
212
- if isinstance(waveform, torch.Tensor):
213
- waveform = waveform.cpu().numpy()
214
-
215
- logger.info("Speech synthesis completed")
216
- return waveform, sample_rate, None
217
- except Exception as e:
218
- logger.error(f"Error synthesizing speech: {e}")
219
- return None, 16000, f"Error synthesizing speech: {str(e)}"
220
-
221
- async def synthesize_speech_chunk(self, text_chunk):
222
- """Synthesize speech for a single text chunk"""
223
- if self.vocoder is None:
224
- return None, 16000, "Error: Vocoder not loaded"
225
-
226
- try:
227
- # Generate speech using CosyVoice for this chunk
228
- waveform = self.vocoder.inference(text_chunk)
229
- sample_rate = self.vocoder.sample_rate
230
-
231
- # Convert to numpy array
232
- if isinstance(waveform, torch.Tensor):
233
- waveform = waveform.cpu().numpy()
234
-
235
- return waveform, sample_rate, None
236
- except Exception as e:
237
- logger.error(f"Error synthesizing speech chunk: {e}")
238
- return None, 16000, f"Error synthesizing speech chunk: {str(e)}"
239
-
240
- async def stream_text_to_speech(self, text_generator):
241
- """Stream text to speech using read-write scheduling"""
242
- buffer = ""
243
- audio_chunks = []
244
-
245
- try:
246
- async for text in text_generator:
247
- # Accumulate text until we have enough to synthesize
248
- buffer += text
249
-
250
- # When we have enough tokens for synthesis (approximate by characters)
251
- if len(buffer.split()) >= self.read_tokens:
252
- # Process the buffer
253
- chunk_to_process = buffer
254
- buffer = ""
255
-
256
- # Synthesize this chunk
257
- audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(chunk_to_process)
258
- if error:
259
- logger.error(f"Error in streaming synthesis: {error}")
260
- continue
261
-
262
- # Add to our collection of audio chunks
263
- audio_chunks.append(audio_chunk)
264
-
265
- # Yield the current concatenated audio
266
- if audio_chunks:
267
- # Concatenate audio chunks
268
- full_audio = np.concatenate(audio_chunks)
269
- yield full_audio, sample_rate, chunk_to_process
270
-
271
- # Process any remaining text in the buffer
272
- if buffer:
273
- audio_chunk, sample_rate, error = await self.synthesize_speech_chunk(buffer)
274
- if not error and audio_chunk is not None:
275
- audio_chunks.append(audio_chunk)
276
-
277
- # Final audio output
278
- if audio_chunks:
279
- full_audio = np.concatenate(audio_chunks)
280
- return full_audio, sample_rate, None
281
- else:
282
- return None, 16000, "No audio generated"
283
-
284
- except Exception as e:
285
- logger.error(f"Error in streaming text to speech: {e}")
286
- return None, 16000, f"Error in streaming text to speech: {str(e)}"
287
-
288
- async def process_audio(self, audio_data, sample_rate, streaming=False):
289
- """Process audio input and return audio output"""
290
- # Save the input audio to a temporary file
291
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
292
- temp_path = temp_audio.name
293
- # Convert sample rate if needed
294
- if sample_rate != 16000:
295
- resampler = torchaudio.transforms.Resample(
296
- orig_freq=sample_rate, new_freq=16000
297
- )
298
- audio_tensor = torch.tensor(audio_data).unsqueeze(0)
299
- audio_tensor = resampler(audio_tensor)
300
- audio_data = audio_tensor.squeeze(0).numpy()
301
- sample_rate = 16000
302
-
303
- # Save as WAV
304
- torchaudio.save(temp_path, torch.tensor(audio_data).unsqueeze(0), sample_rate)
305
-
306
- try:
307
- # Step 1: Transcribe audio
308
- transcription = self.transcribe_audio(temp_path)
309
- if transcription.startswith("Error"):
310
- return None, sample_rate, transcription, "Error occurred during transcription", transcription
311
-
312
- # Step 2: Process with LLaMA-Omni2
313
- if streaming:
314
- # For streaming mode, we use a generator
315
- text_generator = self.generate_text(transcription, streaming=True)
316
- audio_generator = self.stream_text_to_speech(text_generator)
317
- return audio_generator, transcription
318
- else:
319
- # For non-streaming mode
320
- response_text = await self.generate_text(transcription)
321
- if response_text.startswith("Error"):
322
- return None, sample_rate, transcription, response_text, response_text
323
-
324
- # Step 3: Synthesize speech
325
- audio_output, out_sample_rate, error = self.synthesize_speech(response_text)
326
- if error:
327
- return None, sample_rate, transcription, response_text, error
328
-
329
- return audio_output, out_sample_rate, transcription, response_text, None
330
- finally:
331
- # Clean up temporary file
332
- if os.path.exists(temp_path):
333
- os.unlink(temp_path)
334
-
335
- def build_interface(self):
336
- """Build Gradio interface"""
337
- with gr.Blocks(title="LLaMA-Omni2 Audio Interface") as demo:
338
- gr.Markdown("# LLaMA-Omni2 Audio Interface")
339
- gr.Markdown("Speak to LLaMA-Omni2 and hear its response in real-time")
340
-
341
- with gr.Row():
342
- with gr.Column():
343
- audio_input = gr.Audio(
344
- sources=["microphone", "upload"],
345
- type="numpy",
346
- label="Input Audio"
347
- )
348
- with gr.Row():
349
- submit_button = gr.Button("Process Audio", variant="primary")
350
- stream_button = gr.Button("Stream Audio Response", variant="secondary")
351
-
352
- with gr.Column():
353
- transcription = gr.Textbox(
354
- label="Transcription",
355
- interactive=False
356
- )
357
- response_text = gr.Textbox(
358
- label="Response Text",
359
- interactive=False
360
- )
361
- audio_output = gr.Audio(
362
- label="Response Audio",
363
- type="numpy",
364
- interactive=False
365
- )
366
- error_text = gr.Textbox(
367
- label="Errors (if any)",
368
- interactive=False,
369
- visible=False
370
- )
371
-
372
- async def process_wrapper(audio_data):
373
- if audio_data is None:
374
- return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
375
-
376
- audio_array, sample_rate = audio_data
377
- output, out_sample_rate, trans, resp, error = await self.process_audio(audio_array, sample_rate, streaming=False)
378
-
379
- if error:
380
- gr.update(visible=True)
381
- return None, trans, resp, error
382
-
383
- return (output, out_sample_rate), trans, resp, ""
384
-
385
- async def stream_wrapper(audio_data):
386
- if audio_data is None:
387
- return None, "No audio input detected", "Please record or upload audio", "No audio input detected"
388
-
389
- audio_array, sample_rate = audio_data
390
- generator, transcription = await self.process_audio(audio_array, sample_rate, streaming=True)
391
-
392
- # Update transcription immediately
393
- yield None, transcription, "", ""
394
-
395
- # Start streaming
396
- current_text = ""
397
- async for audio_chunk, sr, text_chunk in generator:
398
- current_text += text_chunk
399
- yield (audio_chunk, sr), transcription, current_text, ""
400
-
401
- submit_button.click(
402
- fn=lambda audio: asyncio.create_task(process_wrapper(audio)),
403
- inputs=[audio_input],
404
- outputs=[audio_output, transcription, response_text, error_text]
405
- )
406
-
407
- stream_button.click(
408
- fn=lambda audio: stream_wrapper(audio),
409
- inputs=[audio_input],
410
- outputs=[audio_output, transcription, response_text, error_text]
411
- )
412
-
413
- return demo
414
-
415
-
416
- def main():
417
- parser = argparse.ArgumentParser(description="Audio interface for LLaMA-Omni2")
418
- parser.add_argument("--host", type=str, default="0.0.0.0")
419
- parser.add_argument("--port", type=int, default=7860)
420
- parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
421
- parser.add_argument("--whisper-model-path", type=str, default="models/speech_encoder")
422
- parser.add_argument("--vocoder-dir", type=str, default="models/cosy2_decoder")
423
- parser.add_argument("--model-name", type=str, default="LLaMA-Omni2-7B-Bilingual")
424
- parser.add_argument("--read-tokens", type=int, default=3,
425
- help="Number of text tokens to read before generating speech")
426
- parser.add_argument("--write-tokens", type=int, default=10,
427
- help="Number of speech tokens to write for each read")
428
- parser.add_argument("--share", action="store_true", help="Create a public link")
429
- args = parser.parse_args()
430
-
431
- # Create the interface
432
- interface = AudioInterface(
433
- controller_url=args.controller_url,
434
- whisper_model_path=args.whisper_model_path,
435
- vocoder_dir=args.vocoder_dir,
436
- model_name=args.model_name,
437
- read_tokens=args.read_tokens,
438
- write_tokens=args.write_tokens
439
- )
440
-
441
- # Build and launch the interface
442
- demo = interface.build_interface()
443
- demo.queue()
444
- demo.launch(
445
- server_name=args.host,
446
- server_port=args.port,
447
- share=args.share
448
- )
449
-
450
- if __name__ == "__main__":
451
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
check_setup.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import importlib.util
4
+ import subprocess
5
+
6
+ # Define required directories
7
+ required_dirs = [
8
+ "omni_speech",
9
+ "omni_speech/serve",
10
+ "omni_speech/infer",
11
+ "vocoder"
12
+ ]
13
+
14
+ # Define required files
15
+ required_files = [
16
+ "app.py",
17
+ "omni_speech/__init__.py",
18
+ "omni_speech/serve/__init__.py",
19
+ "omni_speech/serve/controller.py",
20
+ "omni_speech/serve/model_worker.py",
21
+ "omni_speech/serve/gradio_web_server.py",
22
+ "omni_speech/infer/__init__.py",
23
+ "omni_speech/infer/inference.py",
24
+ "omni_speech/infer/run.sh"
25
+ ]
26
+
27
+ # Define required packages
28
+ required_packages = [
29
+ "torch",
30
+ "transformers",
31
+ "gradio",
32
+ "fastapi",
33
+ "uvicorn",
34
+ "pydantic",
35
+ "numpy",
36
+ "tqdm"
37
+ ]
38
+
39
+ def check_directory_structure():
40
+ """Check if all required directories exist."""
41
+ print("Checking directory structure...")
42
+ missing_dirs = []
43
+
44
+ for dir_path in required_dirs:
45
+ if not os.path.isdir(dir_path):
46
+ missing_dirs.append(dir_path)
47
+
48
+ if missing_dirs:
49
+ print(f"❌ Missing directories: {', '.join(missing_dirs)}")
50
+ return False
51
+ else:
52
+ print("✅ All required directories exist.")
53
+ return True
54
+
55
+ def check_required_files():
56
+ """Check if all required files exist."""
57
+ print("Checking required files...")
58
+ missing_files = []
59
+
60
+ for file_path in required_files:
61
+ if not os.path.isfile(file_path):
62
+ missing_files.append(file_path)
63
+
64
+ if missing_files:
65
+ print(f"❌ Missing files: {', '.join(missing_files)}")
66
+ return False
67
+ else:
68
+ print("✅ All required files exist.")
69
+ return True
70
+
71
+ def check_packages():
72
+ """Check if all required packages are installed."""
73
+ print("Checking required packages...")
74
+ missing_packages = []
75
+
76
+ for package in required_packages:
77
+ if importlib.util.find_spec(package) is None:
78
+ missing_packages.append(package)
79
+
80
+ if missing_packages:
81
+ print(f"❌ Missing packages: {', '.join(missing_packages)}")
82
+ return False
83
+ else:
84
+ print("✅ All required packages are installed.")
85
+ return True
86
+
87
+ def check_python_version():
88
+ """Check if Python version is compatible."""
89
+ print("Checking Python version...")
90
+ major, minor = sys.version_info[:2]
91
+
92
+ if major != 3 or minor < 10:
93
+ print(f"❌ Incompatible Python version: {major}.{minor}. Python 3.10+ is required.")
94
+ return False
95
+ else:
96
+ print(f"✅ Python version is compatible: {major}.{minor}")
97
+ return True
98
+
99
+ def main():
100
+ """Run all checks."""
101
+ print("🔍 Checking LLaMA-Omni setup...")
102
+ print("-" * 50)
103
+
104
+ checks = [
105
+ check_directory_structure(),
106
+ check_required_files(),
107
+ check_packages(),
108
+ check_python_version()
109
+ ]
110
+
111
+ print("-" * 50)
112
+
113
+ if all(checks):
114
+ print("✅ All checks passed! LLaMA-Omni is set up correctly.")
115
+ print("🚀 Run 'python app.py' to start the application.")
116
+ else:
117
+ print("❌ Some checks failed. Please fix the issues before running the application.")
118
+
119
+ if __name__ == "__main__":
120
+ main()
cog.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ python_version: "3.10"
4
+ python_packages:
5
+ - "torch==2.0.1"
6
+ - "transformers==4.34.0"
7
+ - "accelerate==0.21.0"
8
+ - "gradio==3.50.2"
9
+ - "fastapi==0.104.0"
10
+ - "uvicorn==0.23.2"
11
+ - "pydantic==2.3.0"
12
+ - "openai-whisper==20231117"
13
+ - "numpy==1.24.0"
14
+ - "tqdm==4.66.1"
15
+ - "flash-attn==2.3.0"
16
+ - "requests==2.31.0"
17
+ system_packages:
18
+ - "wget"
19
+ - "ffmpeg"
20
+ - "libsndfile1"
21
+ run:
22
+ - "pip install -e git+https://github.com/pytorch/fairseq.git#egg=fairseq"
23
+ - "mkdir -p vocoder"
24
+ - "wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/g_00500000 -P vocoder/"
25
+ - "wget https://dl.fbaipublicfiles.com/fairseq/speech_to_speech/vocoder/code_hifigan/mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj/config.json -P vocoder/"
26
+
27
+ predict: "predict.py:Predictor"
gradio_app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import threading
4
+ import time
5
+ import os
6
+
7
+ def check_dependencies():
8
+ """Check and install missing dependencies."""
9
+ print("Checking and installing dependencies...")
10
+
11
+ # Create necessary directories
12
+ os.makedirs("models/speech_encoder", exist_ok=True)
13
+ os.makedirs("vocoder", exist_ok=True)
14
+
15
+ # Download vocoder if needed (this will be done on deployment)
16
+ if not os.path.exists("vocoder/g_00500000"):
17
+ print("Vocoder will be downloaded when deployed")
18
+
19
+ # Return success message
20
+ return "✅ Setup ready for deployment!"
21
+
22
+ def launch_services():
23
+ """Prepare to launch all services."""
24
+ return """
25
+ # LLaMA-Omni Services
26
+
27
+ When deployed to Gradio Spaces, this app will:
28
+
29
+ 1. Download required models (Whisper, LLaMA-Omni, vocoder)
30
+ 2. Start the controller
31
+ 3. Start the model worker
32
+ 4. Launch the web interface
33
+
34
+ ## Notes
35
+ - The model will be loaded automatically during deployment
36
+ - Audio can be processed via both speech input and text input
37
+ - The full system allows for seamless speech interaction
38
+ """
39
+
40
+ # Create the demo
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("# 🦙🎧 LLaMA-Omni Deployment Setup")
43
+
44
+ with gr.Tab("Status"):
45
+ status = gr.Markdown(launch_services())
46
+
47
+ with gr.Tab("Setup"):
48
+ check_btn = gr.Button("Check Dependencies")
49
+ result = gr.Textbox(label="Setup Status")
50
+ check_btn.click(check_dependencies, outputs=result)
51
+
52
+ with gr.Tab("About"):
53
+ gr.Markdown("""
54
+ # About LLaMA-Omni
55
+
56
+ LLaMA-Omni is a speech-language model built upon Llama-3.1-8B-Instruct. It supports low-latency and high-quality speech interactions, simultaneously generating both text and speech responses based on speech instructions.
57
+
58
+ ## Features
59
+
60
+ * Built on Llama-3.1-8B-Instruct, ensuring high-quality responses
61
+ * Low-latency speech interaction with a latency as low as 226ms
62
+ * Simultaneous generation of both text and speech responses
63
+
64
+ ## License
65
+
66
+ This code is released under the Apache-2.0 License. The model is intended for academic research purposes only and may NOT be used for commercial purposes.
67
+
68
+ Original work by Qingkai Fang, Shoutao Guo, Yan Zhou, Zhengrui Ma, Shaolei Zhang, Yang Feng.
69
+ """)
70
+
71
+ # Launch the app
72
+ if __name__ == "__main__":
73
+ demo.launch()
launch_llama_omni2.py DELETED
@@ -1,486 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- LLaMA-Omni2 Direct Launcher
4
- ---------------------------
5
- This script extracts and directly runs the LLaMA-Omni2 components without
6
- relying on package imports.
7
- """
8
-
9
- import os
10
- import sys
11
- import subprocess
12
- import time
13
- import argparse
14
- import shutil
15
- import importlib.util
16
- import tempfile
17
- import logging
18
-
19
- # Configure logging
20
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
- logger = logging.getLogger(__name__)
22
-
23
- # Define paths
24
- EXTRACTION_DIR = "/home/user/app/llama_omni2_extracted"
25
- MODELS_DIR = "/home/user/app/models"
26
- LLAMA_OMNI2_MODEL_NAME = "LLaMA-Omni2-0.5B"
27
- LLAMA_OMNI2_MODEL_PATH = f"{MODELS_DIR}/{LLAMA_OMNI2_MODEL_NAME}"
28
- COSYVOICE_PATH = f"{MODELS_DIR}/cosy2_decoder"
29
-
30
- # Importe o model_downloader se disponível
31
- try:
32
- from model_downloader import download_model_if_needed, download_all_models, get_model_repo_id, NO_DOWNLOAD
33
- has_model_downloader = True
34
- except ImportError:
35
- has_model_downloader = False
36
- NO_DOWNLOAD = False
37
-
38
- # Garantir que os modelos estão disponíveis
39
- def ensure_models_available():
40
- """Garante que os modelos necessários estão disponíveis"""
41
- if has_model_downloader:
42
- if NO_DOWNLOAD:
43
- logger.info("Modo NO_DOWNLOAD ativado. Os modelos não serão baixados, usando diretamente do Hugging Face Hub.")
44
- return
45
-
46
- logger.info("Verificando modelos necessários para o LLaMA-Omni2...")
47
- download_model_if_needed("llama_omni2")
48
- download_model_if_needed("cosy2_decoder")
49
- download_model_if_needed("speech_encoder")
50
- logger.info("Verificação de modelos concluída")
51
- else:
52
- logger.warning("model_downloader não está disponível. Os modelos devem estar disponíveis em: " + MODELS_DIR)
53
-
54
- # Additional imports
55
- def download_dependencies():
56
- """Download and install required Python packages for LLaMA-Omni2"""
57
- print("Installing required dependencies...")
58
- dependencies = [
59
- "gradio>=3.50.2",
60
- "fastapi",
61
- "uvicorn",
62
- "pydantic",
63
- "transformers>=4.36.2",
64
- "sentencepiece",
65
- "huggingface_hub"
66
- ]
67
-
68
- try:
69
- subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade"] + dependencies, check=True)
70
- print("Dependencies installed successfully")
71
- return True
72
- except subprocess.CalledProcessError as e:
73
- print(f"Error installing dependencies: {e}")
74
- return False
75
-
76
- def ensure_module_structure(extraction_dir):
77
- """Ensure that the extracted module has the necessary structure"""
78
- print("Ensuring proper module structure...")
79
-
80
- # Create __init__.py files if they don't exist
81
- module_dirs = [
82
- os.path.join(extraction_dir, "llama_omni2"),
83
- os.path.join(extraction_dir, "llama_omni2", "serve"),
84
- os.path.join(extraction_dir, "llama_omni2", "model"),
85
- os.path.join(extraction_dir, "llama_omni2", "common")
86
- ]
87
-
88
- for dir_path in module_dirs:
89
- os.makedirs(dir_path, exist_ok=True)
90
- init_file = os.path.join(dir_path, "__init__.py")
91
- if not os.path.exists(init_file):
92
- with open(init_file, 'w') as f:
93
- f.write("# Auto-generated __init__.py file\n")
94
- print(f"Created {init_file}")
95
-
96
- # Create missing module files with required constants and functions
97
- dummy_modules = {
98
- # Utils module
99
- os.path.join(extraction_dir, "llama_omni2", "utils.py"): """
100
- # Dummy utils module
101
- def dummy_function():
102
- pass
103
- """,
104
- # Constants module - required by controller.py and model_worker.py
105
- os.path.join(extraction_dir, "llama_omni2", "constants.py"): """
106
- # Constants required by LLaMA-Omni2 modules
107
-
108
- # Controller constants
109
- CONTROLLER_HEART_BEAT_EXPIRATION = 120
110
- CONTROLLER_STATUS_POLLING_INTERVAL = 15
111
-
112
- # Worker constants
113
- WORKER_HEART_BEAT_INTERVAL = 30
114
- WORKER_API_TIMEOUT = 100
115
-
116
- # Other constants that might be needed
117
- DEFAULT_PORT = 8000
118
- """
119
- }
120
-
121
- for file_path, content in dummy_modules.items():
122
- if not os.path.exists(file_path):
123
- with open(file_path, 'w') as f:
124
- f.write(content)
125
- print(f"Created {file_path}")
126
-
127
- return True
128
-
129
- def start_controller():
130
- """Start the LLaMA-Omni2 controller directly"""
131
- print("=== Starting LLaMA-Omni2 Controller ===")
132
-
133
- # First try to use our custom implementation
134
- direct_controller_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "controller.py")
135
- if os.path.exists(direct_controller_path):
136
- print(f"Using custom controller implementation: {direct_controller_path}")
137
- cmd = [
138
- sys.executable, direct_controller_path,
139
- "--host", "0.0.0.0",
140
- "--port", "10000"
141
- ]
142
-
143
- env = os.environ.copy()
144
- process = subprocess.Popen(cmd, env=env)
145
- print(f"Controller started with PID: {process.pid}")
146
- return process
147
-
148
- # Fall back to a simple controller implementation
149
- print("No controller script found. Implementing a simple controller...")
150
-
151
- try:
152
- from fastapi import FastAPI, HTTPException
153
- import uvicorn
154
- from pydantic import BaseModel
155
- import threading
156
-
157
- app = FastAPI()
158
-
159
- class ModelInfo(BaseModel):
160
- model_name: str
161
- worker_name: str
162
- worker_addr: str
163
-
164
- # Simple in-memory storage
165
- registered_models = {}
166
-
167
- @app.get("/")
168
- def read_root():
169
- return {"status": "ok", "models": list(registered_models.keys())}
170
-
171
- @app.get("/api/v1/models")
172
- def list_models():
173
- return {"models": list(registered_models.keys())}
174
-
175
- @app.post("/api/v1/register_worker")
176
- def register_worker(model_info: ModelInfo):
177
- registered_models[model_info.model_name] = {
178
- "worker_name": model_info.worker_name,
179
- "worker_addr": model_info.worker_addr
180
- }
181
- return {"status": "ok"}
182
-
183
- # Start a simple controller
184
- def run_controller():
185
- uvicorn.run(app, host="0.0.0.0", port=10000)
186
-
187
- thread = threading.Thread(target=run_controller, daemon=True)
188
- thread.start()
189
-
190
- print("Simple controller started on port 10000")
191
- # Return a dummy process for compatibility
192
- class DummyProcess:
193
- def __init__(self):
194
- self.pid = 0
195
- def terminate(self):
196
- pass
197
- def poll(self):
198
- return None
199
- def wait(self, timeout=None):
200
- pass
201
-
202
- return DummyProcess()
203
-
204
- except ImportError as e:
205
- print(f"Failed to create simple controller: {e}")
206
- return None
207
-
208
- def start_model_worker():
209
- """Start the LLaMA-Omni2 model worker directly"""
210
- print("=== Starting LLaMA-Omni2 Model Worker ===")
211
-
212
- # First try to use our custom implementation
213
- direct_worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model_worker.py")
214
- if os.path.exists(direct_worker_path):
215
- print(f"Using custom model worker implementation: {direct_worker_path}")
216
- cmd = [
217
- sys.executable, direct_worker_path,
218
- "--host", "0.0.0.0",
219
- "--controller", "http://localhost:10000",
220
- "--port", "40000",
221
- "--worker", "http://localhost:40000",
222
- "--model-path", LLAMA_OMNI2_MODEL_PATH,
223
- "--model-name", LLAMA_OMNI2_MODEL_NAME
224
- ]
225
-
226
- env = os.environ.copy()
227
- process = subprocess.Popen(cmd, env=env)
228
- print(f"Model worker started with PID: {process.pid}")
229
- return process
230
-
231
- # Fall back to a simple implementation
232
- print("No model worker script found. Will try to start Gradio directly with the model.")
233
-
234
- class DummyProcess:
235
- def __init__(self):
236
- self.pid = 0
237
- def terminate(self):
238
- pass
239
- def poll(self):
240
- return None
241
- def wait(self, timeout=None):
242
- pass
243
-
244
- return DummyProcess()
245
-
246
- def start_gradio_server():
247
- """Start the LLaMA-Omni2 Gradio web server directly"""
248
- print("=== Starting LLaMA-Omni2 Gradio Server ===")
249
-
250
- # First try to use our custom implementation
251
- direct_gradio_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "gradio_web_server.py")
252
- if os.path.exists(direct_gradio_path):
253
- print(f"Using custom Gradio server implementation: {direct_gradio_path}")
254
- cmd = [
255
- sys.executable, direct_gradio_path,
256
- "--host", "0.0.0.0",
257
- "--port", "7860",
258
- "--controller-url", "http://localhost:10000",
259
- "--vocoder-dir", COSYVOICE_PATH
260
- ]
261
-
262
- env = os.environ.copy()
263
- process = subprocess.Popen(cmd, env=env)
264
- print(f"Gradio server started with PID: {process.pid}")
265
- return process
266
-
267
- # Fall back to a simple Gradio implementation
268
- print("No Gradio server found. Attempting to create a simple interface...")
269
-
270
- try:
271
- import gradio as gr
272
- import threading
273
- from transformers import AutoModelForCausalLM, AutoTokenizer
274
- import torch
275
-
276
- # Simple function to launch a basic Gradio interface
277
- def launch_simple_gradio():
278
- try:
279
- print(f"Loading model from {LLAMA_OMNI2_MODEL_PATH}...")
280
- # Check for CUDA availability
281
- device = "cuda" if torch.cuda.is_available() else "cpu"
282
- print(f"Using device: {device}")
283
-
284
- if device == "cuda":
285
- print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
286
- print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
287
-
288
- tokenizer = AutoTokenizer.from_pretrained(LLAMA_OMNI2_MODEL_PATH)
289
- model = AutoModelForCausalLM.from_pretrained(LLAMA_OMNI2_MODEL_PATH).to(device)
290
-
291
- def generate_text(input_text):
292
- inputs = tokenizer(input_text, return_tensors="pt").to(device)
293
- outputs = model.generate(inputs.input_ids, max_length=100)
294
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
295
-
296
- with gr.Blocks() as demo:
297
- gr.Markdown("# LLaMA-Omni2 Simple Interface")
298
- with gr.Tab("Text Generation"):
299
- input_text = gr.Textbox(label="Input Text")
300
- output_text = gr.Textbox(label="Generated Text")
301
- generate_btn = gr.Button("Generate")
302
- generate_btn.click(generate_text, inputs=input_text, outputs=output_text)
303
-
304
- demo.launch(server_name="0.0.0.0", server_port=7860)
305
-
306
- except Exception as e:
307
- print(f"Error in simple Gradio interface: {e}")
308
-
309
- thread = threading.Thread(target=launch_simple_gradio, daemon=True)
310
- thread.start()
311
-
312
- print("Simple Gradio interface started on port 7860")
313
-
314
- class DummyProcess:
315
- def __init__(self):
316
- self.pid = 0
317
- def terminate(self):
318
- pass
319
- def poll(self):
320
- return None
321
- def wait(self, timeout=None):
322
- pass
323
-
324
- return DummyProcess()
325
-
326
- except ImportError as e:
327
- print(f"Failed to create simple Gradio interface: {e}")
328
- return None
329
-
330
- def patch_extracted_files(extraction_dir):
331
- """Patch the extracted Python files to handle missing imports"""
332
- print("Patching extracted Python files to handle missing imports...")
333
-
334
- # Define files to patch and their imports to check/fix
335
- files_to_patch = {
336
- os.path.join(extraction_dir, "llama_omni2", "serve", "controller.py"): [
337
- "from llama_omni2.constants import",
338
- "from llama_omni2.model import",
339
- "from llama_omni2.common import",
340
- ],
341
- os.path.join(extraction_dir, "llama_omni2", "serve", "model_worker.py"): [
342
- "from llama_omni2.constants import",
343
- "from llama_omni2.model import",
344
- "from llama_omni2.common import",
345
- ],
346
- os.path.join(extraction_dir, "llama_omni2", "serve", "gradio_web_server.py"): [
347
- "from llama_omni2.constants import",
348
- "from llama_omni2.model import",
349
- "from llama_omni2.common import",
350
- ]
351
- }
352
-
353
- patched_files = []
354
- for file_path, imports_to_check in files_to_patch.items():
355
- if not os.path.exists(file_path):
356
- print(f"Warning: File {file_path} not found, skipping patch")
357
- continue
358
-
359
- with open(file_path, 'r') as f:
360
- content = f.read()
361
-
362
- original_content = content
363
- modified = False
364
-
365
- # Add try-except blocks around problematic imports
366
- for import_line in imports_to_check:
367
- if import_line in content:
368
- # Find the full line containing this import
369
- import_lines = [line for line in content.split('\n') if import_line in line]
370
-
371
- for full_line in import_lines:
372
- # Extract the variable names being imported
373
- try:
374
- imported_vars = full_line.split('import')[1].strip().split(',')
375
- imported_vars = [var.strip() for var in imported_vars]
376
-
377
- # Create a try-except block with fallback definitions
378
- replacement = f"""try:
379
- {full_line}
380
- except ImportError:
381
- # Auto-generated fallback for missing import
382
- print("Warning: Creating fallback for missing import: {full_line}")
383
- """
384
- for var in imported_vars:
385
- if var: # Skip empty strings
386
- replacement += f" {var} = object() # Dummy placeholder\n"
387
-
388
- # Replace the original import with the try-except block
389
- content = content.replace(full_line, replacement)
390
- modified = True
391
- except Exception as e:
392
- print(f"Error processing import line '{full_line}': {e}")
393
-
394
- # Write the modified content back if changes were made
395
- if modified:
396
- with open(file_path, 'w') as f:
397
- f.write(content)
398
- patched_files.append(file_path)
399
- print(f"Patched file: {file_path}")
400
-
401
- if patched_files:
402
- print(f"Successfully patched {len(patched_files)} files")
403
- else:
404
- print("No files needed patching")
405
-
406
- return patched_files
407
-
408
- def main():
409
- """Main entry point for the launcher script"""
410
- parser = argparse.ArgumentParser(description="LLaMA-Omni2 Direct Launcher")
411
- parser.add_argument("--skip-download", action="store_true", help="Skip downloading dependencies")
412
- parser.add_argument("--no-model-download", action="store_true", help="Don't download models, use them directly from HF Hub")
413
- parser.add_argument("--extraction-dir", type=str, default=EXTRACTION_DIR, help="Directory to extract LLaMA-Omni2 to")
414
- parser.add_argument("--models-dir", type=str, default=MODELS_DIR, help="Directory containing models")
415
- parser.add_argument("--skip-modules", action="store_true", help="Skip module structure creation")
416
- parser.add_argument("--controller-only", action="store_true", help="Start only the controller")
417
- parser.add_argument("--worker-only", action="store_true", help="Start only the model worker")
418
- parser.add_argument("--gradio-only", action="store_true", help="Start only the Gradio interface")
419
- args = parser.parse_args()
420
-
421
- # Update paths based on arguments
422
- global EXTRACTION_DIR, MODELS_DIR, LLAMA_OMNI2_MODEL_PATH, COSYVOICE_PATH
423
- EXTRACTION_DIR = args.extraction_dir
424
- MODELS_DIR = args.models_dir
425
- LLAMA_OMNI2_MODEL_PATH = f"{MODELS_DIR}/{LLAMA_OMNI2_MODEL_NAME}"
426
- COSYVOICE_PATH = f"{MODELS_DIR}/cosy2_decoder"
427
-
428
- # Set NO_DOWNLOAD environment variable if --no-model-download is specified
429
- if args.no_model_download:
430
- os.environ["NO_DOWNLOAD"] = "1"
431
- global NO_DOWNLOAD
432
- NO_DOWNLOAD = True
433
- logger.info("Modo NO_DOWNLOAD ativado via linha de comando")
434
-
435
- print("=== LLaMA-Omni2 Direct Launcher ===")
436
- print(f"Extraction directory: {EXTRACTION_DIR}")
437
- print(f"Models directory: {MODELS_DIR}")
438
- print(f"Downloading models: {'No' if NO_DOWNLOAD else 'Yes'}")
439
-
440
- # Ensure models are available
441
- ensure_models_available()
442
-
443
- # Download dependencies if needed
444
- if not args.skip_download:
445
- download_dependencies()
446
-
447
- # Create module structure if needed
448
- if not args.skip_modules:
449
- ensure_module_structure(EXTRACTION_DIR)
450
-
451
- # Start the controller if needed
452
- controller_process = None
453
- if not args.worker_only and not args.gradio_only:
454
- controller_process = start_controller()
455
- # Give the controller time to start up
456
- time.sleep(5)
457
-
458
- # Start the model worker if needed
459
- worker_process = None
460
- if not args.controller_only and not args.gradio_only:
461
- worker_process = start_model_worker()
462
- # Give the worker time to start up
463
- time.sleep(5)
464
-
465
- # Start the Gradio interface if needed
466
- gradio_process = None
467
- if not args.controller_only and not args.worker_only:
468
- gradio_process = start_gradio_server()
469
-
470
- # Keep the main process running to maintain subprocesses
471
- try:
472
- print("Press Ctrl+C to exit...")
473
- while True:
474
- time.sleep(1)
475
- except KeyboardInterrupt:
476
- print("Shutting down...")
477
- if controller_process:
478
- controller_process.terminate()
479
- if worker_process:
480
- worker_process.terminate()
481
- if gradio_process:
482
- gradio_process.terminate()
483
- print("Shutdown complete")
484
-
485
- if __name__ == "__main__":
486
- sys.exit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_downloader.py DELETED
@@ -1,219 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Model Downloader para LLaMA-Omni2
4
- ---------------------------------
5
- Este script gerencia o download automático dos modelos necessários para o LLaMA-Omni2.
6
- Os modelos serão baixados apenas quando necessário durante a inicialização.
7
- """
8
-
9
- import os
10
- import sys
11
- import logging
12
- import huggingface_hub
13
- from huggingface_hub import snapshot_download, hf_hub_download
14
- from pathlib import Path
15
- import torch
16
- import shutil
17
-
18
- # Configurar logging
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
- # Configurações de modelos
23
- MODELS_DIR = os.environ.get("MODELS_DIR", "models")
24
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
25
-
26
- # Modo sem download (NO_DOWNLOAD=1)
27
- NO_DOWNLOAD = os.environ.get("NO_DOWNLOAD", "0").lower() in ("1", "true", "yes")
28
-
29
- # Mensagem de debug para verificar o status da variável
30
- logger.info(f"Inicializando model_downloader.py com NO_DOWNLOAD={NO_DOWNLOAD} (valor da env: {os.environ.get('NO_DOWNLOAD', 'não definido')})")
31
-
32
- # Modelos necessários
33
- MODEL_CONFIGS = {
34
- "speech_encoder": {
35
- "repo_id": "openai/whisper-large-v3",
36
- "local_dir": os.path.join(MODELS_DIR, "speech_encoder", "whisper-large-v3"),
37
- "files": None, # None significa baixar o modelo completo
38
- },
39
- "cosy2_decoder": {
40
- "repo_id": "ICTNLP/cosy2_decoder",
41
- "local_dir": os.path.join(MODELS_DIR, "cosy2_decoder"),
42
- "files": [
43
- "flow.decoder.estimator.fp32.onnx",
44
- "flow.decoder.estimator.fp16.A10.plan",
45
- "flow.encoder.fp32.zip",
46
- "flow.decoder.estimator.fp16.Volta.plan",
47
- "hift.pt",
48
- "campplus.onnx",
49
- "cosyvoice.yaml",
50
- ],
51
- },
52
- "llama_omni2": {
53
- "repo_id": "ICTNLP/LLaMA-Omni2-0.5B",
54
- "local_dir": os.path.join(MODELS_DIR, "LLaMA-Omni2-0.5B"),
55
- "files": None, # None significa baixar o modelo completo
56
- }
57
- }
58
-
59
- def ensure_model_dir():
60
- """Garante que o diretório models existe"""
61
- if NO_DOWNLOAD:
62
- logger.info("Modo NO_DOWNLOAD ativado. Pulando criação de diretórios.")
63
- return
64
-
65
- os.makedirs(MODELS_DIR, exist_ok=True)
66
- for model_config in MODEL_CONFIGS.values():
67
- os.makedirs(model_config["local_dir"], exist_ok=True)
68
-
69
- def is_model_downloaded(model_key):
70
- """Verifica se um modelo já foi baixado"""
71
- # No modo sem download, sempre retorna False para pular a verificação
72
- if NO_DOWNLOAD:
73
- logger.info(f"Modo NO_DOWNLOAD ativado. Pulando verificação para {model_key}.")
74
- return False
75
-
76
- config = MODEL_CONFIGS[model_key]
77
- local_dir = config["local_dir"]
78
-
79
- # Se não temos uma lista específica de arquivos, verificar apenas se o diretório existe
80
- if config["files"] is None:
81
- # Verificar se o diretório existe e tem arquivos
82
- if os.path.exists(local_dir) and any(os.listdir(local_dir)):
83
- logger.info(f"Modelo {model_key} já parece estar baixado em {local_dir}")
84
- return True
85
- return False
86
-
87
- # Verificar se todos os arquivos específicos existem
88
- for file in config["files"]:
89
- file_path = os.path.join(local_dir, file)
90
- if not os.path.exists(file_path):
91
- logger.info(f"Arquivo {file} não encontrado para o modelo {model_key}")
92
- return False
93
-
94
- logger.info(f"Todos os arquivos para o modelo {model_key} já estão disponíveis em {local_dir}")
95
- return True
96
-
97
- def download_model(model_key):
98
- """Baixa um modelo específico do Hugging Face Hub"""
99
- # Verificar o modo sem download
100
- if NO_DOWNLOAD:
101
- logger.warning(f"Modo NO_DOWNLOAD ativado. Pulando download de {model_key}")
102
- return False
103
-
104
- config = MODEL_CONFIGS[model_key]
105
- repo_id = config["repo_id"]
106
- local_dir = config["local_dir"]
107
- files = config["files"]
108
-
109
- try:
110
- logger.info(f"Baixando modelo {model_key} do repo {repo_id}...")
111
-
112
- # Se temos uma lista específica de arquivos, baixar um por um
113
- if files is not None:
114
- for file in files:
115
- file_path = os.path.join(local_dir, file)
116
-
117
- # Pular se o arquivo já existe
118
- if os.path.exists(file_path):
119
- logger.info(f"Arquivo {file} já existe, pulando download")
120
- continue
121
-
122
- logger.info(f"Baixando arquivo {file} para {file_path}")
123
- try:
124
- hf_hub_download(
125
- repo_id=repo_id,
126
- filename=file,
127
- local_dir=local_dir,
128
- local_dir_use_symlinks=False,
129
- token=HF_TOKEN
130
- )
131
- except Exception as e:
132
- logger.warning(f"Erro ao baixar arquivo {file}: {e}. Tentando continuar.")
133
- else:
134
- # Baixar o modelo completo
135
- snapshot_download(
136
- repo_id=repo_id,
137
- local_dir=local_dir,
138
- local_dir_use_symlinks=False,
139
- token=HF_TOKEN
140
- )
141
-
142
- logger.info(f"Modelo {model_key} baixado com sucesso para {local_dir}")
143
- return True
144
- except Exception as e:
145
- logger.error(f"Erro ao baixar modelo {model_key}: {e}")
146
- return False
147
-
148
- def cleanup_model_dir(model_key):
149
- """Remove arquivos incompletos ou corruptos de um diretório de modelo"""
150
- # Verificar o modo sem download
151
- if NO_DOWNLOAD:
152
- logger.info(f"Modo NO_DOWNLOAD ativado. Pulando limpeza de diretório para {model_key}.")
153
- return True
154
-
155
- config = MODEL_CONFIGS[model_key]
156
- local_dir = config["local_dir"]
157
-
158
- try:
159
- # Procurar por arquivos .incomplete e removê-los
160
- for root, dirs, files in os.walk(local_dir):
161
- for file in files:
162
- if file.endswith(".incomplete"):
163
- file_path = os.path.join(root, file)
164
- logger.info(f"Removendo arquivo incompleto: {file_path}")
165
- os.remove(file_path)
166
-
167
- return True
168
- except Exception as e:
169
- logger.error(f"Erro ao limpar diretório do modelo {model_key}: {e}")
170
- return False
171
-
172
- def download_all_models():
173
- """Baixa todos os modelos configurados, se necessário"""
174
- # Verificar o modo sem download
175
- if NO_DOWNLOAD:
176
- logger.warning("Modo NO_DOWNLOAD ativado. Nenhum modelo será baixado.")
177
- return
178
-
179
- ensure_model_dir()
180
-
181
- for model_key in MODEL_CONFIGS:
182
- if not is_model_downloaded(model_key):
183
- logger.info(f"Iniciando download do modelo {model_key}")
184
- cleanup_model_dir(model_key)
185
- download_model(model_key)
186
- else:
187
- logger.info(f"Modelo {model_key} já está disponível localmente")
188
-
189
- def download_model_if_needed(model_key):
190
- """Baixa um modelo específico se ele não estiver disponível"""
191
- # Verificar o modo sem download
192
- if NO_DOWNLOAD:
193
- logger.info(f"Modo NO_DOWNLOAD ativado. Usando repo_id diretamente para {model_key}")
194
- return False
195
-
196
- ensure_model_dir()
197
-
198
- if model_key not in MODEL_CONFIGS:
199
- logger.error(f"Modelo {model_key} não está configurado para download")
200
- return False
201
-
202
- if not is_model_downloaded(model_key):
203
- logger.info(f"Modelo {model_key} não encontrado localmente. Iniciando download...")
204
- cleanup_model_dir(model_key)
205
- return download_model(model_key)
206
- else:
207
- logger.info(f"Modelo {model_key} já está disponível localmente")
208
- return True
209
-
210
- def get_model_repo_id(model_key):
211
- """Retorna o repo_id do modelo para uso direto sem download"""
212
- if model_key not in MODEL_CONFIGS:
213
- logger.error(f"Modelo {model_key} não está configurado")
214
- return None
215
- return MODEL_CONFIGS[model_key]["repo_id"]
216
-
217
- if __name__ == "__main__":
218
- # Se executado diretamente, baixar todos os modelos
219
- download_all_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
no_download.py DELETED
@@ -1,55 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Script para iniciar aplicações no modo sem download.
4
- Este script define explicitamente a variável NO_DOWNLOAD=1 no ambiente Python,
5
- garantindo que nenhum modelo seja baixado.
6
- """
7
-
8
- import os
9
- import sys
10
- import subprocess
11
- import logging
12
-
13
- # Configurar logging
14
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
15
- logger = logging.getLogger("no_download")
16
-
17
- # Definir a variável NO_DOWNLOAD no ambiente
18
- os.environ["NO_DOWNLOAD"] = "1"
19
- logger.info(f"Variável NO_DOWNLOAD definida como: {os.environ.get('NO_DOWNLOAD')}")
20
-
21
- # Verificar argumentos de linha de comando
22
- if len(sys.argv) < 2:
23
- logger.info("Nenhum script especificado. Executando app.py por padrão.")
24
- target_script = "app.py"
25
- else:
26
- target_script = sys.argv[1]
27
- logger.info(f"Executando script: {target_script}")
28
-
29
- # Lista de argumentos extras
30
- args = sys.argv[2:]
31
-
32
- # Exibir informações
33
- print("=" * 70)
34
- print(f"Executando {target_script} no modo SEM DOWNLOAD (NO_DOWNLOAD=1)")
35
- print("Os modelos serão usados diretamente do Hugging Face Hub, sem baixar localmente")
36
- print("=" * 70)
37
-
38
- # Executar o script alvo com os mesmos argumentos
39
- try:
40
- # Criar um dicionário de ambiente com NO_DOWNLOAD definido
41
- env = os.environ.copy()
42
- env["NO_DOWNLOAD"] = "1"
43
-
44
- # Construir o comando
45
- command = [sys.executable, target_script] + args
46
- logger.info(f"Executando comando: {' '.join(command)}")
47
-
48
- # Execute o comando com o ambiente modificado
49
- process = subprocess.Popen(command, env=env)
50
- process.wait()
51
-
52
- sys.exit(process.returncode)
53
- except Exception as e:
54
- logger.error(f"Erro ao executar {target_script}: {e}")
55
- sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
omni_speech/__init__.py ADDED
File without changes
omni_speech/infer/__init__.py ADDED
File without changes
omni_speech/infer/examples/example.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "instructions": [
3
+ {
4
+ "id": "001",
5
+ "input_type": "speech",
6
+ "audio_path": "input_audio.wav",
7
+ "transcription": "What is the weather like today?"
8
+ },
9
+ {
10
+ "id": "002",
11
+ "input_type": "text",
12
+ "text": "Tell me about the history of artificial intelligence."
13
+ }
14
+ ]
15
+ }
omni_speech/infer/inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import logging
5
+ from typing import Dict, List, Optional
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def load_model():
11
+ """Load LLaMA-Omni model for inference (placeholder)."""
12
+ logger.info("Loading LLaMA-Omni model...")
13
+ logger.info("Note: In a real deployment, the model would be downloaded from Hugging Face")
14
+ return "PLACEHOLDER_MODEL"
15
+
16
+ def load_vocoder():
17
+ """Load vocoder for speech synthesis (placeholder)."""
18
+ logger.info("Loading vocoder...")
19
+ logger.info("Note: In a real deployment, the vocoder would be downloaded")
20
+ return "PLACEHOLDER_VOCODER"
21
+
22
+ def transcribe_audio(audio_path):
23
+ """Transcribe audio using Whisper (placeholder)."""
24
+ logger.info(f"Transcribing audio: {audio_path}")
25
+ # In a real implementation, this would use the Whisper model
26
+ return f"Placeholder transcription for {os.path.basename(audio_path)}"
27
+
28
+ def process_instruction(instruction, model, vocoder):
29
+ """Process a single instruction."""
30
+ instruction_id = instruction.get("id", "unknown")
31
+ input_type = instruction.get("input_type")
32
+
33
+ logger.info(f"Processing instruction {instruction_id}, type: {input_type}")
34
+
35
+ if input_type == "speech":
36
+ audio_path = instruction.get("audio_path")
37
+ if not audio_path:
38
+ logger.error(f"Instruction {instruction_id}: Missing audio path")
39
+ return None
40
+
41
+ # Check if transcription is provided, otherwise transcribe
42
+ transcription = instruction.get("transcription")
43
+ if not transcription:
44
+ transcription = transcribe_audio(audio_path)
45
+
46
+ # In a real implementation, this would process the transcription through the model
47
+ text_response = f"Placeholder response to: {transcription}"
48
+
49
+ # In a real implementation, this would generate speech from the text response
50
+ speech_output = "PLACEHOLDER_SPEECH_OUTPUT"
51
+
52
+ return {
53
+ "id": instruction_id,
54
+ "input_type": input_type,
55
+ "transcription": transcription,
56
+ "text_response": text_response,
57
+ "speech_output": speech_output
58
+ }
59
+
60
+ elif input_type == "text":
61
+ text = instruction.get("text")
62
+ if not text:
63
+ logger.error(f"Instruction {instruction_id}: Missing text")
64
+ return None
65
+
66
+ # In a real implementation, this would process the text through the model
67
+ text_response = f"Placeholder response to: {text}"
68
+
69
+ # In a real implementation, this would generate speech from the text response
70
+ speech_output = "PLACEHOLDER_SPEECH_OUTPUT"
71
+
72
+ return {
73
+ "id": instruction_id,
74
+ "input_type": input_type,
75
+ "text": text,
76
+ "text_response": text_response,
77
+ "speech_output": speech_output
78
+ }
79
+
80
+ else:
81
+ logger.error(f"Instruction {instruction_id}: Unknown input type: {input_type}")
82
+ return None
83
+
84
+ def process_instructions(input_file, output_dir):
85
+ """Process instructions from input file and save results to output directory."""
86
+ # Create output directory if it doesn't exist
87
+ os.makedirs(output_dir, exist_ok=True)
88
+
89
+ # Load input JSON
90
+ with open(input_file, 'r') as f:
91
+ data = json.load(f)
92
+
93
+ instructions = data.get("instructions", [])
94
+ logger.info(f"Loaded {len(instructions)} instructions from {input_file}")
95
+
96
+ # Load model and vocoder
97
+ model = load_model()
98
+ vocoder = load_vocoder()
99
+
100
+ # Process each instruction
101
+ results = []
102
+ for instruction in instructions:
103
+ result = process_instruction(instruction, model, vocoder)
104
+ if result:
105
+ results.append(result)
106
+
107
+ # Save results
108
+ output_file = os.path.join(output_dir, f"{os.path.basename(input_file)}_results.json")
109
+ with open(output_file, 'w') as f:
110
+ json.dump({"results": results}, f, indent=2)
111
+
112
+ logger.info(f"Saved {len(results)} results to {output_file}")
113
+
114
+ def main():
115
+ """Run inference."""
116
+ parser = argparse.ArgumentParser(description="LLaMA-Omni inference")
117
+ parser.add_argument("--input", type=str, required=True, help="Input JSON file with instructions")
118
+ parser.add_argument("--output", type=str, required=True, help="Output directory for results")
119
+
120
+ args = parser.parse_args()
121
+
122
+ process_instructions(args.input, args.output)
123
+
124
+ if __name__ == "__main__":
125
+ main()
omni_speech/infer/run.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Run inference on LLaMA-Omni model
4
+ # Usage: bash run.sh <examples_directory>
5
+
6
+ EXAMPLES_DIR=$1
7
+
8
+ if [ -z "$EXAMPLES_DIR" ]; then
9
+ echo "Error: Examples directory not specified"
10
+ echo "Usage: bash run.sh <examples_directory>"
11
+ exit 1
12
+ fi
13
+
14
+ if [ ! -d "$EXAMPLES_DIR" ]; then
15
+ echo "Error: Directory $EXAMPLES_DIR does not exist"
16
+ exit 1
17
+ fi
18
+
19
+ # Check if the model and vocoder exist (placeholders for real implementation)
20
+ echo "Checking if required models are available..."
21
+ echo "Note: In a real deployment, the model would be downloaded from Hugging Face"
22
+
23
+ # Process each JSON file in the examples directory
24
+ for json_file in "$EXAMPLES_DIR"/*.json; do
25
+ if [ -f "$json_file" ]; then
26
+ echo "Processing $json_file..."
27
+ # In a real implementation, this would call a Python script
28
+ echo "python -m omni_speech.infer.inference --input $json_file --output results/$(basename $json_file .json)"
29
+ fi
30
+ done
31
+
32
+ echo "Inference complete."
omni_speech/serve/__init__.py ADDED
File without changes
omni_speech/serve/controller.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import time
5
+ from fastapi import FastAPI, WebSocket, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import uvicorn
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Store worker information
21
+ worker_info = {}
22
+
23
+ @app.get("/status")
24
+ async def get_status():
25
+ """Get the status of the controller."""
26
+ return {"status": "ok", "worker_count": len(worker_info)}
27
+
28
+ @app.get("/worker_info")
29
+ async def get_worker_info():
30
+ """Get information about all registered workers."""
31
+ return {"worker_info": worker_info}
32
+
33
+ @app.post("/register_worker")
34
+ async def register_worker(worker_info_data: Dict):
35
+ """Register a new worker."""
36
+ worker_name = worker_info_data.get("name")
37
+ worker_url = worker_info_data.get("url")
38
+
39
+ if not worker_name or not worker_url:
40
+ raise HTTPException(status_code=400, detail="Missing name or URL for worker")
41
+
42
+ models = worker_info_data.get("models", [])
43
+
44
+ worker_info[worker_name] = {
45
+ "url": worker_url,
46
+ "models": models,
47
+ "status": "alive",
48
+ "last_heartbeat": time.time()
49
+ }
50
+
51
+ return {"status": "registered", "worker_name": worker_name}
52
+
53
+ @app.post("/unregister_worker")
54
+ async def unregister_worker(worker_name: str):
55
+ """Unregister a worker."""
56
+ if worker_name in worker_info:
57
+ del worker_info[worker_name]
58
+ return {"status": "unregistered", "worker_name": worker_name}
59
+ else:
60
+ raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found")
61
+
62
+ @app.post("/heartbeat")
63
+ async def heartbeat(worker_data: Dict):
64
+ """Process worker heartbeat."""
65
+ worker_name = worker_data.get("name")
66
+
67
+ if worker_name in worker_info:
68
+ worker_info[worker_name]["last_heartbeat"] = time.time()
69
+ worker_info[worker_name]["status"] = "alive"
70
+ return {"status": "received"}
71
+ else:
72
+ raise HTTPException(status_code=404, detail=f"Worker {worker_name} not found")
73
+
74
+ @app.get("/get_worker_address")
75
+ async def get_worker_address(model_name: str):
76
+ """Get the address of a worker that hosts the requested model."""
77
+ for name, info in worker_info.items():
78
+ if model_name in info["models"] and info["status"] == "alive":
79
+ return {"worker_address": info["url"]}
80
+
81
+ raise HTTPException(status_code=404, detail=f"No available worker found for model {model_name}")
82
+
83
+ @app.get("/list_models")
84
+ async def list_models():
85
+ """List all available models across workers."""
86
+ available_models = []
87
+ for name, info in worker_info.items():
88
+ if info["status"] == "alive":
89
+ available_models.extend(info["models"])
90
+
91
+ return {"models": list(set(available_models))}
92
+
93
+ def main():
94
+ """Run the controller server."""
95
+ parser = argparse.ArgumentParser(description="LLaMA-Omni controller for managing worker nodes")
96
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server")
97
+ parser.add_argument("--port", type=int, default=10000, help="Port to bind the server")
98
+
99
+ args = parser.parse_args()
100
+
101
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
102
+
103
+ if __name__ == "__main__":
104
+ main()
omni_speech/serve/gradio_web_server.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import requests
6
+ import gradio as gr
7
+ import uuid
8
+ import logging
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+ import tempfile
11
+
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Global variables
16
+ controller_url = None
17
+ vocoder_path = None
18
+ vocoder_cfg = None
19
+ model_list_mode = "once" # "once" or "reload"
20
+ avatars = {}
21
+ message_history = {}
22
+
23
+ def list_models():
24
+ """Get list of available models from the controller."""
25
+ try:
26
+ response = requests.get(f"{controller_url}/list_models")
27
+ if response.status_code == 200:
28
+ models = response.json().get("models", [])
29
+ return models
30
+ else:
31
+ logger.error(f"Failed to list models: {response.text}")
32
+ return []
33
+ except Exception as e:
34
+ logger.error(f"Error listing models: {str(e)}")
35
+ return []
36
+
37
+ def get_worker_address(model_name):
38
+ """Get address of a worker that serves the requested model."""
39
+ try:
40
+ response = requests.get(f"{controller_url}/get_worker_address", params={"model_name": model_name})
41
+ if response.status_code == 200:
42
+ return response.json().get("worker_address")
43
+ else:
44
+ logger.error(f"Failed to get worker address: {response.text}")
45
+ return None
46
+ except Exception as e:
47
+ logger.error(f"Error getting worker address: {str(e)}")
48
+ return None
49
+
50
+ def transcribe_audio(audio_path):
51
+ """Placeholder for audio transcription."""
52
+ # In a real implementation, this would use the Whisper model
53
+ logger.info(f"Transcribing audio from {audio_path}...")
54
+ # Simulated transcription
55
+ return f"This is a placeholder transcription for audio file {os.path.basename(audio_path)}"
56
+
57
+ def process_speech_to_speech(audio_path, model_name):
58
+ """Process speech to speech generation."""
59
+ if not audio_path:
60
+ return "Error: No audio provided", None
61
+
62
+ try:
63
+ # Transcribe the audio
64
+ transcription = transcribe_audio(audio_path)
65
+
66
+ # Get worker address
67
+ worker_address = get_worker_address(model_name)
68
+ if not worker_address:
69
+ return f"Error: No worker available for model {model_name}", None
70
+
71
+ # Send request to worker
72
+ response = requests.post(
73
+ f"{worker_address}/generate_speech",
74
+ json={"prompt": transcription}
75
+ )
76
+
77
+ if response.status_code == 200:
78
+ result = response.json()
79
+ text_response = result.get("text", "No text response generated")
80
+ speech_url = result.get("speech_url")
81
+
82
+ # In a real implementation, we would handle the audio file
83
+ # For now, we'll just return the text response
84
+ return text_response, speech_url
85
+ else:
86
+ return f"Error: {response.text}", None
87
+ except Exception as e:
88
+ logger.error(f"Error in speech-to-speech processing: {str(e)}")
89
+ return f"Error: {str(e)}", None
90
+
91
+ def process_text_to_speech(text, model_name):
92
+ """Process text to speech generation."""
93
+ if not text:
94
+ return "Error: No text provided", None
95
+
96
+ try:
97
+ # Get worker address
98
+ worker_address = get_worker_address(model_name)
99
+ if not worker_address:
100
+ return f"Error: No worker available for model {model_name}", None
101
+
102
+ # Send request to worker
103
+ response = requests.post(
104
+ f"{worker_address}/generate_speech",
105
+ json={"prompt": text}
106
+ )
107
+
108
+ if response.status_code == 200:
109
+ result = response.json()
110
+ text_response = result.get("text", "No text response generated")
111
+ speech_url = result.get("speech_url")
112
+
113
+ # In a real implementation, we would handle the audio file
114
+ # For now, we'll just return the text response
115
+ return text_response, speech_url
116
+ else:
117
+ return f"Error: {response.text}", None
118
+ except Exception as e:
119
+ logger.error(f"Error in text-to-speech processing: {str(e)}")
120
+ return f"Error: {str(e)}", None
121
+
122
+ def create_chat_ui():
123
+ """Create the Gradio chat UI."""
124
+ available_models = list_models()
125
+ logger.info(f"Available models: {available_models}")
126
+
127
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
128
+ gr.Markdown("# 🦙🎧 LLaMA-Omni Speech Interaction Demo")
129
+
130
+ with gr.Row():
131
+ with gr.Column(scale=3):
132
+ # Input area
133
+ with gr.Tab("Speech Input"):
134
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload audio")
135
+ transcription_output = gr.Textbox(label="Transcription", interactive=False)
136
+
137
+ with gr.Tab("Text Input"):
138
+ text_input = gr.Textbox(label="Text Input", placeholder="Type your message here...")
139
+
140
+ # Common controls
141
+ with gr.Row():
142
+ model_selector = gr.Dropdown(choices=available_models, label="Model", value=available_models[0] if available_models else None)
143
+ submit_btn = gr.Button("Submit")
144
+
145
+ if model_list_mode == "reload":
146
+ refresh_btn = gr.Button("Refresh Models")
147
+
148
+ with gr.Column(scale=4):
149
+ # Output area
150
+ chatbot = gr.Chatbot(label="Conversation", height=500)
151
+
152
+ with gr.Row():
153
+ audio_output = gr.Audio(label="Generated Speech", interactive=False)
154
+
155
+ # Event handlers
156
+ def on_audio_input(audio):
157
+ if audio:
158
+ transcription = transcribe_audio(audio)
159
+ return transcription
160
+ return ""
161
+
162
+ def on_speech_submit(audio, model_name, chat_history):
163
+ if not audio:
164
+ return chat_history, None
165
+
166
+ transcription = transcribe_audio(audio)
167
+ text_response, speech_url = process_speech_to_speech(audio, model_name)
168
+
169
+ # Update chat history
170
+ new_history = chat_history.copy()
171
+ new_history.append((transcription, text_response))
172
+
173
+ # In a real implementation, we would handle the audio file
174
+ # For now, we'll just return None for audio output
175
+ return new_history, None
176
+
177
+ def on_text_submit(text, model_name, chat_history):
178
+ if not text:
179
+ return chat_history, None
180
+
181
+ text_response, speech_url = process_text_to_speech(text, model_name)
182
+
183
+ # Update chat history
184
+ new_history = chat_history.copy()
185
+ new_history.append((text, text_response))
186
+
187
+ # In a real implementation, we would handle the audio file
188
+ # For now, we'll just return None for audio output
189
+ return new_history, None
190
+
191
+ def on_refresh_models():
192
+ return gr.Dropdown.update(choices=list_models())
193
+
194
+ # Connect events
195
+ audio_input.change(on_audio_input, [audio_input], [transcription_output])
196
+
197
+ submit_btn.click(
198
+ fn=lambda audio, text, model, chat: on_speech_submit(audio, model, chat) if audio else on_text_submit(text, model, chat),
199
+ inputs=[audio_input, text_input, model_selector, chatbot],
200
+ outputs=[chatbot, audio_output]
201
+ )
202
+
203
+ if model_list_mode == "reload":
204
+ refresh_btn.click(on_refresh_models, [], [model_selector])
205
+
206
+ return demo
207
+
208
+ def main():
209
+ """Run the Gradio web server."""
210
+ global controller_url, vocoder_path, vocoder_cfg, model_list_mode
211
+
212
+ parser = argparse.ArgumentParser(description="LLaMA-Omni Gradio web server")
213
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server")
214
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind the server")
215
+ parser.add_argument("--controller", type=str, required=True, help="Controller URL")
216
+ parser.add_argument("--vocoder", type=str, required=True, help="Path to vocoder model")
217
+ parser.add_argument("--vocoder-cfg", type=str, required=True, help="Path to vocoder config")
218
+ parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"], help="Model listing mode")
219
+
220
+ args = parser.parse_args()
221
+
222
+ controller_url = args.controller
223
+ vocoder_path = args.vocoder
224
+ vocoder_cfg = args.vocoder_cfg
225
+ model_list_mode = args.model_list_mode
226
+
227
+ # Create the demo
228
+ demo = create_chat_ui()
229
+
230
+ # Launch the server
231
+ demo.launch(server_name=args.host, server_port=args.port, share=False)
232
+
233
+ if __name__ == "__main__":
234
+ main()
omni_speech/serve/model_worker.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import uuid
6
+ import requests
7
+ import threading
8
+ import transformers
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+ import torch
11
+ from typing import Dict, List, Optional, Union
12
+ import traceback
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ import uvicorn
16
+ import logging
17
+
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+ logger = logging.getLogger(__name__)
20
+
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Global variables
32
+ model = None
33
+ tokenizer = None
34
+ model_name = None
35
+ model_path = None
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ controller_url = None
38
+ worker_url = None
39
+ worker_id = str(uuid.uuid4())[:8]
40
+ support_s2s = False
41
+
42
+ def load_model(model_path_arg, s2s=False):
43
+ """Load LLaMA-Omni model and tokenizer."""
44
+ global model, tokenizer, model_name, model_path, support_s2s
45
+
46
+ model_name = os.path.basename(model_path_arg)
47
+ model_path = model_path_arg
48
+ support_s2s = s2s
49
+
50
+ logger.info(f"Loading model {model_name} from {model_path}...")
51
+
52
+ # This is a placeholder for downloading the model
53
+ # In a real implementation, it would download from HuggingFace or another source
54
+ logger.info(f"Model would be downloaded from huggingface.co/ictnlp/Llama-3.1-8B-Omni")
55
+
56
+ try:
57
+ # Use placeholder values since we're not actually loading the model in this setup
58
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
59
+ model = "PLACEHOLDER - Model would be loaded during actual deployment"
60
+
61
+ logger.info(f"Model {model_name} loaded successfully")
62
+ return True
63
+ except Exception as e:
64
+ logger.error(f"Error loading model: {str(e)}")
65
+ logger.error(traceback.format_exc())
66
+ return False
67
+
68
+ def register_worker():
69
+ """Register with the controller."""
70
+ global worker_id, controller_url, worker_url, model_name
71
+
72
+ logger.info(f"Registering worker {worker_id} with controller at {controller_url}")
73
+
74
+ while True:
75
+ try:
76
+ response = requests.post(
77
+ f"{controller_url}/register_worker",
78
+ json={
79
+ "name": worker_id,
80
+ "url": worker_url,
81
+ "models": [model_name] if model_name else []
82
+ }
83
+ )
84
+
85
+ if response.status_code == 200:
86
+ logger.info(f"Worker {worker_id} registered successfully")
87
+ break
88
+ else:
89
+ logger.error(f"Failed to register worker: {response.text}")
90
+ except Exception as e:
91
+ logger.error(f"Error registering worker: {str(e)}")
92
+
93
+ # Retry after a short delay
94
+ time.sleep(5)
95
+
96
+ def heartbeat_sender():
97
+ """Send heartbeats to the controller."""
98
+ global worker_id, controller_url
99
+
100
+ while True:
101
+ try:
102
+ response = requests.post(
103
+ f"{controller_url}/heartbeat",
104
+ json={"name": worker_id}
105
+ )
106
+
107
+ if response.status_code == 200:
108
+ logger.debug(f"Heartbeat sent successfully")
109
+ else:
110
+ logger.warning(f"Failed to send heartbeat: {response.text}")
111
+ except Exception as e:
112
+ logger.error(f"Error sending heartbeat: {str(e)}")
113
+
114
+ # Send heartbeat every 15 seconds
115
+ time.sleep(15)
116
+
117
+ @app.get("/status")
118
+ async def get_status():
119
+ """Get the status of the worker."""
120
+ return {
121
+ "status": "ok",
122
+ "model": model_name,
123
+ "supports_speech": support_s2s
124
+ }
125
+
126
+ @app.post("/generate_speech")
127
+ async def generate_speech(request_data: Dict):
128
+ """Generate speech response from a prompt."""
129
+ prompt = request_data.get("prompt")
130
+
131
+ if not prompt:
132
+ raise HTTPException(status_code=400, detail="Prompt is required")
133
+
134
+ try:
135
+ # This is a placeholder since we're not actually generating speech
136
+ # In a real implementation, it would process the prompt and return speech
137
+ logger.info(f"Received prompt: {prompt[:50]}...")
138
+
139
+ # Simulated response
140
+ response = {
141
+ "text": f"This is a response to: {prompt[:20]}...",
142
+ "speech_url": None, # In a real implementation, this would be the URL to the generated speech
143
+ "success": True
144
+ }
145
+
146
+ return response
147
+ except Exception as e:
148
+ logger.error(f"Error generating speech: {str(e)}")
149
+ logger.error(traceback.format_exc())
150
+ raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
151
+
152
+ @app.post("/generate_text")
153
+ async def generate_text(request_data: Dict):
154
+ """Generate text response from a prompt."""
155
+ prompt = request_data.get("prompt")
156
+
157
+ if not prompt:
158
+ raise HTTPException(status_code=400, detail="Prompt is required")
159
+
160
+ try:
161
+ # This is a placeholder since we're not actually generating text
162
+ # In a real implementation, it would process the prompt and return text
163
+ logger.info(f"Received prompt: {prompt[:50]}...")
164
+
165
+ # Simulated response
166
+ response = {
167
+ "text": f"This is a response to: {prompt[:20]}...",
168
+ "success": True
169
+ }
170
+
171
+ return response
172
+ except Exception as e:
173
+ logger.error(f"Error generating text: {str(e)}")
174
+ logger.error(traceback.format_exc())
175
+ raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
176
+
177
+ def main():
178
+ """Run the model worker."""
179
+ global controller_url, worker_url
180
+
181
+ parser = argparse.ArgumentParser(description="LLaMA-Omni model worker")
182
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server")
183
+ parser.add_argument("--port", type=int, default=40000, help="Port to bind the server")
184
+ parser.add_argument("--controller", type=str, required=True, help="Controller URL")
185
+ parser.add_argument("--worker", type=str, required=True, help="Worker URL")
186
+ parser.add_argument("--model-path", type=str, required=True, help="Path or name of the model to load")
187
+ parser.add_argument("--model-name", type=str, required=True, help="Name to register the model as")
188
+ parser.add_argument("--s2s", action="store_true", help="Enable speech-to-speech support")
189
+
190
+ args = parser.parse_args()
191
+
192
+ controller_url = args.controller
193
+ worker_url = args.worker
194
+
195
+ # Load the model
196
+ if not load_model(args.model_path, args.s2s):
197
+ logger.error("Failed to load model. Exiting.")
198
+ return
199
+
200
+ # Register with the controller
201
+ register_worker()
202
+
203
+ # Start heartbeat thread
204
+ heartbeat_thread = threading.Thread(target=heartbeat_sender, daemon=True)
205
+ heartbeat_thread.start()
206
+
207
+ # Start the server
208
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
209
+
210
+ if __name__ == "__main__":
211
+ main()
predict.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import subprocess
4
+ import whisper
5
+ from cog import BasePredictor, Input, Path
6
+ import torch
7
+ import tempfile
8
+
9
+ class Predictor(BasePredictor):
10
+ def setup(self):
11
+ """Load the model into memory to make inference faster"""
12
+ print("Loading models...")
13
+
14
+ # Load whisper for audio transcription
15
+ print("Loading Whisper model...")
16
+ self.whisper_model = whisper.load_model("large-v3", download_root="models/speech_encoder/")
17
+
18
+ # In a real implementation, this would load the LLaMA-Omni model
19
+ print("Note: In a real deployment, the LLaMA-Omni model would be loaded here")
20
+
21
+ # Start the controller
22
+ print("Starting controller...")
23
+ self.controller_process = subprocess.Popen([
24
+ "python", "-m", "omni_speech.serve.controller",
25
+ "--host", "0.0.0.0",
26
+ "--port", "10000"
27
+ ])
28
+ time.sleep(5) # Wait for controller to start
29
+
30
+ # Start model worker
31
+ print("Starting model worker...")
32
+ self.model_worker_process = subprocess.Popen([
33
+ "python", "-m", "omni_speech.serve.model_worker",
34
+ "--host", "0.0.0.0",
35
+ "--controller", "http://localhost:10000",
36
+ "--port", "40000",
37
+ "--worker", "http://localhost:40000",
38
+ "--model-path", "Llama-3.1-8B-Omni",
39
+ "--model-name", "Llama-3.1-8B-Omni",
40
+ "--s2s"
41
+ ])
42
+ time.sleep(10) # Wait for model worker to start
43
+
44
+ print("Setup complete")
45
+
46
+ def predict(
47
+ self,
48
+ audio: Path = Input(description="Audio file for speech input", default=None),
49
+ text: str = Input(description="Text input (used if no audio is provided)", default=None),
50
+ ) -> str:
51
+ """Run inference on the model"""
52
+ if audio is None and not text:
53
+ return "Error: Please provide either an audio file or text input."
54
+
55
+ if audio is not None:
56
+ # Process audio input
57
+ print(f"Transcribing audio from {audio}...")
58
+
59
+ # Transcribe audio using Whisper
60
+ result = self.whisper_model.transcribe(str(audio))
61
+ transcription = result["text"]
62
+
63
+ print(f"Transcription: {transcription}")
64
+
65
+ # In a real implementation, this would process the transcription through LLaMA-Omni
66
+ # For this placeholder, we'll just return the transcription with a simulated response
67
+ response = f"Transcription: {transcription}\n\nResponse: This is a simulated response to your audio. In a real deployment, this would be processed through the LLaMA-Omni model."
68
+
69
+ return response
70
+ else:
71
+ # Process text input
72
+ print(f"Processing text: {text}")
73
+
74
+ # In a real implementation, this would process the text through LLaMA-Omni
75
+ # For this placeholder, we'll just return the text with a simulated response
76
+ response = f"Input: {text}\n\nResponse: This is a simulated response to your text. In a real deployment, this would be processed through the LLaMA-Omni model."
77
+
78
+ return response
79
+
80
+ def __del__(self):
81
+ """Clean up processes on shutdown"""
82
+ if hasattr(self, 'controller_process'):
83
+ self.controller_process.terminate()
84
+
85
+ if hasattr(self, 'model_worker_process'):
86
+ self.model_worker_process.terminate()
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.setuptools]
6
+ packages = ["omni_speech"]
7
+
8
+ [project]
9
+ name = "llama-omni"
10
+ version = "0.1.0"
11
+ description = "LLaMA-Omni: Seamless Speech Interaction with Large Language Models"
12
+ authors = [
13
+ {name = "Qingkai Fang", email = "[email protected]"},
14
+ ]
15
+ readme = "README.md"
16
+ requires-python = ">=3.10"
17
+ dependencies = [
18
+ "torch>=2.0.0",
19
+ "transformers>=4.34.0",
20
+ "accelerate>=0.21.0",
21
+ "gradio>=3.50.2",
22
+ "fastapi>=0.104.0",
23
+ "uvicorn>=0.23.2",
24
+ "pydantic>=2.3.0",
25
+ "whisper>=0.0.1",
26
+ "numpy>=1.24.0",
27
+ "tqdm>=4.66.1",
28
+ "flash-attn>=2.3.0",
29
+ "fairseq>=0.12.2",
30
+ ]
requirements.txt CHANGED
@@ -1,15 +1,13 @@
1
  torch>=2.0.0
2
- torchaudio>=2.0.0
3
- transformers>=4.30.0
4
- tokenizers>=0.13.0
5
- gradio>=3.30.0
6
- huggingface-hub>=0.16.0
7
- safetensors>=0.3.1
 
8
  numpy>=1.24.0
9
- einops>=0.6.0
10
- diffusers>=0.18.0
11
- accelerate>=0.20.0
12
- soundfile>=0.12.1
13
- librosa>=0.10.0
14
- pydub
15
- ffmpeg-python
 
1
  torch>=2.0.0
2
+ transformers>=4.34.0
3
+ accelerate>=0.21.0
4
+ gradio>=3.50.2
5
+ fastapi>=0.104.0
6
+ uvicorn>=0.23.2
7
+ pydantic>=2.3.0
8
+ openai-whisper>=0.0.1
9
  numpy>=1.24.0
10
+ tqdm>=4.66.1
11
+ git+https://github.com/pytorch/fairseq.git
12
+ flash-attn>=2.3.0
13
+ requests>=2.31.0
 
 
 
run_without_downloads.sh DELETED
@@ -1,55 +0,0 @@
1
- #!/bin/bash
2
- # Script para executar o LLaMA-Omni2 sem baixar modelos localmente
3
-
4
- # Definir a variável de ambiente NO_DOWNLOAD
5
- export NO_DOWNLOAD=1
6
-
7
- # Verificar se a variável foi definida
8
- echo "Verificando variável de ambiente NO_DOWNLOAD..."
9
- echo "NO_DOWNLOAD=$NO_DOWNLOAD"
10
-
11
- # Adicionar modo de depuração para verificar o funcionamento
12
- export PYTHONVERBOSE=1
13
- export PYTHONPATH=$(pwd):$PYTHONPATH
14
-
15
- # Criar arquivo temporário de verificação
16
- python -c "
17
- import os
18
- with open('env_check.txt', 'w') as f:
19
- f.write(f'NO_DOWNLOAD={os.environ.get(\"NO_DOWNLOAD\", \"não definido\")}')
20
- "
21
-
22
- # Mostrar o conteúdo do arquivo de verificação
23
- echo "Conteúdo do arquivo de verificação:"
24
- cat env_check.txt
25
-
26
- # Executar a aplicação
27
- echo "Executando LLaMA-Omni2 no modo sem download (NO_DOWNLOAD=1)"
28
- echo "Os modelos serão usados diretamente do Hugging Face Hub, sem baixar localmente"
29
- echo "======================================================================"
30
-
31
- # Verificar qual aplicação iniciar
32
- if [ "$1" == "app" ] || [ "$1" == "" ]; then
33
- echo "Iniciando app.py..."
34
- # Verificar se a variável está disponível para o Python
35
- python -c "import os; print('NO_DOWNLOAD environment variable:', os.environ.get('NO_DOWNLOAD', 'not set'))"
36
- # Executar com a variável de ambiente explícita
37
- NO_DOWNLOAD=1 python app.py
38
- elif [ "$1" == "launcher" ]; then
39
- echo "Iniciando launcher..."
40
- python -c "import os; print('NO_DOWNLOAD environment variable:', os.environ.get('NO_DOWNLOAD', 'not set'))"
41
- # Usar a opção de linha de comando
42
- NO_DOWNLOAD=1 python launch_llama_omni2.py --no-model-download
43
- elif [ "$1" == "audio" ]; then
44
- echo "Iniciando interface de áudio..."
45
- python -c "import os; print('NO_DOWNLOAD environment variable:', os.environ.get('NO_DOWNLOAD', 'not set'))"
46
- NO_DOWNLOAD=1 python audio_interface.py
47
- else
48
- echo "Uso: $0 [app|launcher|audio]"
49
- echo " app - Inicia app.py (padrão)"
50
- echo " launcher - Inicia launch_llama_omni2.py"
51
- echo " audio - Inicia audio_interface.py"
52
- fi
53
-
54
- # Limpar arquivo temporário
55
- rm -f env_check.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/README.md DELETED
@@ -1,116 +0,0 @@
1
- # Teste LLaMA-Omni2-0.5B no Hugging Face
2
-
3
- Este diretório contém um script completo para testar o modelo LLaMA-Omni2-0.5B implantado no Hugging Face.
4
-
5
- ## Características do Script
6
-
7
- - Teste da API programaticamente (modo api)
8
- - Interface de teste manual no navegador (modo manual)
9
- - Transcrição local de áudio com Whisper
10
- - Envio de texto diretamente para o modelo
11
- - Salvamento da transcrição e das respostas para referência
12
-
13
- ## Pré-requisitos
14
-
15
- Antes de executar o script de teste, certifique-se de ter instalado as dependências necessárias:
16
-
17
- ```bash
18
- pip install requests gradio-client
19
- ```
20
-
21
- Para transcrição de áudio (opcional), você pode instalar o Whisper:
22
-
23
- ```bash
24
- pip install openai-whisper
25
- ```
26
-
27
- ## Uso
28
-
29
- Você pode executar o script de teste usando o seguinte comando:
30
-
31
- ```bash
32
- cd tests
33
- python test_llama_omni_api.py
34
- ```
35
-
36
- Por padrão, o script executará ambos os modos (api e manual) e irá:
37
- 1. Tentar transcrever o arquivo test.mp3 usando Whisper (se disponível)
38
- 2. Se o Whisper não estiver disponível ou o arquivo não existir, usará uma mensagem de teste padrão
39
- 3. Testar a API programaticamente e salvar a resposta
40
- 4. Salvar o texto de entrada em um arquivo para fácil cópia
41
- 5. Abrir a interface web do LLaMA-Omni2-0.5B no Hugging Face no seu navegador
42
- 6. Fornecer instruções para teste manual
43
-
44
- ### Parâmetros de linha de comando
45
-
46
- O script aceita os seguintes argumentos de linha de comando:
47
-
48
- - `--api-url`: URL da interface Gradio (padrão: https://marcosremar2-llama-omni.hf.space)
49
- - `--audio-file`: Caminho para o arquivo de áudio a ser transcrito localmente (padrão: test.mp3)
50
- - `--text`: Texto para usar diretamente (em vez de transcrever áudio)
51
- - `--output-dir`: Diretório para salvar a transcrição e respostas (padrão: ./output)
52
- - `--mode`: Modo de teste: api (programático), manual (navegador) ou both (ambos) (padrão: both)
53
-
54
- ### Exemplos de uso com parâmetros personalizados:
55
-
56
- ```bash
57
- # Usando entrada de texto direta
58
- python test_llama_omni_api.py --text "Olá, esta é uma mensagem de teste para o LLaMA-Omni2-0.5B."
59
-
60
- # Usando um arquivo de áudio personalizado para transcrição
61
- python test_llama_omni_api.py --audio-file /caminho/para/seu/audio.mp3
62
-
63
- # Testando apenas o modo API programaticamente
64
- python test_llama_omni_api.py --mode api
65
-
66
- # Apenas abrindo a interface web com um texto personalizado
67
- python test_llama_omni_api.py --mode manual --text "Teste manual do LLaMA-Omni2-0.5B"
68
- ```
69
-
70
- ## Modos de Teste
71
-
72
- ### 1. Modo API (Programático)
73
-
74
- Envia diretamente uma solicitação para a API do modelo e salva a resposta em um arquivo:
75
-
76
- - Conecta-se à API do Gradio com timeout aumentado
77
- - Lista os endpoints disponíveis
78
- - Envia o texto para o endpoint de geração
79
- - Salva a resposta recebida em um arquivo
80
- - Também consulta informações básicas do modelo
81
-
82
- ### 2. Modo Manual (Interface Web)
83
-
84
- Facilita o teste manual com o seguinte fluxo de trabalho:
85
-
86
- 1. **Preparação do Texto**: O texto de entrada é salvo em um arquivo para fácil cópia
87
- 2. **Abertura do Navegador**: O script abre a interface web no seu navegador padrão
88
- 3. **Interação Manual**: Você precisa manualmente:
89
- - Copiar o texto do arquivo salvo
90
- - Colar no campo "Input Text" na interface web
91
- - Clicar no botão "Generate"
92
- - Aguardar a resposta
93
- - Copiar e salvar a resposta para seus registros
94
-
95
- ## Solução de Problemas
96
-
97
- Se encontrar algum problema:
98
-
99
- 1. Verifique se a URL da interface web está correta e o serviço está em execução
100
- 2. Certifique-se de ter uma conexão com a internet
101
- 3. Se estiver usando transcrição de áudio, certifique-se de que o Whisper esteja instalado corretamente
102
- 4. No modo API, verifique se o Gradio Space está ativo (às vezes eles "dormem" quando inativos)
103
-
104
- ## Erros Comuns
105
-
106
- ### Dependências Ausentes
107
-
108
- Se você ver erros relacionados a módulos não encontrados, instale as dependências necessárias:
109
-
110
- ```bash
111
- pip install requests gradio-client openai-whisper
112
- ```
113
-
114
- ### Deploy no Hugging Face
115
-
116
- Este script é apenas para teste do modelo LLaMA-Omni2-0.5B já implantado no Hugging Face. Para fazer o deploy do modelo no Hugging Face Spaces, você só precisa fazer push do seu código para o repositório correspondente no Hugging Face.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test.mp3 DELETED
Binary file (13.5 kB)
 
tests/test_llama_omni_api.py DELETED
@@ -1,223 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Teste completo para o LLaMA-Omni2-0.5B no Hugging Face
4
- Este script pode:
5
- 1. Transcrever áudio localmente e enviar para o modelo
6
- 2. Enviar texto diretamente para o modelo
7
- 3. Facilita o teste manual com interface web
8
- 4. Testar a API diretamente de modo programático
9
- """
10
-
11
- import os
12
- import sys
13
- import time
14
- import argparse
15
- import requests
16
- import subprocess
17
- import webbrowser
18
- from pathlib import Path
19
- from gradio_client import Client
20
-
21
- # Configurações padrão
22
- DEFAULT_API_URL = "https://marcosremar2-llama-omni.hf.space"
23
- DEFAULT_OUTPUT_DIR = "./output"
24
- MODEL_NAME = "LLaMA-Omni2-0.5B"
25
-
26
- def transcribe_audio_locally(audio_file_path):
27
- """
28
- Transcreve áudio localmente usando whisper se disponível
29
- Caso contrário, retorna uma mensagem padrão
30
- """
31
- try:
32
- # Tenta usar whisper CLI se disponível
33
- result = subprocess.run(
34
- ["whisper", audio_file_path, "--model", "tiny", "--output_format", "txt"],
35
- capture_output=True,
36
- text=True,
37
- check=True
38
- )
39
- transcript_file = f"{os.path.splitext(audio_file_path)[0]}.txt"
40
- if os.path.exists(transcript_file):
41
- with open(transcript_file, "r") as f:
42
- transcript = f.read().strip()
43
- print(f"Transcrição: {transcript}")
44
- return transcript
45
- except (subprocess.CalledProcessError, FileNotFoundError) as e:
46
- print(f"Whisper não disponível ou erro: {e}")
47
-
48
- # Mensagem padrão
49
- print("Usando mensagem de teste padrão, já que whisper não está disponível")
50
- return f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
51
-
52
- def check_url_accessibility(url):
53
- """Verifica se a URL é acessível"""
54
- try:
55
- response = requests.get(url, timeout=10)
56
- if response.status_code == 200:
57
- return True
58
- else:
59
- print(f"URL retornou código {response.status_code}")
60
- return False
61
- except Exception as e:
62
- print(f"Erro ao acessar URL: {e}")
63
- return False
64
-
65
- def save_text_to_file(text, output_dir, filename="text.txt"):
66
- """Salva texto em arquivo para fácil cópia"""
67
- os.makedirs(output_dir, exist_ok=True)
68
- filepath = os.path.join(output_dir, filename)
69
-
70
- with open(filepath, "w") as f:
71
- f.write(text)
72
-
73
- print(f"Texto salvo em: {filepath}")
74
- return filepath
75
-
76
- def test_api_programmatically(api_url, text_input, output_dir=DEFAULT_OUTPUT_DIR):
77
- """
78
- Testa a API do modelo programaticamente enviando um texto
79
- e salvando a resposta
80
- """
81
- output_path = os.path.join(output_dir, f"response_{int(time.time())}.txt")
82
- os.makedirs(output_dir, exist_ok=True)
83
-
84
- print(f"Testando API em: {api_url}")
85
- print(f"Texto de entrada: {text_input[:50]}..." if len(text_input) > 50 else f"Texto de entrada: {text_input}")
86
-
87
- try:
88
- # Conecta ao app Gradio com timeout aumentado
89
- client = Client(
90
- api_url,
91
- httpx_kwargs={"timeout": 300.0} # 5 minutos de timeout
92
- )
93
-
94
- print("Conectado à API com sucesso")
95
-
96
- # Lista os endpoints disponíveis
97
- print("Endpoints disponíveis:")
98
- client.view_api()
99
-
100
- # Envia o prompt para o modelo
101
- print(f"\nUsando endpoint de geração de texto (/lambda_1)...")
102
- print(f"Enviando prompt: '{text_input[:50]}...'")
103
- job = client.submit(
104
- text_input,
105
- MODEL_NAME,
106
- api_name="/lambda_1"
107
- )
108
-
109
- print("Requisição enviada, aguardando resposta...")
110
- result = job.result()
111
- print(f"Resposta recebida (tamanho: {len(str(result))} caracteres)")
112
-
113
- # Salva a resposta em arquivo
114
- with open(output_path, "w") as f:
115
- f.write(str(result))
116
-
117
- print(f"Resposta salva em: {output_path}")
118
-
119
- # Tenta obter informações do modelo
120
- try:
121
- print("\nConsultando informações do modelo...")
122
- model_info = client.submit(api_name="/lambda").result()
123
- print(f"Informações do modelo: {model_info}")
124
- except Exception as model_error:
125
- print(f"Erro ao obter informações do modelo: {str(model_error)}")
126
-
127
- return True, result
128
-
129
- except Exception as e:
130
- print(f"Erro durante requisição à API: {str(e)}")
131
- print("Isso pode ocorrer porque o Space está dormindo e precisa de tempo para iniciar.")
132
- print("Tente acessar o Space diretamente primeiro: " + api_url)
133
- print(f"\nNota: Esta API é para o modelo {MODEL_NAME} e não processa áudio diretamente.")
134
- print("Para trabalhar com áudio, você precisaria primeiro transcrever o áudio usando Whisper,")
135
- print("e então enviar o texto transcrito para esta API.")
136
- return False, None
137
-
138
- def test_manual_interface(api_url, text_input, output_dir=DEFAULT_OUTPUT_DIR):
139
- """
140
- Prepara o teste manual do modelo via interface web:
141
- 1. Salva o texto em arquivo para fácil cópia
142
- 2. Abre a interface web para teste manual
143
- """
144
- # Verifica se a URL é acessível
145
- print(f"Verificando acessibilidade de {api_url}...")
146
- if not check_url_accessibility(api_url):
147
- print(f"Aviso: {api_url} não está acessível. Teste manual pode não ser possível.")
148
-
149
- # Salva o texto em arquivo para fácil cópia
150
- transcript_file = save_text_to_file(text_input, output_dir, "transcription.txt")
151
-
152
- # Instruções para teste manual
153
- print("\n" + "=" * 50)
154
- print(f"INSTRUÇÕES PARA TESTE MANUAL DO {MODEL_NAME}")
155
- print("=" * 50)
156
- print(f"1. O texto foi salvo em: {transcript_file}")
157
- print(f"2. Abrindo {api_url} no navegador...")
158
- print("3. Copie o texto do arquivo salvo e cole no campo 'Input Text'")
159
- print("4. Clique no botão 'Generate'")
160
- print("5. Quando receber a resposta, copie e salve para seus registros")
161
- print("=" * 50 + "\n")
162
-
163
- # Abre a URL no navegador padrão
164
- try:
165
- webbrowser.open(api_url)
166
- return True
167
- except Exception as e:
168
- print(f"Erro ao abrir navegador: {e}")
169
- print(f"Por favor, visite manualmente: {api_url}")
170
- return False
171
-
172
- def main():
173
- parser = argparse.ArgumentParser(description=f"Teste para {MODEL_NAME} no Hugging Face")
174
- parser.add_argument("--api-url", type=str, default=DEFAULT_API_URL,
175
- help=f"URL da interface Gradio (padrão: {DEFAULT_API_URL})")
176
- parser.add_argument("--audio-file", type=str, default="test.mp3",
177
- help="Caminho para o arquivo de áudio a ser transcrito localmente (opcional)")
178
- parser.add_argument("--text", type=str, default=None,
179
- help="Texto para usar diretamente (em vez de transcrever áudio)")
180
- parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
181
- help="Diretório para salvar a transcrição e respostas")
182
- parser.add_argument("--mode", type=str, choices=["api", "manual", "both"], default="both",
183
- help="Modo de teste: api (programático), manual (navegador) ou both (ambos)")
184
- args = parser.parse_args()
185
-
186
- # Converte caminhos relativos para absolutos
187
- if args.audio_file and not os.path.isabs(args.audio_file):
188
- if not os.path.exists(args.audio_file):
189
- script_dir = os.path.dirname(os.path.abspath(__file__))
190
- args.audio_file = os.path.join(script_dir, args.audio_file)
191
-
192
- if args.output_dir and not os.path.isabs(args.output_dir):
193
- script_dir = os.path.dirname(os.path.abspath(__file__))
194
- args.output_dir = os.path.join(script_dir, args.output_dir)
195
-
196
- # Obtém texto de entrada da transcrição ou do parâmetro
197
- input_text = args.text
198
- if not input_text and args.audio_file:
199
- if os.path.exists(args.audio_file):
200
- input_text = transcribe_audio_locally(args.audio_file)
201
- else:
202
- print(f"Arquivo de áudio não encontrado: {args.audio_file}")
203
- input_text = f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
204
- if not input_text:
205
- input_text = f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
206
-
207
- print(f"Texto de entrada: {input_text}")
208
-
209
- # Executa os testes conforme o modo selecionado
210
- success = True
211
- if args.mode in ["api", "both"]:
212
- api_success, _ = test_api_programmatically(args.api_url, input_text, args.output_dir)
213
- success = success and api_success
214
-
215
- if args.mode in ["manual", "both"]:
216
- manual_success = test_manual_interface(args.api_url, input_text, args.output_dir)
217
- success = success and manual_success
218
-
219
- # Sai com código apropriado
220
- sys.exit(0 if success else 1)
221
-
222
- if __name__ == "__main__":
223
- main()