Spaces:
Build error
Build error
Commit
·
cc0c580
1
Parent(s):
ad298ab
ddssdsds
Browse files- Dockerfile +0 -83
- debug_llama_omni2.py +0 -168
- extract_llama_omni2_scripts.py +0 -215
- run_controller_directly.py +0 -192
- run_gradio_directly.py +0 -191
- run_model_worker_directly.py +0 -208
- test_llama_omni_api.py +0 -84
- tests/README.md +46 -18
- tests/test_llama_omni_api.py +114 -46
Dockerfile
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
# Use an official Python runtime as a parent image
|
2 |
-
FROM python:3.10-slim
|
3 |
-
|
4 |
-
# Set the working directory in the container
|
5 |
-
WORKDIR /code
|
6 |
-
|
7 |
-
# Set environment variables for pip
|
8 |
-
ENV PIP_NO_CACHE_DIR=off \
|
9 |
-
PIP_DISABLE_PIP_VERSION_CHECK=on
|
10 |
-
|
11 |
-
# Install system dependencies (git for cloning, build-essential for compiling C/C++ extensions)
|
12 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
13 |
-
git \
|
14 |
-
build-essential \
|
15 |
-
curl \
|
16 |
-
&& apt-get clean \
|
17 |
-
&& rm -rf /var/lib/apt/lists/*
|
18 |
-
|
19 |
-
# Copy all files from your Hugging Face Space repo
|
20 |
-
COPY . /code/
|
21 |
-
|
22 |
-
# Clone LLaMA-Omni2 and install it (WITHOUT editable flag)
|
23 |
-
RUN git clone https://github.com/ICTNLP/LLaMA-Omni2.git /tmp/LLaMA-Omni2 \
|
24 |
-
&& cd /tmp/LLaMA-Omni2 \
|
25 |
-
&& pip install . \
|
26 |
-
&& echo "--- PIP LIST AFTER LLaMA-Omni2 INSTALL --- " \
|
27 |
-
&& pip list | grep -i llama \
|
28 |
-
&& echo "--- PYTHON SYS.PATH AFTER LLaMA-Omni2 INSTALL --- " \
|
29 |
-
&& python -c "import sys; print(sys.path)" \
|
30 |
-
&& echo "--- TRYING TO IMPORT LLaMA-Omni2 --- " \
|
31 |
-
&& python -c "import llama_omni2; print(f'LLaMA-Omni2 imported successfully from {llama_omni2.__file__}')" \
|
32 |
-
&& echo "--- CHECKING WHERE LLAMA_OMNI2 IS INSTALLED --- " \
|
33 |
-
&& pip show llama-omni2 \
|
34 |
-
&& echo "--- DIAGNOSTICS END --- "
|
35 |
-
|
36 |
-
# Copy the LLaMA-Omni2 source code to /code as well for direct file access
|
37 |
-
RUN cp -r /tmp/LLaMA-Omni2/llama_omni2 /code/ \
|
38 |
-
&& echo "--- COPIED LLAMA_OMNI2 SOURCE TO /code ---" \
|
39 |
-
&& ls -la /code/llama_omni2 \
|
40 |
-
&& echo "--- CHECKING SERVE SCRIPTS ---" \
|
41 |
-
&& ls -la /code/llama_omni2/serve || echo "serve directory not found!"
|
42 |
-
|
43 |
-
# Make sure PYTHONPATH includes both /code and site-packages
|
44 |
-
ENV PYTHONPATH "${PYTHONPATH}:/code"
|
45 |
-
|
46 |
-
# Install any other explicit dependencies from requirements.txt
|
47 |
-
RUN pip install -r requirements.txt
|
48 |
-
|
49 |
-
# Make debug and extraction scripts executable
|
50 |
-
RUN chmod +x /code/debug_llama_omni2.py \
|
51 |
-
&& chmod +x /code/extract_llama_omni2_scripts.py
|
52 |
-
|
53 |
-
# Create startup script with enhanced diagnostics and fallbacks
|
54 |
-
RUN echo '#!/bin/bash\n\
|
55 |
-
echo "--- CONTAINER STARTING ---"\n\
|
56 |
-
echo "PYTHONPATH: $PYTHONPATH"\n\
|
57 |
-
echo "Python sys.path:"\n\
|
58 |
-
python -c "import sys; print(sys.path)"\n\
|
59 |
-
\n\
|
60 |
-
echo "Running diagnostic script..."\n\
|
61 |
-
python /code/debug_llama_omni2.py\n\
|
62 |
-
\n\
|
63 |
-
# Check if llama_omni2 module is importable\n\
|
64 |
-
if ! python -c "import llama_omni2" > /dev/null 2>&1; then\n\
|
65 |
-
echo "WARNING: llama_omni2 module cannot be imported. Extracting scripts as fallback..."\n\
|
66 |
-
python /code/extract_llama_omni2_scripts.py\n\
|
67 |
-
\n\
|
68 |
-
# Add the extracted directory to PYTHONPATH\n\
|
69 |
-
if [ -d "/code/llama_omni2_extracted" ]; then\n\
|
70 |
-
export PYTHONPATH="$PYTHONPATH:/code/llama_omni2_extracted"\n\
|
71 |
-
echo "Added /code/llama_omni2_extracted to PYTHONPATH: $PYTHONPATH"\n\
|
72 |
-
fi\n\
|
73 |
-
fi\n\
|
74 |
-
\n\
|
75 |
-
echo "Starting LLaMA-Omni2 application..."\n\
|
76 |
-
python app.py\n' > /code/startup.sh \
|
77 |
-
&& chmod +x /code/startup.sh
|
78 |
-
|
79 |
-
# Expose the port Gradio will run on
|
80 |
-
EXPOSE 7860
|
81 |
-
|
82 |
-
# Command to run the application
|
83 |
-
CMD ["/code/startup.sh"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
debug_llama_omni2.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
LLaMA-Omni2 Debug Script
|
4 |
-
-------------------
|
5 |
-
This script helps diagnose issues with LLaMA-Omni2 installation and imports.
|
6 |
-
It checks:
|
7 |
-
1. Python environment
|
8 |
-
2. Module locations
|
9 |
-
3. Import capabilities
|
10 |
-
4. Script locations
|
11 |
-
"""
|
12 |
-
|
13 |
-
import os
|
14 |
-
import sys
|
15 |
-
import importlib
|
16 |
-
import subprocess
|
17 |
-
|
18 |
-
def print_section(title):
|
19 |
-
"""Print a section header for better readability"""
|
20 |
-
print("\n" + "=" * 50)
|
21 |
-
print(f" {title} ".center(50, "="))
|
22 |
-
print("=" * 50)
|
23 |
-
|
24 |
-
def find_module_in_paths(module_name, paths=None):
|
25 |
-
"""Find all occurrences of a module in the specified paths"""
|
26 |
-
if paths is None:
|
27 |
-
paths = sys.path
|
28 |
-
|
29 |
-
found_locations = []
|
30 |
-
for path in paths:
|
31 |
-
potential_path = os.path.join(path, module_name)
|
32 |
-
if os.path.exists(potential_path):
|
33 |
-
found_locations.append(potential_path)
|
34 |
-
|
35 |
-
return found_locations
|
36 |
-
|
37 |
-
def find_scripts(script_name, search_dirs=None):
|
38 |
-
"""Find scripts by name in the specified directories"""
|
39 |
-
if search_dirs is None:
|
40 |
-
search_dirs = [
|
41 |
-
'/code',
|
42 |
-
'/tmp/LLaMA-Omni2',
|
43 |
-
'/usr/local/lib/python3.10/site-packages',
|
44 |
-
'/home/user'
|
45 |
-
]
|
46 |
-
|
47 |
-
found_scripts = []
|
48 |
-
|
49 |
-
for search_dir in search_dirs:
|
50 |
-
if not os.path.exists(search_dir):
|
51 |
-
continue
|
52 |
-
|
53 |
-
for root, dirs, files in os.walk(search_dir):
|
54 |
-
# Skip .git and other large dirs
|
55 |
-
dirs[:] = [d for d in dirs if d not in ('.git', 'node_modules')]
|
56 |
-
|
57 |
-
if script_name in files:
|
58 |
-
found_scripts.append(os.path.join(root, script_name))
|
59 |
-
|
60 |
-
return found_scripts
|
61 |
-
|
62 |
-
def check_pip_installed():
|
63 |
-
"""Check if llama_omni2 is properly installed via pip"""
|
64 |
-
try:
|
65 |
-
result = subprocess.run(['pip', 'show', 'llama-omni2'],
|
66 |
-
capture_output=True, text=True)
|
67 |
-
if result.returncode == 0:
|
68 |
-
print("LLaMA-Omni2 is installed via pip:")
|
69 |
-
print(result.stdout)
|
70 |
-
else:
|
71 |
-
print("LLaMA-Omni2 is NOT installed via pip")
|
72 |
-
except Exception as e:
|
73 |
-
print(f"Error checking pip installation: {e}")
|
74 |
-
|
75 |
-
def main():
|
76 |
-
# 1. Environment Information
|
77 |
-
print_section("ENVIRONMENT INFORMATION")
|
78 |
-
print(f"Python Executable: {sys.executable}")
|
79 |
-
print(f"Python Version: {sys.version}")
|
80 |
-
print(f"Working Directory: {os.getcwd()}")
|
81 |
-
|
82 |
-
# 2. PYTHONPATH
|
83 |
-
print_section("PYTHONPATH")
|
84 |
-
pythonpath = os.environ.get('PYTHONPATH', 'Not set')
|
85 |
-
print(f"PYTHONPATH Environment Variable: {pythonpath}")
|
86 |
-
|
87 |
-
# 3. sys.path
|
88 |
-
print_section("sys.path")
|
89 |
-
for i, path in enumerate(sys.path):
|
90 |
-
print(f"{i}: {path}")
|
91 |
-
|
92 |
-
# 4. Check if llama_omni2 is pip-installed
|
93 |
-
print_section("PIP INSTALLATION")
|
94 |
-
check_pip_installed()
|
95 |
-
|
96 |
-
# 5. Find llama_omni2 in sys.path
|
97 |
-
print_section("LLAMA_OMNI2 MODULE LOCATIONS")
|
98 |
-
found_locations = find_module_in_paths('llama_omni2')
|
99 |
-
if found_locations:
|
100 |
-
print("Found llama_omni2 module in the following locations:")
|
101 |
-
for loc in found_locations:
|
102 |
-
print(f" - {loc}")
|
103 |
-
else:
|
104 |
-
print("Could not find llama_omni2 module in sys.path!")
|
105 |
-
|
106 |
-
# 6. Try to import llama_omni2
|
107 |
-
print_section("IMPORT TEST")
|
108 |
-
try:
|
109 |
-
import llama_omni2
|
110 |
-
print(f"Successfully imported llama_omni2 from: {llama_omni2.__file__}")
|
111 |
-
|
112 |
-
# Check if key modules exist
|
113 |
-
modules_to_check = [
|
114 |
-
'llama_omni2.serve.controller',
|
115 |
-
'llama_omni2.serve.model_worker',
|
116 |
-
'llama_omni2.serve.gradio_web_server'
|
117 |
-
]
|
118 |
-
|
119 |
-
for module in modules_to_check:
|
120 |
-
try:
|
121 |
-
importlib.import_module(module)
|
122 |
-
print(f"✅ Successfully imported {module}")
|
123 |
-
except ImportError as e:
|
124 |
-
print(f"❌ Failed to import {module}: {e}")
|
125 |
-
|
126 |
-
except ImportError as e:
|
127 |
-
print(f"Failed to import llama_omni2: {e}")
|
128 |
-
|
129 |
-
# 7. Find core script files
|
130 |
-
print_section("SCRIPT LOCATIONS")
|
131 |
-
scripts_to_find = ['controller.py', 'model_worker.py', 'gradio_web_server.py']
|
132 |
-
|
133 |
-
for script in scripts_to_find:
|
134 |
-
found_scripts = find_scripts(script)
|
135 |
-
if found_scripts:
|
136 |
-
print(f"Found {script} at:")
|
137 |
-
for path in found_scripts:
|
138 |
-
print(f" - {path}")
|
139 |
-
else:
|
140 |
-
print(f"Could not find {script}")
|
141 |
-
|
142 |
-
# 8. Test running the scripts directly
|
143 |
-
print_section("DIRECT SCRIPT EXECUTION TEST")
|
144 |
-
|
145 |
-
for script in scripts_to_find:
|
146 |
-
found_scripts = find_scripts(script)
|
147 |
-
if found_scripts:
|
148 |
-
script_path = found_scripts[0]
|
149 |
-
print(f"Testing if {script_path} can be executed...")
|
150 |
-
try:
|
151 |
-
# Just import the script module directly to see if it loads
|
152 |
-
result = subprocess.run(
|
153 |
-
[sys.executable, '-c', f"import importlib.util; spec = importlib.util.spec_from_file_location('test', '{script_path}'); module = importlib.util.module_from_spec(spec); spec.loader.exec_module(module); print('Successfully loaded {script}')"],
|
154 |
-
capture_output=True, text=True, timeout=5
|
155 |
-
)
|
156 |
-
if result.returncode == 0:
|
157 |
-
print(f"✅ Script can be imported: {script_path}")
|
158 |
-
print(result.stdout)
|
159 |
-
else:
|
160 |
-
print(f"❌ Script import failed: {script_path}")
|
161 |
-
print(result.stderr)
|
162 |
-
except subprocess.TimeoutExpired:
|
163 |
-
print(f"⚠️ Script import timed out: {script_path}")
|
164 |
-
except Exception as e:
|
165 |
-
print(f"❌ Error testing script: {e}")
|
166 |
-
|
167 |
-
if __name__ == "__main__":
|
168 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extract_llama_omni2_scripts.py
DELETED
@@ -1,215 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Extract LLaMA-Omni2 Serve Scripts
|
4 |
-
---------------------------------
|
5 |
-
This script downloads and extracts just the necessary serve scripts from
|
6 |
-
the LLaMA-Omni2 GitHub repository to enable running the server components
|
7 |
-
without a full package installation.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import os
|
11 |
-
import sys
|
12 |
-
import subprocess
|
13 |
-
import shutil
|
14 |
-
import tempfile
|
15 |
-
import importlib.util
|
16 |
-
|
17 |
-
def print_section(title):
|
18 |
-
"""Print a section header"""
|
19 |
-
print("\n" + "=" * 60)
|
20 |
-
print(f" {title} ".center(60, "="))
|
21 |
-
print("=" * 60)
|
22 |
-
|
23 |
-
def clone_repo(target_dir):
|
24 |
-
"""Clone the LLaMA-Omni2 repository to a temporary location"""
|
25 |
-
print_section("CLONING REPOSITORY")
|
26 |
-
print(f"Cloning LLaMA-Omni2 repository to {target_dir}...")
|
27 |
-
|
28 |
-
try:
|
29 |
-
subprocess.run(
|
30 |
-
["git", "clone", "https://github.com/ICTNLP/LLaMA-Omni2.git", target_dir],
|
31 |
-
check=True
|
32 |
-
)
|
33 |
-
print(f"Successfully cloned repository to {target_dir}")
|
34 |
-
return True
|
35 |
-
except subprocess.CalledProcessError as e:
|
36 |
-
print(f"Failed to clone repository: {e}")
|
37 |
-
return False
|
38 |
-
|
39 |
-
def extract_serve_scripts(repo_dir, output_dir):
|
40 |
-
"""Extract the serve scripts and dependencies to the output directory"""
|
41 |
-
print_section("EXTRACTING SERVE SCRIPTS")
|
42 |
-
|
43 |
-
# Ensure output directories exist
|
44 |
-
serve_dir = os.path.join(output_dir, "llama_omni2", "serve")
|
45 |
-
os.makedirs(serve_dir, exist_ok=True)
|
46 |
-
|
47 |
-
# Copy serve scripts
|
48 |
-
source_serve_dir = os.path.join(repo_dir, "llama_omni2", "serve")
|
49 |
-
if not os.path.exists(source_serve_dir):
|
50 |
-
print(f"Error: Source serve directory not found at {source_serve_dir}")
|
51 |
-
return False
|
52 |
-
|
53 |
-
print(f"Copying serve scripts from {source_serve_dir} to {serve_dir}")
|
54 |
-
|
55 |
-
# Copy all files from serve directory
|
56 |
-
for filename in os.listdir(source_serve_dir):
|
57 |
-
source_file = os.path.join(source_serve_dir, filename)
|
58 |
-
if os.path.isfile(source_file):
|
59 |
-
shutil.copy2(source_file, serve_dir)
|
60 |
-
print(f"Copied {filename}")
|
61 |
-
|
62 |
-
# Copy __init__.py files to make the modules importable
|
63 |
-
init_files = [
|
64 |
-
os.path.join(output_dir, "llama_omni2", "__init__.py"),
|
65 |
-
os.path.join(serve_dir, "__init__.py")
|
66 |
-
]
|
67 |
-
|
68 |
-
for init_file in init_files:
|
69 |
-
if not os.path.exists(init_file):
|
70 |
-
with open(init_file, 'w') as f:
|
71 |
-
f.write("# Auto-generated __init__.py file\n")
|
72 |
-
print(f"Created {init_file}")
|
73 |
-
|
74 |
-
# Also copy key dependencies from the llama_omni2 module
|
75 |
-
modules_to_copy = [
|
76 |
-
"model",
|
77 |
-
"common"
|
78 |
-
]
|
79 |
-
|
80 |
-
for module in modules_to_copy:
|
81 |
-
source_module_dir = os.path.join(repo_dir, "llama_omni2", module)
|
82 |
-
if os.path.exists(source_module_dir):
|
83 |
-
target_module_dir = os.path.join(output_dir, "llama_omni2", module)
|
84 |
-
print(f"Copying {module} module to {target_module_dir}")
|
85 |
-
shutil.copytree(source_module_dir, target_module_dir, dirs_exist_ok=True)
|
86 |
-
|
87 |
-
# Add __init__.py file if it doesn't exist
|
88 |
-
init_file = os.path.join(target_module_dir, "__init__.py")
|
89 |
-
if not os.path.exists(init_file):
|
90 |
-
with open(init_file, 'w') as f:
|
91 |
-
f.write("# Auto-generated __init__.py file\n")
|
92 |
-
|
93 |
-
print("Extraction completed successfully")
|
94 |
-
return True
|
95 |
-
|
96 |
-
def test_scripts(scripts_dir):
|
97 |
-
"""Test if the extracted scripts can be imported"""
|
98 |
-
print_section("TESTING EXTRACTED SCRIPTS")
|
99 |
-
|
100 |
-
# Make sure the scripts directory is in the Python path
|
101 |
-
parent_dir = os.path.dirname(scripts_dir)
|
102 |
-
if parent_dir not in sys.path:
|
103 |
-
sys.path.insert(0, parent_dir)
|
104 |
-
|
105 |
-
# Try to import each script
|
106 |
-
script_paths = [
|
107 |
-
os.path.join(scripts_dir, "llama_omni2", "serve", "controller.py"),
|
108 |
-
os.path.join(scripts_dir, "llama_omni2", "serve", "model_worker.py"),
|
109 |
-
os.path.join(scripts_dir, "llama_omni2", "serve", "gradio_web_server.py")
|
110 |
-
]
|
111 |
-
|
112 |
-
for script_path in script_paths:
|
113 |
-
if not os.path.exists(script_path):
|
114 |
-
print(f"❌ Script not found: {script_path}")
|
115 |
-
continue
|
116 |
-
|
117 |
-
try:
|
118 |
-
script_name = os.path.basename(script_path).replace(".py", "")
|
119 |
-
spec = importlib.util.spec_from_file_location(script_name, script_path)
|
120 |
-
module = importlib.util.module_from_spec(spec)
|
121 |
-
spec.loader.exec_module(module)
|
122 |
-
print(f"✅ Successfully imported {script_path}")
|
123 |
-
except Exception as e:
|
124 |
-
print(f"❌ Failed to import {script_path}: {e}")
|
125 |
-
|
126 |
-
def create_usage_instructions(output_dir, scripts_dir):
|
127 |
-
"""Create usage instructions for the extracted scripts"""
|
128 |
-
print_section("CREATING USAGE INSTRUCTIONS")
|
129 |
-
|
130 |
-
instruction_file = os.path.join(output_dir, "README.md")
|
131 |
-
|
132 |
-
with open(instruction_file, 'w') as f:
|
133 |
-
f.write("""# LLaMA-Omni2 Extracted Serve Scripts
|
134 |
-
|
135 |
-
This directory contains the extracted serve scripts from LLaMA-Omni2 to run without a full package installation.
|
136 |
-
|
137 |
-
## Usage
|
138 |
-
|
139 |
-
### 1. Make sure Python can find these modules:
|
140 |
-
|
141 |
-
```bash
|
142 |
-
export PYTHONPATH=$PYTHONPATH:/path/to/this/directory
|
143 |
-
```
|
144 |
-
|
145 |
-
### 2. Run the controller:
|
146 |
-
|
147 |
-
```bash
|
148 |
-
python -m llama_omni2.serve.controller --host 0.0.0.0 --port 10000
|
149 |
-
```
|
150 |
-
|
151 |
-
### 3. Run the model worker:
|
152 |
-
|
153 |
-
```bash
|
154 |
-
python -m llama_omni2.serve.model_worker \\
|
155 |
-
--host 0.0.0.0 \\
|
156 |
-
--controller http://localhost:10000 \\
|
157 |
-
--port 40000 \\
|
158 |
-
--worker http://localhost:40000 \\
|
159 |
-
--model-path /path/to/model \\
|
160 |
-
--model-name MODEL_NAME
|
161 |
-
```
|
162 |
-
|
163 |
-
### 4. Run the Gradio web server:
|
164 |
-
|
165 |
-
```bash
|
166 |
-
python -m llama_omni2.serve.gradio_web_server \\
|
167 |
-
--host 0.0.0.0 \\
|
168 |
-
--port 7860 \\
|
169 |
-
--controller-url http://localhost:10000 \\
|
170 |
-
--model-list-mode reload \\
|
171 |
-
--vocoder-dir /path/to/vocoder
|
172 |
-
```
|
173 |
-
|
174 |
-
Alternatively, you can run these scripts directly:
|
175 |
-
|
176 |
-
```bash
|
177 |
-
python llama_omni2/serve/controller.py --host 0.0.0.0 --port 10000
|
178 |
-
```
|
179 |
-
""")
|
180 |
-
|
181 |
-
print(f"Created usage instructions at {instruction_file}")
|
182 |
-
|
183 |
-
def main():
|
184 |
-
"""Main function to extract and test LLaMA-Omni2 scripts"""
|
185 |
-
output_dir = "/home/user/app/llama_omni2_extracted"
|
186 |
-
|
187 |
-
print(f"This script will extract LLaMA-Omni2 serve scripts to {output_dir}")
|
188 |
-
|
189 |
-
# Create temporary directory for cloning
|
190 |
-
with tempfile.TemporaryDirectory() as temp_dir:
|
191 |
-
# Clone repository
|
192 |
-
if not clone_repo(temp_dir):
|
193 |
-
print("Failed to clone repository. Exiting.")
|
194 |
-
return 1
|
195 |
-
|
196 |
-
# Extract serve scripts
|
197 |
-
if not extract_serve_scripts(temp_dir, output_dir):
|
198 |
-
print("Failed to extract serve scripts. Exiting.")
|
199 |
-
return 1
|
200 |
-
|
201 |
-
# Create usage instructions
|
202 |
-
create_usage_instructions(output_dir, output_dir)
|
203 |
-
|
204 |
-
# Test scripts
|
205 |
-
test_scripts(output_dir)
|
206 |
-
|
207 |
-
print_section("EXTRACTION COMPLETED")
|
208 |
-
print(f"LLaMA-Omni2 serve scripts have been extracted to {output_dir}")
|
209 |
-
print(f"Add this directory to PYTHONPATH: export PYTHONPATH=$PYTHONPATH:{output_dir}")
|
210 |
-
print("See README.md for usage instructions")
|
211 |
-
|
212 |
-
return 0
|
213 |
-
|
214 |
-
if __name__ == "__main__":
|
215 |
-
sys.exit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_controller_directly.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
A standalone implementation of the LLaMA-Omni2 controller
|
4 |
-
that doesn't rely on any LLaMA-Omni2 imports.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import asyncio
|
9 |
-
import dataclasses
|
10 |
-
import json
|
11 |
-
import logging
|
12 |
-
import time
|
13 |
-
from typing import Dict, List, Optional, Set, Tuple, Union
|
14 |
-
|
15 |
-
import fastapi
|
16 |
-
from fastapi import BackgroundTasks, Request
|
17 |
-
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
18 |
-
import uvicorn
|
19 |
-
|
20 |
-
# Define constants
|
21 |
-
CONTROLLER_HEART_BEAT_EXPIRATION = 120
|
22 |
-
MODEL_WORKER_API_TIMEOUT = 100
|
23 |
-
|
24 |
-
# Configure logging
|
25 |
-
logging.basicConfig(level=logging.INFO)
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
|
28 |
-
# Define data models using dataclasses instead of pydantic
|
29 |
-
@dataclasses.dataclass
|
30 |
-
class ModelInfo:
|
31 |
-
id: str
|
32 |
-
name: str
|
33 |
-
worker_names: List[str]
|
34 |
-
time: float = dataclasses.field(default_factory=time.time)
|
35 |
-
|
36 |
-
@dataclasses.dataclass
|
37 |
-
class WorkerInfo:
|
38 |
-
worker_name: str
|
39 |
-
model_names: List[str]
|
40 |
-
check_heart_beat: bool
|
41 |
-
last_heart_beat: float = dataclasses.field(default_factory=time.time)
|
42 |
-
|
43 |
-
# Global state
|
44 |
-
worker_info: Dict[str, WorkerInfo] = {}
|
45 |
-
model_info: Dict[str, ModelInfo] = {}
|
46 |
-
worker_addr: Dict[str, str] = {}
|
47 |
-
|
48 |
-
# FastAPI app
|
49 |
-
app = fastapi.FastAPI()
|
50 |
-
|
51 |
-
@app.post("/register_worker")
|
52 |
-
async def register_worker(request: Request):
|
53 |
-
data = await request.json()
|
54 |
-
worker_name = data.get("worker_name")
|
55 |
-
worker_url = data.get("worker_url")
|
56 |
-
model_names = data.get("model_names", [])
|
57 |
-
check_heart_beat = data.get("check_heart_beat", True)
|
58 |
-
|
59 |
-
logger.info(f"Registering worker {worker_name} at {worker_url}")
|
60 |
-
|
61 |
-
worker_info[worker_name] = WorkerInfo(
|
62 |
-
worker_name=worker_name,
|
63 |
-
model_names=model_names,
|
64 |
-
check_heart_beat=check_heart_beat,
|
65 |
-
last_heart_beat=time.time()
|
66 |
-
)
|
67 |
-
worker_addr[worker_name] = worker_url
|
68 |
-
|
69 |
-
# Register models
|
70 |
-
for model_name in model_names:
|
71 |
-
if model_name in model_info:
|
72 |
-
model_info[model_name].worker_names.append(worker_name)
|
73 |
-
else:
|
74 |
-
model_id = f"model-{len(model_info)}"
|
75 |
-
model_info[model_name] = ModelInfo(
|
76 |
-
id=model_id,
|
77 |
-
name=model_name,
|
78 |
-
worker_names=[worker_name]
|
79 |
-
)
|
80 |
-
|
81 |
-
return {"result": "success"}
|
82 |
-
|
83 |
-
@app.post("/unregister_worker")
|
84 |
-
async def unregister_worker(request: Request):
|
85 |
-
data = await request.json()
|
86 |
-
worker_name = data.get("worker_name")
|
87 |
-
|
88 |
-
logger.info(f"Unregistering worker {worker_name}")
|
89 |
-
|
90 |
-
if worker_name in worker_info:
|
91 |
-
for model_name in worker_info[worker_name].model_names:
|
92 |
-
if model_name in model_info:
|
93 |
-
if worker_name in model_info[model_name].worker_names:
|
94 |
-
model_info[model_name].worker_names.remove(worker_name)
|
95 |
-
if len(model_info[model_name].worker_names) == 0:
|
96 |
-
del model_info[model_name]
|
97 |
-
|
98 |
-
del worker_info[worker_name]
|
99 |
-
|
100 |
-
if worker_name in worker_addr:
|
101 |
-
del worker_addr[worker_name]
|
102 |
-
|
103 |
-
return {"result": "success"}
|
104 |
-
|
105 |
-
@app.post("/heart_beat")
|
106 |
-
async def heart_beat(request: Request):
|
107 |
-
data = await request.json()
|
108 |
-
worker_name = data.get("worker_name")
|
109 |
-
|
110 |
-
if worker_name not in worker_info or worker_name not in worker_addr:
|
111 |
-
return {"result": "failure", "error": f"Worker {worker_name} not found"}
|
112 |
-
|
113 |
-
worker_info[worker_name].last_heart_beat = time.time()
|
114 |
-
|
115 |
-
return {"result": "success"}
|
116 |
-
|
117 |
-
@app.get("/list_models")
|
118 |
-
async def list_models():
|
119 |
-
models = []
|
120 |
-
for name, info in model_info.items():
|
121 |
-
models.append({
|
122 |
-
"id": info.id,
|
123 |
-
"name": name
|
124 |
-
})
|
125 |
-
return {"models": models}
|
126 |
-
|
127 |
-
@app.get("/get_worker_address")
|
128 |
-
async def get_worker_address(model_name: str):
|
129 |
-
if model_name not in model_info or not model_info[model_name].worker_names:
|
130 |
-
return JSONResponse(
|
131 |
-
{"error": f"No available workers for model {model_name}"},
|
132 |
-
status_code=400
|
133 |
-
)
|
134 |
-
|
135 |
-
# Simple round-robin selection among available workers
|
136 |
-
workers = model_info[model_name].worker_names
|
137 |
-
selected_worker = workers[int(time.time()) % len(workers)]
|
138 |
-
|
139 |
-
return {"address": worker_addr.get(selected_worker)}
|
140 |
-
|
141 |
-
@app.get("/worker_status")
|
142 |
-
async def worker_status():
|
143 |
-
return {"worker_info": [
|
144 |
-
{
|
145 |
-
"name": name,
|
146 |
-
"address": worker_addr.get(name),
|
147 |
-
"models": info.model_names,
|
148 |
-
"last_heart_beat": info.last_heart_beat,
|
149 |
-
"status": "alive" if not info.check_heart_beat or
|
150 |
-
(time.time() - info.last_heart_beat) < CONTROLLER_HEART_BEAT_EXPIRATION
|
151 |
-
else "dead"
|
152 |
-
}
|
153 |
-
for name, info in worker_info.items()
|
154 |
-
]}
|
155 |
-
|
156 |
-
@app.get("/status")
|
157 |
-
async def status():
|
158 |
-
return {
|
159 |
-
"model_info": [
|
160 |
-
{
|
161 |
-
"name": name,
|
162 |
-
"id": info.id,
|
163 |
-
"workers": info.worker_names
|
164 |
-
}
|
165 |
-
for name, info in model_info.items()
|
166 |
-
],
|
167 |
-
"worker_info": [
|
168 |
-
{
|
169 |
-
"name": name,
|
170 |
-
"address": worker_addr.get(name),
|
171 |
-
"models": info.model_names,
|
172 |
-
"last_heart_beat": info.last_heart_beat,
|
173 |
-
"status": "alive" if not info.check_heart_beat or
|
174 |
-
(time.time() - info.last_heart_beat) < CONTROLLER_HEART_BEAT_EXPIRATION
|
175 |
-
else "dead"
|
176 |
-
}
|
177 |
-
for name, info in worker_info.items()
|
178 |
-
]
|
179 |
-
}
|
180 |
-
|
181 |
-
# Run the server
|
182 |
-
def main():
|
183 |
-
parser = argparse.ArgumentParser(description="Controller for LLaMA-Omni2")
|
184 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
185 |
-
parser.add_argument("--port", type=int, default=10000)
|
186 |
-
args = parser.parse_args()
|
187 |
-
|
188 |
-
logger.info(f"Starting controller server at http://{args.host}:{args.port}")
|
189 |
-
uvicorn.run(app, host=args.host, port=args.port)
|
190 |
-
|
191 |
-
if __name__ == "__main__":
|
192 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_gradio_directly.py
DELETED
@@ -1,191 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
A minimal Gradio web interface for LLaMA-Omni2 that doesn't rely on
|
4 |
-
importing from the LLaMA-Omni2 package.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import asyncio
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import time
|
13 |
-
from typing import Dict, List, Optional
|
14 |
-
|
15 |
-
import aiohttp
|
16 |
-
import gradio as gr
|
17 |
-
|
18 |
-
# Configure logging
|
19 |
-
logging.basicConfig(level=logging.INFO)
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
class LLaMA_Omni2_UI:
|
23 |
-
def __init__(
|
24 |
-
self,
|
25 |
-
controller_url: str,
|
26 |
-
vocoder_dir: str
|
27 |
-
):
|
28 |
-
self.controller_url = controller_url
|
29 |
-
self.vocoder_dir = vocoder_dir
|
30 |
-
self.model_list = []
|
31 |
-
self.model_names = []
|
32 |
-
|
33 |
-
# Verify vocoder directory exists
|
34 |
-
if not os.path.exists(vocoder_dir):
|
35 |
-
logger.warning(f"Vocoder directory not found at {vocoder_dir}")
|
36 |
-
logger.warning("Voice synthesis will not be available")
|
37 |
-
else:
|
38 |
-
logger.info(f"Using vocoder at {vocoder_dir}")
|
39 |
-
|
40 |
-
async def fetch_model_list(self):
|
41 |
-
"""Fetch the list of models from the controller"""
|
42 |
-
try:
|
43 |
-
async with aiohttp.ClientSession() as session:
|
44 |
-
async with session.get(
|
45 |
-
f"{self.controller_url}/list_models",
|
46 |
-
timeout=30
|
47 |
-
) as response:
|
48 |
-
if response.status == 200:
|
49 |
-
data = await response.json()
|
50 |
-
self.model_list = data.get("models", [])
|
51 |
-
self.model_names = [model.get("name") for model in self.model_list]
|
52 |
-
return self.model_names
|
53 |
-
else:
|
54 |
-
logger.error(f"Failed to fetch model list: {await response.text()}")
|
55 |
-
return []
|
56 |
-
except Exception as e:
|
57 |
-
logger.error(f"Error fetching model list: {e}")
|
58 |
-
return []
|
59 |
-
|
60 |
-
async def get_worker_address(self, model_name: str):
|
61 |
-
"""Get the address of a worker serving the specified model"""
|
62 |
-
try:
|
63 |
-
async with aiohttp.ClientSession() as session:
|
64 |
-
async with session.get(
|
65 |
-
f"{self.controller_url}/get_worker_address?model_name={model_name}",
|
66 |
-
timeout=30
|
67 |
-
) as response:
|
68 |
-
if response.status == 200:
|
69 |
-
data = await response.json()
|
70 |
-
return data.get("address")
|
71 |
-
else:
|
72 |
-
logger.error(f"Failed to get worker address: {await response.text()}")
|
73 |
-
return None
|
74 |
-
except Exception as e:
|
75 |
-
logger.error(f"Error getting worker address: {e}")
|
76 |
-
return None
|
77 |
-
|
78 |
-
async def generate_text(self, prompt: str, model_name: str):
|
79 |
-
"""Generate text using the specified model"""
|
80 |
-
worker_addr = await self.get_worker_address(model_name)
|
81 |
-
if not worker_addr:
|
82 |
-
return f"Error: No worker available for model {model_name}"
|
83 |
-
|
84 |
-
try:
|
85 |
-
async with aiohttp.ClientSession() as session:
|
86 |
-
async with session.post(
|
87 |
-
f"{worker_addr}/generate",
|
88 |
-
json={"prompt": prompt},
|
89 |
-
timeout=120
|
90 |
-
) as response:
|
91 |
-
if response.status == 200:
|
92 |
-
data = await response.json()
|
93 |
-
return data.get("response", "No response received from model")
|
94 |
-
else:
|
95 |
-
error_text = await response.text()
|
96 |
-
logger.error(f"Failed to generate text: {error_text}")
|
97 |
-
return f"Error: {error_text}"
|
98 |
-
except Exception as e:
|
99 |
-
logger.error(f"Error generating text: {e}")
|
100 |
-
return f"Error: {str(e)}"
|
101 |
-
|
102 |
-
def build_demo(self):
|
103 |
-
"""Build the Gradio interface"""
|
104 |
-
with gr.Blocks(title="LLaMA-Omni2 Web UI") as demo:
|
105 |
-
gr.Markdown("# LLaMA-Omni2 Web UI")
|
106 |
-
|
107 |
-
with gr.Row():
|
108 |
-
with gr.Column(scale=1):
|
109 |
-
model_dropdown = gr.Dropdown(
|
110 |
-
choices=self.model_names or ["No models available"],
|
111 |
-
label="Model",
|
112 |
-
value=self.model_names[0] if self.model_names else None
|
113 |
-
)
|
114 |
-
|
115 |
-
refresh_button = gr.Button("Refresh Models")
|
116 |
-
|
117 |
-
with gr.Row():
|
118 |
-
with gr.Column(scale=3):
|
119 |
-
text_input = gr.Textbox(
|
120 |
-
lines=5,
|
121 |
-
placeholder="Enter text here...",
|
122 |
-
label="Input Text"
|
123 |
-
)
|
124 |
-
|
125 |
-
with gr.Row():
|
126 |
-
with gr.Column(scale=1):
|
127 |
-
submit_button = gr.Button("Generate", variant="primary")
|
128 |
-
clear_button = gr.Button("Clear")
|
129 |
-
|
130 |
-
with gr.Row():
|
131 |
-
with gr.Column(scale=3):
|
132 |
-
text_output = gr.Textbox(
|
133 |
-
lines=10,
|
134 |
-
label="Generated Text",
|
135 |
-
interactive=False
|
136 |
-
)
|
137 |
-
|
138 |
-
async def refresh_models():
|
139 |
-
model_names = await self.fetch_model_list()
|
140 |
-
return gr.Dropdown.update(choices=model_names or ["No models available"])
|
141 |
-
|
142 |
-
async def generate(text, model):
|
143 |
-
if not text.strip():
|
144 |
-
return "Please enter some text"
|
145 |
-
if not model or model == "No models available":
|
146 |
-
return "Please select a model"
|
147 |
-
|
148 |
-
return await self.generate_text(text, model)
|
149 |
-
|
150 |
-
def clear():
|
151 |
-
return "", ""
|
152 |
-
|
153 |
-
refresh_button.click(fn=lambda: asyncio.create_task(refresh_models()), outputs=[model_dropdown])
|
154 |
-
submit_button.click(fn=lambda text, model: asyncio.create_task(generate(text, model)),
|
155 |
-
inputs=[text_input, model_dropdown],
|
156 |
-
outputs=[text_output])
|
157 |
-
clear_button.click(fn=clear, outputs=[text_input, text_output])
|
158 |
-
|
159 |
-
return demo
|
160 |
-
|
161 |
-
def main():
|
162 |
-
parser = argparse.ArgumentParser(description="Gradio web server for LLaMA-Omni2")
|
163 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
164 |
-
parser.add_argument("--port", type=int, default=7860)
|
165 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
|
166 |
-
parser.add_argument("--vocoder-dir", type=str, required=True)
|
167 |
-
parser.add_argument("--share", action="store_true", help="Create a public link")
|
168 |
-
args = parser.parse_args()
|
169 |
-
|
170 |
-
logger.info(f"Using controller at {args.controller_url}")
|
171 |
-
|
172 |
-
# Create the UI
|
173 |
-
ui = LLaMA_Omni2_UI(
|
174 |
-
controller_url=args.controller_url,
|
175 |
-
vocoder_dir=args.vocoder_dir
|
176 |
-
)
|
177 |
-
|
178 |
-
# Start by fetching the model list
|
179 |
-
asyncio.run(ui.fetch_model_list())
|
180 |
-
|
181 |
-
# Build and launch the demo
|
182 |
-
demo = ui.build_demo()
|
183 |
-
demo.queue()
|
184 |
-
demo.launch(
|
185 |
-
server_name=args.host,
|
186 |
-
server_port=args.port,
|
187 |
-
share=args.share
|
188 |
-
)
|
189 |
-
|
190 |
-
if __name__ == "__main__":
|
191 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_model_worker_directly.py
DELETED
@@ -1,208 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
A simplified implementation of the LLaMA-Omni2 model worker
|
4 |
-
that doesn't rely on deep LLaMA-Omni2 imports.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import asyncio
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import re
|
13 |
-
import threading
|
14 |
-
import time
|
15 |
-
import uuid
|
16 |
-
from typing import Dict, List, Optional, Tuple
|
17 |
-
|
18 |
-
import aiohttp
|
19 |
-
import fastapi
|
20 |
-
from fastapi import BackgroundTasks, Request
|
21 |
-
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
22 |
-
import gradio as gr
|
23 |
-
import uvicorn
|
24 |
-
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
25 |
-
|
26 |
-
# Configure logging
|
27 |
-
logging.basicConfig(level=logging.INFO)
|
28 |
-
logger = logging.getLogger(__name__)
|
29 |
-
|
30 |
-
# Define constants
|
31 |
-
WORKER_HEART_BEAT_INTERVAL = 30
|
32 |
-
CONTROLLER_HEART_BEAT_EXPIRATION = 120
|
33 |
-
|
34 |
-
class ModelWorker:
|
35 |
-
def __init__(
|
36 |
-
self,
|
37 |
-
controller_addr: str,
|
38 |
-
worker_addr: str,
|
39 |
-
worker_id: str,
|
40 |
-
model_path: str,
|
41 |
-
model_name: str,
|
42 |
-
device: str = "cpu",
|
43 |
-
limit_worker_concurrency: int = 5,
|
44 |
-
):
|
45 |
-
self.controller_addr = controller_addr
|
46 |
-
self.worker_addr = worker_addr
|
47 |
-
self.worker_id = worker_id
|
48 |
-
self.model_path = model_path
|
49 |
-
self.model_name = model_name
|
50 |
-
self.device = device
|
51 |
-
self.limit_worker_concurrency = limit_worker_concurrency
|
52 |
-
|
53 |
-
# Track current requests
|
54 |
-
self.lock = asyncio.Lock()
|
55 |
-
self.messages = {}
|
56 |
-
self.sem = asyncio.Semaphore(limit_worker_concurrency)
|
57 |
-
|
58 |
-
# Placeholders - the real implementation would load the model
|
59 |
-
logger.info(f"Loading model from {model_path}...")
|
60 |
-
try:
|
61 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
62 |
-
self.model = None # In a real implementation, we would load the model here
|
63 |
-
logger.info(f"Model initialization successful (tokenizer only, no model)")
|
64 |
-
except Exception as e:
|
65 |
-
logger.error(f"Failed to load model: {e}")
|
66 |
-
logger.info("Using dummy model instead")
|
67 |
-
self.tokenizer = None
|
68 |
-
self.model = None
|
69 |
-
|
70 |
-
logger.info(f"Model loaded successfully ({model_name})")
|
71 |
-
|
72 |
-
async def generate_response(self, request_data):
|
73 |
-
"""Generate a response (simulated)"""
|
74 |
-
prompt = request_data.get("prompt", "")
|
75 |
-
response = f"This is a simulated response for prompt: {prompt[:30]}..."
|
76 |
-
return response
|
77 |
-
|
78 |
-
async def register_to_controller(self):
|
79 |
-
"""Register this worker with the controller"""
|
80 |
-
controller_addr = self.controller_addr
|
81 |
-
worker_addr = self.worker_addr
|
82 |
-
worker_id = self.worker_id
|
83 |
-
model_name = self.model_name
|
84 |
-
|
85 |
-
data = {
|
86 |
-
"worker_name": worker_id,
|
87 |
-
"worker_url": worker_addr,
|
88 |
-
"model_names": [model_name],
|
89 |
-
"check_heart_beat": True,
|
90 |
-
}
|
91 |
-
|
92 |
-
logger.info(f"Register to controller at {controller_addr}")
|
93 |
-
|
94 |
-
async with aiohttp.ClientSession() as session:
|
95 |
-
async with session.post(
|
96 |
-
f"{controller_addr}/register_worker",
|
97 |
-
json=data,
|
98 |
-
timeout=30,
|
99 |
-
) as response:
|
100 |
-
if response.status != 200:
|
101 |
-
logger.error(f"Failed to register to controller: {await response.text()}")
|
102 |
-
return False
|
103 |
-
else:
|
104 |
-
logger.info(f"Registered to controller successfully")
|
105 |
-
return True
|
106 |
-
|
107 |
-
async def send_heart_beat(self):
|
108 |
-
"""Send a heartbeat to the controller periodically"""
|
109 |
-
controller_addr = self.controller_addr
|
110 |
-
worker_id = self.worker_id
|
111 |
-
|
112 |
-
data = {
|
113 |
-
"worker_name": worker_id,
|
114 |
-
}
|
115 |
-
|
116 |
-
async with aiohttp.ClientSession() as session:
|
117 |
-
while True:
|
118 |
-
try:
|
119 |
-
async with session.post(
|
120 |
-
f"{controller_addr}/heart_beat",
|
121 |
-
json=data,
|
122 |
-
timeout=30,
|
123 |
-
) as response:
|
124 |
-
if response.status != 200:
|
125 |
-
logger.error(f"Failed to send heart beat: {await response.text()}")
|
126 |
-
except Exception as e:
|
127 |
-
logger.error(f"Error sending heart beat: {e}")
|
128 |
-
|
129 |
-
await asyncio.sleep(WORKER_HEART_BEAT_INTERVAL)
|
130 |
-
|
131 |
-
# FastAPI app
|
132 |
-
app = fastapi.FastAPI()
|
133 |
-
|
134 |
-
@app.post("/generate")
|
135 |
-
async def generate(request: Request):
|
136 |
-
"""Generate text based on the prompt"""
|
137 |
-
global model_worker
|
138 |
-
|
139 |
-
if not model_worker:
|
140 |
-
return JSONResponse(
|
141 |
-
{"error": "Model worker not initialized"},
|
142 |
-
status_code=500,
|
143 |
-
)
|
144 |
-
|
145 |
-
data = await request.json()
|
146 |
-
response = await model_worker.generate_response(data)
|
147 |
-
|
148 |
-
return {"response": response}
|
149 |
-
|
150 |
-
@app.get("/status")
|
151 |
-
async def status():
|
152 |
-
"""Get the status of the worker"""
|
153 |
-
global model_worker
|
154 |
-
|
155 |
-
if not model_worker:
|
156 |
-
return {"status": "offline"}
|
157 |
-
|
158 |
-
return {
|
159 |
-
"status": "online",
|
160 |
-
"model_name": model_worker.model_name,
|
161 |
-
"worker_id": model_worker.worker_id,
|
162 |
-
}
|
163 |
-
|
164 |
-
# Global model worker instance
|
165 |
-
model_worker = None
|
166 |
-
|
167 |
-
def start_background_tasks(app):
|
168 |
-
"""Start background tasks when the app starts"""
|
169 |
-
asyncio.create_task(model_worker.register_to_controller())
|
170 |
-
asyncio.create_task(model_worker.send_heart_beat())
|
171 |
-
|
172 |
-
# Run the server
|
173 |
-
def main():
|
174 |
-
global model_worker
|
175 |
-
|
176 |
-
parser = argparse.ArgumentParser(description="Model worker for LLaMA-Omni2")
|
177 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
178 |
-
parser.add_argument("--port", type=int, default=40000)
|
179 |
-
parser.add_argument("--controller", type=str, default="http://localhost:10000")
|
180 |
-
parser.add_argument("--worker", type=str, default="http://localhost:40000")
|
181 |
-
parser.add_argument("--model-path", type=str, required=True)
|
182 |
-
parser.add_argument("--model-name", type=str, required=True)
|
183 |
-
parser.add_argument("--limit-worker-concurrency", type=int, default=5)
|
184 |
-
parser.add_argument("--device", type=str, default="cpu")
|
185 |
-
args = parser.parse_args()
|
186 |
-
|
187 |
-
logger.info(f"Initializing model worker with model {args.model_name}")
|
188 |
-
|
189 |
-
# Initialize the model worker
|
190 |
-
worker_id = f"worker-{str(uuid.uuid4())[:8]}"
|
191 |
-
model_worker = ModelWorker(
|
192 |
-
controller_addr=args.controller,
|
193 |
-
worker_addr=args.worker,
|
194 |
-
worker_id=worker_id,
|
195 |
-
model_path=args.model_path,
|
196 |
-
model_name=args.model_name,
|
197 |
-
device=args.device,
|
198 |
-
limit_worker_concurrency=args.limit_worker_concurrency,
|
199 |
-
)
|
200 |
-
|
201 |
-
# Start the FastAPI app with background tasks
|
202 |
-
app.add_event_handler("startup", lambda: start_background_tasks(app))
|
203 |
-
|
204 |
-
logger.info(f"Starting model worker server at http://{args.host}:{args.port}")
|
205 |
-
uvicorn.run(app, host=args.host, port=args.port)
|
206 |
-
|
207 |
-
if __name__ == "__main__":
|
208 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_llama_omni_api.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
"""
|
3 |
-
Test script for LLaMA-Omni API on Hugging Face Spaces.
|
4 |
-
This script sends a text message to the LLaMA-Omni2-0.5B API and saves the response.
|
5 |
-
"""
|
6 |
-
|
7 |
-
import os
|
8 |
-
import time
|
9 |
-
from pathlib import Path
|
10 |
-
from gradio_client import Client
|
11 |
-
|
12 |
-
# API endpoint
|
13 |
-
API_URL = "https://marcosremar2-llama-omni.hf.space" # Gradio Space URL
|
14 |
-
|
15 |
-
# Input and output paths
|
16 |
-
INPUT_AUDIO_PATH = "/Users/marcos/Documents/projects/test/whisper-realtime/llama-omni/llama-omni/test.mp3"
|
17 |
-
OUTPUT_DIR = "./output"
|
18 |
-
OUTPUT_TEXT_PATH = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.txt")
|
19 |
-
|
20 |
-
def main():
|
21 |
-
"""Main function to test the LLaMA-Omni API"""
|
22 |
-
# Ensure output directory exists
|
23 |
-
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
24 |
-
|
25 |
-
print(f"Audio file path: {INPUT_AUDIO_PATH}")
|
26 |
-
print(f"API URL: {API_URL}")
|
27 |
-
|
28 |
-
try:
|
29 |
-
# Connect to the Gradio app with increased timeout
|
30 |
-
client = Client(
|
31 |
-
API_URL,
|
32 |
-
httpx_kwargs={"timeout": 300.0} # Increase timeout to 5 minutes
|
33 |
-
)
|
34 |
-
|
35 |
-
print("Connected to API successfully")
|
36 |
-
|
37 |
-
# Inspect the API endpoints
|
38 |
-
print("Available API endpoints:")
|
39 |
-
client.view_api()
|
40 |
-
|
41 |
-
# Since this is a text-based model (LLaMA-Omni2), we'll send a text prompt
|
42 |
-
# The audio file can't be directly processed by this API
|
43 |
-
print("\nUsing the text generation endpoint (/lambda_1)...")
|
44 |
-
|
45 |
-
# Create a text prompt describing the audio
|
46 |
-
prompt = """This is a test of the LLaMA-Omni2-0.5B API.
|
47 |
-
Please respond with a sample of what you can do as an AI assistant."""
|
48 |
-
|
49 |
-
# Submit the text to the API
|
50 |
-
print(f"Sending text prompt: '{prompt[:50]}...'")
|
51 |
-
job = client.submit(
|
52 |
-
prompt,
|
53 |
-
"LLaMA-Omni2-0.5B", # Updated model name
|
54 |
-
api_name="/lambda_1"
|
55 |
-
)
|
56 |
-
|
57 |
-
print("Job submitted, waiting for response...")
|
58 |
-
result = job.result()
|
59 |
-
print(f"Response received (length: {len(str(result))} characters)")
|
60 |
-
|
61 |
-
# Save the text result
|
62 |
-
with open(OUTPUT_TEXT_PATH, "w") as f:
|
63 |
-
f.write(str(result))
|
64 |
-
|
65 |
-
print(f"Text response saved to: {OUTPUT_TEXT_PATH}")
|
66 |
-
|
67 |
-
# Also try the model info endpoint
|
68 |
-
try:
|
69 |
-
print("\nQuerying model information...")
|
70 |
-
model_info = client.submit(api_name="/lambda").result()
|
71 |
-
print(f"Model info: {model_info}")
|
72 |
-
except Exception as model_error:
|
73 |
-
print(f"Error getting model info: {str(model_error)}")
|
74 |
-
|
75 |
-
except Exception as e:
|
76 |
-
print(f"Error during API request: {str(e)}")
|
77 |
-
print("This could be because the Space is currently sleeping and needs time to wake up.")
|
78 |
-
print("Try accessing the Space directly in a browser first: " + API_URL)
|
79 |
-
print("\nNote: This API is for the LLaMA-Omni2-0.5B model and does not directly process audio files.")
|
80 |
-
print("To work with audio, you would need to first transcribe the audio using a service like Whisper,")
|
81 |
-
print("then send the transcribed text to this API.")
|
82 |
-
|
83 |
-
if __name__ == "__main__":
|
84 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/README.md
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
# Teste LLaMA-Omni2-0.5B no Hugging Face
|
2 |
|
3 |
-
Este diretório contém um script
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
## Pré-requisitos
|
6 |
|
7 |
Antes de executar o script de teste, certifique-se de ter instalado as dependências necessárias:
|
8 |
|
9 |
```bash
|
10 |
-
pip install requests
|
11 |
```
|
12 |
|
13 |
Para transcrição de áudio (opcional), você pode instalar o Whisper:
|
@@ -25,23 +33,25 @@ cd tests
|
|
25 |
python test_llama_omni_api.py
|
26 |
```
|
27 |
|
28 |
-
Por padrão, o script irá:
|
29 |
1. Tentar transcrever o arquivo test.mp3 usando Whisper (se disponível)
|
30 |
-
2. Se o Whisper não estiver disponível, usará uma mensagem de teste padrão
|
31 |
-
3.
|
32 |
-
4.
|
33 |
-
5.
|
|
|
34 |
|
35 |
### Parâmetros de linha de comando
|
36 |
|
37 |
O script aceita os seguintes argumentos de linha de comando:
|
38 |
|
39 |
- `--api-url`: URL da interface Gradio (padrão: https://marcosremar2-llama-omni.hf.space)
|
40 |
-
- `--audio-file`: Caminho para o arquivo de áudio a ser transcrito localmente (padrão:
|
41 |
- `--text`: Texto para usar diretamente (em vez de transcrever áudio)
|
42 |
-
- `--output-dir`: Diretório para salvar a transcrição (padrão:
|
|
|
43 |
|
44 |
-
Exemplos de uso com parâmetros personalizados:
|
45 |
|
46 |
```bash
|
47 |
# Usando entrada de texto direta
|
@@ -49,16 +59,33 @@ python test_llama_omni_api.py --text "Olá, esta é uma mensagem de teste para o
|
|
49 |
|
50 |
# Usando um arquivo de áudio personalizado para transcrição
|
51 |
python test_llama_omni_api.py --audio-file /caminho/para/seu/audio.mp3
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
```
|
53 |
|
54 |
-
##
|
|
|
|
|
55 |
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
- Copiar o texto do arquivo salvo
|
63 |
- Colar no campo "Input Text" na interface web
|
64 |
- Clicar no botão "Generate"
|
@@ -72,15 +99,16 @@ Se encontrar algum problema:
|
|
72 |
1. Verifique se a URL da interface web está correta e o serviço está em execução
|
73 |
2. Certifique-se de ter uma conexão com a internet
|
74 |
3. Se estiver usando transcrição de áudio, certifique-se de que o Whisper esteja instalado corretamente
|
|
|
75 |
|
76 |
## Erros Comuns
|
77 |
|
78 |
### Dependências Ausentes
|
79 |
|
80 |
-
Se você ver erros
|
81 |
|
82 |
```bash
|
83 |
-
pip install openai-whisper
|
84 |
```
|
85 |
|
86 |
### Deploy no Hugging Face
|
|
|
1 |
# Teste LLaMA-Omni2-0.5B no Hugging Face
|
2 |
|
3 |
+
Este diretório contém um script completo para testar o modelo LLaMA-Omni2-0.5B implantado no Hugging Face.
|
4 |
+
|
5 |
+
## Características do Script
|
6 |
+
|
7 |
+
- Teste da API programaticamente (modo api)
|
8 |
+
- Interface de teste manual no navegador (modo manual)
|
9 |
+
- Transcrição local de áudio com Whisper
|
10 |
+
- Envio de texto diretamente para o modelo
|
11 |
+
- Salvamento da transcrição e das respostas para referência
|
12 |
|
13 |
## Pré-requisitos
|
14 |
|
15 |
Antes de executar o script de teste, certifique-se de ter instalado as dependências necessárias:
|
16 |
|
17 |
```bash
|
18 |
+
pip install requests gradio-client
|
19 |
```
|
20 |
|
21 |
Para transcrição de áudio (opcional), você pode instalar o Whisper:
|
|
|
33 |
python test_llama_omni_api.py
|
34 |
```
|
35 |
|
36 |
+
Por padrão, o script executará ambos os modos (api e manual) e irá:
|
37 |
1. Tentar transcrever o arquivo test.mp3 usando Whisper (se disponível)
|
38 |
+
2. Se o Whisper não estiver disponível ou o arquivo não existir, usará uma mensagem de teste padrão
|
39 |
+
3. Testar a API programaticamente e salvar a resposta
|
40 |
+
4. Salvar o texto de entrada em um arquivo para fácil cópia
|
41 |
+
5. Abrir a interface web do LLaMA-Omni2-0.5B no Hugging Face no seu navegador
|
42 |
+
6. Fornecer instruções para teste manual
|
43 |
|
44 |
### Parâmetros de linha de comando
|
45 |
|
46 |
O script aceita os seguintes argumentos de linha de comando:
|
47 |
|
48 |
- `--api-url`: URL da interface Gradio (padrão: https://marcosremar2-llama-omni.hf.space)
|
49 |
+
- `--audio-file`: Caminho para o arquivo de áudio a ser transcrito localmente (padrão: test.mp3)
|
50 |
- `--text`: Texto para usar diretamente (em vez de transcrever áudio)
|
51 |
+
- `--output-dir`: Diretório para salvar a transcrição e respostas (padrão: ./output)
|
52 |
+
- `--mode`: Modo de teste: api (programático), manual (navegador) ou both (ambos) (padrão: both)
|
53 |
|
54 |
+
### Exemplos de uso com parâmetros personalizados:
|
55 |
|
56 |
```bash
|
57 |
# Usando entrada de texto direta
|
|
|
59 |
|
60 |
# Usando um arquivo de áudio personalizado para transcrição
|
61 |
python test_llama_omni_api.py --audio-file /caminho/para/seu/audio.mp3
|
62 |
+
|
63 |
+
# Testando apenas o modo API programaticamente
|
64 |
+
python test_llama_omni_api.py --mode api
|
65 |
+
|
66 |
+
# Apenas abrindo a interface web com um texto personalizado
|
67 |
+
python test_llama_omni_api.py --mode manual --text "Teste manual do LLaMA-Omni2-0.5B"
|
68 |
```
|
69 |
|
70 |
+
## Modos de Teste
|
71 |
+
|
72 |
+
### 1. Modo API (Programático)
|
73 |
|
74 |
+
Envia diretamente uma solicitação para a API do modelo e salva a resposta em um arquivo:
|
75 |
|
76 |
+
- Conecta-se à API do Gradio com timeout aumentado
|
77 |
+
- Lista os endpoints disponíveis
|
78 |
+
- Envia o texto para o endpoint de geração
|
79 |
+
- Salva a resposta recebida em um arquivo
|
80 |
+
- Também consulta informações básicas do modelo
|
81 |
+
|
82 |
+
### 2. Modo Manual (Interface Web)
|
83 |
+
|
84 |
+
Facilita o teste manual com o seguinte fluxo de trabalho:
|
85 |
+
|
86 |
+
1. **Preparação do Texto**: O texto de entrada é salvo em um arquivo para fácil cópia
|
87 |
+
2. **Abertura do Navegador**: O script abre a interface web no seu navegador padrão
|
88 |
+
3. **Interação Manual**: Você precisa manualmente:
|
89 |
- Copiar o texto do arquivo salvo
|
90 |
- Colar no campo "Input Text" na interface web
|
91 |
- Clicar no botão "Generate"
|
|
|
99 |
1. Verifique se a URL da interface web está correta e o serviço está em execução
|
100 |
2. Certifique-se de ter uma conexão com a internet
|
101 |
3. Se estiver usando transcrição de áudio, certifique-se de que o Whisper esteja instalado corretamente
|
102 |
+
4. No modo API, verifique se o Gradio Space está ativo (às vezes eles "dormem" quando inativos)
|
103 |
|
104 |
## Erros Comuns
|
105 |
|
106 |
### Dependências Ausentes
|
107 |
|
108 |
+
Se você ver erros relacionados a módulos não encontrados, instale as dependências necessárias:
|
109 |
|
110 |
```bash
|
111 |
+
pip install requests gradio-client openai-whisper
|
112 |
```
|
113 |
|
114 |
### Deploy no Hugging Face
|
tests/test_llama_omni_api.py
CHANGED
@@ -1,16 +1,27 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
-
Teste
|
4 |
-
Este script
|
|
|
|
|
|
|
|
|
5 |
"""
|
6 |
|
7 |
import os
|
8 |
import sys
|
|
|
9 |
import argparse
|
10 |
import requests
|
11 |
import subprocess
|
12 |
import webbrowser
|
13 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def transcribe_audio_locally(audio_file_path):
|
16 |
"""
|
@@ -36,7 +47,7 @@ def transcribe_audio_locally(audio_file_path):
|
|
36 |
|
37 |
# Mensagem padrão
|
38 |
print("Usando mensagem de teste padrão, já que whisper não está disponível")
|
39 |
-
return "Olá, estou testando o modelo
|
40 |
|
41 |
def check_url_accessibility(url):
|
42 |
"""Verifica se a URL é acessível"""
|
@@ -51,58 +62,98 @@ def check_url_accessibility(url):
|
|
51 |
print(f"Erro ao acessar URL: {e}")
|
52 |
return False
|
53 |
|
54 |
-
def
|
55 |
-
"""Salva
|
56 |
os.makedirs(output_dir, exist_ok=True)
|
57 |
filepath = os.path.join(output_dir, filename)
|
58 |
|
59 |
with open(filepath, "w") as f:
|
60 |
f.write(text)
|
61 |
|
62 |
-
print(f"
|
63 |
return filepath
|
64 |
|
65 |
-
def
|
66 |
"""
|
67 |
-
|
68 |
-
|
69 |
-
2. Salva o texto em arquivo para fácil cópia
|
70 |
-
3. Abre a interface web para teste manual
|
71 |
-
|
72 |
-
Args:
|
73 |
-
api_url: URL da interface Gradio
|
74 |
-
audio_file_path: Caminho para o arquivo de áudio
|
75 |
-
text_input: Texto para usar diretamente (em vez de transcrever áudio)
|
76 |
-
output_dir: Diretório para salvar a transcrição
|
77 |
-
|
78 |
-
Returns:
|
79 |
-
bool: True se a preparação foi bem-sucedida, False caso contrário
|
80 |
"""
|
81 |
-
|
82 |
os.makedirs(output_dir, exist_ok=True)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# Verifica se a URL é acessível
|
85 |
print(f"Verificando acessibilidade de {api_url}...")
|
86 |
if not check_url_accessibility(api_url):
|
87 |
print(f"Aviso: {api_url} não está acessível. Teste manual pode não ser possível.")
|
88 |
|
89 |
-
# Obtém texto de entrada da transcrição ou do parâmetro
|
90 |
-
input_text = text_input
|
91 |
-
if not input_text and audio_file_path:
|
92 |
-
input_text = transcribe_audio_locally(audio_file_path)
|
93 |
-
if not input_text:
|
94 |
-
input_text = "Olá, estou testando o modelo LLaMA-Omni2-0.5B. Você pode me responder em português?"
|
95 |
-
|
96 |
-
print(f"Texto para usar: {input_text}")
|
97 |
-
|
98 |
# Salva o texto em arquivo para fácil cópia
|
99 |
-
transcript_file =
|
100 |
|
101 |
# Instruções para teste manual
|
102 |
print("\n" + "=" * 50)
|
103 |
-
print("INSTRUÇÕES PARA TESTE MANUAL")
|
104 |
print("=" * 50)
|
105 |
-
print(f"1.
|
106 |
print(f"2. Abrindo {api_url} no navegador...")
|
107 |
print("3. Copie o texto do arquivo salvo e cole no campo 'Input Text'")
|
108 |
print("4. Clique no botão 'Generate'")
|
@@ -119,15 +170,17 @@ def test_llama_omni_manual(api_url, audio_file_path=None, text_input=None, outpu
|
|
119 |
return False
|
120 |
|
121 |
def main():
|
122 |
-
parser = argparse.ArgumentParser(description="Teste para
|
123 |
-
parser.add_argument("--api-url", type=str, default=
|
124 |
-
help="URL da interface Gradio (padrão:
|
125 |
-
parser.add_argument("--audio-file", type=str, default="
|
126 |
help="Caminho para o arquivo de áudio a ser transcrito localmente (opcional)")
|
127 |
parser.add_argument("--text", type=str, default=None,
|
128 |
help="Texto para usar diretamente (em vez de transcrever áudio)")
|
129 |
-
parser.add_argument("--output-dir", type=str, default=
|
130 |
-
help="Diretório para salvar a transcrição")
|
|
|
|
|
131 |
args = parser.parse_args()
|
132 |
|
133 |
# Converte caminhos relativos para absolutos
|
@@ -140,13 +193,28 @@ def main():
|
|
140 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
141 |
args.output_dir = os.path.join(script_dir, args.output_dir)
|
142 |
|
143 |
-
#
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# Sai com código apropriado
|
152 |
sys.exit(0 if success else 1)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
"""
|
3 |
+
Teste completo para o LLaMA-Omni2-0.5B no Hugging Face
|
4 |
+
Este script pode:
|
5 |
+
1. Transcrever áudio localmente e enviar para o modelo
|
6 |
+
2. Enviar texto diretamente para o modelo
|
7 |
+
3. Facilita o teste manual com interface web
|
8 |
+
4. Testar a API diretamente de modo programático
|
9 |
"""
|
10 |
|
11 |
import os
|
12 |
import sys
|
13 |
+
import time
|
14 |
import argparse
|
15 |
import requests
|
16 |
import subprocess
|
17 |
import webbrowser
|
18 |
from pathlib import Path
|
19 |
+
from gradio_client import Client
|
20 |
+
|
21 |
+
# Configurações padrão
|
22 |
+
DEFAULT_API_URL = "https://marcosremar2-llama-omni.hf.space"
|
23 |
+
DEFAULT_OUTPUT_DIR = "./output"
|
24 |
+
MODEL_NAME = "LLaMA-Omni2-0.5B"
|
25 |
|
26 |
def transcribe_audio_locally(audio_file_path):
|
27 |
"""
|
|
|
47 |
|
48 |
# Mensagem padrão
|
49 |
print("Usando mensagem de teste padrão, já que whisper não está disponível")
|
50 |
+
return f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
|
51 |
|
52 |
def check_url_accessibility(url):
|
53 |
"""Verifica se a URL é acessível"""
|
|
|
62 |
print(f"Erro ao acessar URL: {e}")
|
63 |
return False
|
64 |
|
65 |
+
def save_text_to_file(text, output_dir, filename="text.txt"):
|
66 |
+
"""Salva texto em arquivo para fácil cópia"""
|
67 |
os.makedirs(output_dir, exist_ok=True)
|
68 |
filepath = os.path.join(output_dir, filename)
|
69 |
|
70 |
with open(filepath, "w") as f:
|
71 |
f.write(text)
|
72 |
|
73 |
+
print(f"Texto salvo em: {filepath}")
|
74 |
return filepath
|
75 |
|
76 |
+
def test_api_programmatically(api_url, text_input, output_dir=DEFAULT_OUTPUT_DIR):
|
77 |
"""
|
78 |
+
Testa a API do modelo programaticamente enviando um texto
|
79 |
+
e salvando a resposta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"""
|
81 |
+
output_path = os.path.join(output_dir, f"response_{int(time.time())}.txt")
|
82 |
os.makedirs(output_dir, exist_ok=True)
|
83 |
|
84 |
+
print(f"Testando API em: {api_url}")
|
85 |
+
print(f"Texto de entrada: {text_input[:50]}..." if len(text_input) > 50 else f"Texto de entrada: {text_input}")
|
86 |
+
|
87 |
+
try:
|
88 |
+
# Conecta ao app Gradio com timeout aumentado
|
89 |
+
client = Client(
|
90 |
+
api_url,
|
91 |
+
httpx_kwargs={"timeout": 300.0} # 5 minutos de timeout
|
92 |
+
)
|
93 |
+
|
94 |
+
print("Conectado à API com sucesso")
|
95 |
+
|
96 |
+
# Lista os endpoints disponíveis
|
97 |
+
print("Endpoints disponíveis:")
|
98 |
+
client.view_api()
|
99 |
+
|
100 |
+
# Envia o prompt para o modelo
|
101 |
+
print(f"\nUsando endpoint de geração de texto (/lambda_1)...")
|
102 |
+
print(f"Enviando prompt: '{text_input[:50]}...'")
|
103 |
+
job = client.submit(
|
104 |
+
text_input,
|
105 |
+
MODEL_NAME,
|
106 |
+
api_name="/lambda_1"
|
107 |
+
)
|
108 |
+
|
109 |
+
print("Requisição enviada, aguardando resposta...")
|
110 |
+
result = job.result()
|
111 |
+
print(f"Resposta recebida (tamanho: {len(str(result))} caracteres)")
|
112 |
+
|
113 |
+
# Salva a resposta em arquivo
|
114 |
+
with open(output_path, "w") as f:
|
115 |
+
f.write(str(result))
|
116 |
+
|
117 |
+
print(f"Resposta salva em: {output_path}")
|
118 |
+
|
119 |
+
# Tenta obter informações do modelo
|
120 |
+
try:
|
121 |
+
print("\nConsultando informações do modelo...")
|
122 |
+
model_info = client.submit(api_name="/lambda").result()
|
123 |
+
print(f"Informações do modelo: {model_info}")
|
124 |
+
except Exception as model_error:
|
125 |
+
print(f"Erro ao obter informações do modelo: {str(model_error)}")
|
126 |
+
|
127 |
+
return True, result
|
128 |
+
|
129 |
+
except Exception as e:
|
130 |
+
print(f"Erro durante requisição à API: {str(e)}")
|
131 |
+
print("Isso pode ocorrer porque o Space está dormindo e precisa de tempo para iniciar.")
|
132 |
+
print("Tente acessar o Space diretamente primeiro: " + api_url)
|
133 |
+
print(f"\nNota: Esta API é para o modelo {MODEL_NAME} e não processa áudio diretamente.")
|
134 |
+
print("Para trabalhar com áudio, você precisaria primeiro transcrever o áudio usando Whisper,")
|
135 |
+
print("e então enviar o texto transcrito para esta API.")
|
136 |
+
return False, None
|
137 |
+
|
138 |
+
def test_manual_interface(api_url, text_input, output_dir=DEFAULT_OUTPUT_DIR):
|
139 |
+
"""
|
140 |
+
Prepara o teste manual do modelo via interface web:
|
141 |
+
1. Salva o texto em arquivo para fácil cópia
|
142 |
+
2. Abre a interface web para teste manual
|
143 |
+
"""
|
144 |
# Verifica se a URL é acessível
|
145 |
print(f"Verificando acessibilidade de {api_url}...")
|
146 |
if not check_url_accessibility(api_url):
|
147 |
print(f"Aviso: {api_url} não está acessível. Teste manual pode não ser possível.")
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
# Salva o texto em arquivo para fácil cópia
|
150 |
+
transcript_file = save_text_to_file(text_input, output_dir, "transcription.txt")
|
151 |
|
152 |
# Instruções para teste manual
|
153 |
print("\n" + "=" * 50)
|
154 |
+
print(f"INSTRUÇÕES PARA TESTE MANUAL DO {MODEL_NAME}")
|
155 |
print("=" * 50)
|
156 |
+
print(f"1. O texto foi salvo em: {transcript_file}")
|
157 |
print(f"2. Abrindo {api_url} no navegador...")
|
158 |
print("3. Copie o texto do arquivo salvo e cole no campo 'Input Text'")
|
159 |
print("4. Clique no botão 'Generate'")
|
|
|
170 |
return False
|
171 |
|
172 |
def main():
|
173 |
+
parser = argparse.ArgumentParser(description=f"Teste para {MODEL_NAME} no Hugging Face")
|
174 |
+
parser.add_argument("--api-url", type=str, default=DEFAULT_API_URL,
|
175 |
+
help=f"URL da interface Gradio (padrão: {DEFAULT_API_URL})")
|
176 |
+
parser.add_argument("--audio-file", type=str, default="test.mp3",
|
177 |
help="Caminho para o arquivo de áudio a ser transcrito localmente (opcional)")
|
178 |
parser.add_argument("--text", type=str, default=None,
|
179 |
help="Texto para usar diretamente (em vez de transcrever áudio)")
|
180 |
+
parser.add_argument("--output-dir", type=str, default=DEFAULT_OUTPUT_DIR,
|
181 |
+
help="Diretório para salvar a transcrição e respostas")
|
182 |
+
parser.add_argument("--mode", type=str, choices=["api", "manual", "both"], default="both",
|
183 |
+
help="Modo de teste: api (programático), manual (navegador) ou both (ambos)")
|
184 |
args = parser.parse_args()
|
185 |
|
186 |
# Converte caminhos relativos para absolutos
|
|
|
193 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
194 |
args.output_dir = os.path.join(script_dir, args.output_dir)
|
195 |
|
196 |
+
# Obtém texto de entrada da transcrição ou do parâmetro
|
197 |
+
input_text = args.text
|
198 |
+
if not input_text and args.audio_file:
|
199 |
+
if os.path.exists(args.audio_file):
|
200 |
+
input_text = transcribe_audio_locally(args.audio_file)
|
201 |
+
else:
|
202 |
+
print(f"Arquivo de áudio não encontrado: {args.audio_file}")
|
203 |
+
input_text = f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
|
204 |
+
if not input_text:
|
205 |
+
input_text = f"Olá, estou testando o modelo {MODEL_NAME}. Você pode me responder em português?"
|
206 |
+
|
207 |
+
print(f"Texto de entrada: {input_text}")
|
208 |
+
|
209 |
+
# Executa os testes conforme o modo selecionado
|
210 |
+
success = True
|
211 |
+
if args.mode in ["api", "both"]:
|
212 |
+
api_success, _ = test_api_programmatically(args.api_url, input_text, args.output_dir)
|
213 |
+
success = success and api_success
|
214 |
+
|
215 |
+
if args.mode in ["manual", "both"]:
|
216 |
+
manual_success = test_manual_interface(args.api_url, input_text, args.output_dir)
|
217 |
+
success = success and manual_success
|
218 |
|
219 |
# Sai com código apropriado
|
220 |
sys.exit(0 if success else 1)
|