ghostai1 commited on
Commit
8579576
·
verified ·
1 Parent(s): b36dea8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -158
app.py CHANGED
@@ -1,168 +1,135 @@
1
  import gradio as gr
2
- from gliner import GLiNER
3
- from vllm import LLM, SamplingParams
4
- from sentence_transformers import SentenceTransformer
5
- import faiss
6
- import numpy as np
7
- import json
8
  import torch
9
- import requests
10
- import threading
11
- from queue import Queue
12
  import logging
13
- import pynvml
 
 
 
14
 
15
- # Configure logging
16
- logging.basicConfig(level=logging.DEBUG)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Initialize NVML for GPU debugging
20
  try:
21
- pynvml.nvmlInit()
22
- device_count = pynvml.nvmlDeviceGetCount()
23
- logger.info(f"NVML Initialized. GPU Count: {device_count}")
24
- for i in range(device_count):
25
- handle = pynvml.nvmlDeviceGetHandleByIndex(i)
26
- name = pynvml.nvmlDeviceGetName(handle)
27
- logger.info(f"GPU {i}: {name}")
28
- except pynvml.NVMLError as e:
29
- logger.error(f"NVML Initialization Failed: {str(e)}")
30
- raise RuntimeError("Cannot initialize NVML. Check NVIDIA drivers.")
31
-
32
- # Verify CUDA
33
- if not torch.cuda.is_available():
34
- logger.error("CUDA not available")
35
- raise RuntimeError("No GPU detected. Ensure H200 GPU is available.")
36
- logger.info(f"CUDA Version: {torch.version.cuda}")
37
- logger.info(f"GPU Detected: {torch.cuda.get_device_name(0)}")
38
- logger.info(f"Device Count: {torch.cuda.device_count()}")
39
-
40
- # Load legal corpus
41
- with open("legal_corpus.json", "r", encoding="utf-8") as f:
42
- corpus = json.load(f)
43
- documents = [item["text"] for item in corpus]
44
-
45
- # Initialize sentence transformer (GPU)
46
- embedder = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
47
- embeddings = embedder.encode(documents, convert_to_numpy=True)
48
-
49
- # Initialize FAISS-GPU
50
- dimension = embeddings.shape[1]
51
- index = faiss.IndexFlatL2(dimension)
52
- index.add(embeddings)
53
-
54
- # Initialize GLiNER (GPU)
55
- gliner_model = GLiNER.from_pretrained("NAMAA-Space/gliner_arabic-v2.1", load_tokenizer=True)
56
- gliner_model = gliner_model.cuda()
57
-
58
- # Initialize LLM (default to Qwen2-7B-Instruct-AWQ)
59
- use_qwq_32b = False # Set to True if H200 detection is fixed
60
- model_name = "Qwen/Qwen2-7B-Instruct-AWQ" if not use_qwq_32b else "Qwen/QwQ-32B"
61
- try:
62
- llm = LLM(
63
- model=model_name,
64
- quantization="awq",
65
- max_model_len=4096,
66
- gpu_memory_utilization=0.9,
67
- device="cuda"
68
- )
69
- logger.info(f"Loaded LLM: {model_name}")
70
- except Exception as e:
71
- logger.error(f"Failed to initialize LLM: {str(e)}")
72
- raise
73
-
74
- sampling_params = SamplingParams(temperature=0.7, max_tokens=512)
75
-
76
- def fetch_external_legal_data(query, queue):
77
- """Fetch external legal data via HTTP request (mock API)."""
78
- try:
79
- response = requests.get(
80
- "https://api.example.com/legal",
81
- params={"query": query},
82
- timeout=5
83
  )
84
- response.raise_for_status()
85
- queue.put(response.json().get("text", "No external data found"))
86
- except requests.RequestException:
87
- queue.put("Failed to fetch external data")
88
-
89
- def run_ner(text, entity_types, queue):
90
- """Run NER with gliner_arabic-v2.1."""
91
- if not text or not entity_types:
92
- queue.put([])
93
- return
94
- entity_list = [e.strip() for e in entity_types.split(",")]
95
- entities = gliner_model.predict_entities(text, entity_list, threshold=0.5)
96
- queue.put([{"text": e["text"], "label": e["label"], "score": round(e["score"], 2)} for e in entities])
97
-
98
- def retrieve_documents(query, k=2):
99
- """Retrieve top-k documents using FAISS-GPU."""
100
- query_embedding = embedder.encode([query], convert_to_numpy=True)
101
- distances, indices = index.search(query_embedding, k)
102
- return [documents[idx] for idx in indices[0]]
103
-
104
- def generate_legal_insight(text, entities, retrieved_docs, external_data):
105
- """Generate insight with LLM using RAG."""
106
- entity_str = ", ".join([f"{e['text']} ({e['label']})" for e in entities])
107
- context = "\n".join(retrieved_docs) + "\nExternal Data: " + external_data
108
- prompt = f"""You are a legal assistant for Arabic law. Using the following context, extracted entities, and external data, provide a concise legal insight.
109
-
110
- Context:
111
- {context}
112
-
113
- Entities:
114
- {entity_str}
115
-
116
- Input Text:
117
- {text}
118
-
119
- Insight:"""
120
- outputs = llm.generate([prompt], sampling_params)
121
- return outputs[0].outputs[0].text
122
-
123
- def main_interface(text, entity_types):
124
- """Main Gradio interface with threading."""
125
- ner_queue = Queue()
126
- external_queue = Queue()
127
-
128
- ner_thread = threading.Thread(target=run_ner, args=(text, entity_types, ner_queue))
129
- external_thread = threading.Thread(target=fetch_external_legal_data, args=(text, external_queue))
130
-
131
- ner_thread.start()
132
- external_thread.start()
133
-
134
- ner_thread.join()
135
- external_thread.join()
136
-
137
- ner_result = ner_queue.get()
138
- external_data = external_queue.get()
139
-
140
- retrieved_docs = retrieve_documents(text)
141
-
142
- insight = generate_legal_insight(text, ner_result, retrieved_docs, external_data)
143
-
144
- return ner_result, retrieved_docs, external_data, insight
145
-
146
- # Gradio interface
147
- with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
148
- gr.Markdown("# Arabic Legal Demo: NER & RAG with GLiNER and LLM")
149
- with gr.Row():
150
- text_input = gr.Textbox(label="Arabic Legal Text", lines=5, placeholder="Enter Arabic legal text...")
151
- entity_types = gr.Textbox(
152
- label="Entity Types (comma-separated)",
153
- value="person,law,organization",
154
- placeholder="e.g., person,law,organization"
155
  )
156
- submit_btn = gr.Button("Analyze")
157
- ner_output = gr.JSON(label="Extracted Entities")
158
- docs_output = gr.Textbox(label="Retrieved Legal Context")
159
- external_output = gr.Textbox(label="External Legal Data")
160
- insight_output = gr.Textbox(label="Legal Insight")
161
-
162
- submit_btn.click(
163
- fn=main_interface,
164
- inputs=[text_input, entity_types],
165
- outputs=[ner_output, docs_output, external_output, insight_output]
166
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- demo.launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
  import torch
 
 
 
3
  import logging
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from diffusers import DiffusionPipeline
6
+ import soundfile as sf
7
+ import numpy as np
8
 
9
+ # Set up logging to debug startup issues
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
  logger = logging.getLogger(__name__)
12
 
 
13
  try:
14
+ # Load text tokenizer and embedding model (umt5-base)
15
+ def load_text_processor():
16
+ logger.info("Loading text processor (umt5-base)...")
17
+ tokenizer = AutoTokenizer.from_pretrained("./umt5-base")
18
+ text_model = AutoModel.from_pretrained(
19
+ "./umt5-base",
20
+ use_safetensors=True,
21
+ torch_dtype=torch.float16,
22
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
+ logger.info("Text processor loaded successfully.")
25
+ return tokenizer, text_model
26
+
27
+ # Load the transformer backbone (phantomstep_transformer)
28
+ def load_transformer():
29
+ logger.info("Loading transformer (phantomstep_transformer)...")
30
+ transformer = DiffusionPipeline.from_pretrained(
31
+ "./phantomstep_transformer",
32
+ use_safetensors=True,
33
+ torch_dtype=torch.float16,
34
+ device_map="auto"
35
+ )
36
+ logger.info("Transformer loaded successfully.")
37
+ return transformer
38
+
39
+ # Load the DCAE for audio encoding/decoding (phantomstep_dcae)
40
+ def load_dcae():
41
+ logger.info("Loading DCAE (phantomstep_dcae)...")
42
+ dcae = DiffusionPipeline.from_pretrained(
43
+ "./phantomstep_dcae",
44
+ use_safetensors=True,
45
+ torch_dtype=torch.float16,
46
+ device_map="auto"
47
+ )
48
+ logger.info("DCAE loaded successfully.")
49
+ return dcae
50
+
51
+ # Load the vocoder for audio synthesis (phantomstep_vocoder)
52
+ def load_vocoder():
53
+ logger.info("Loading vocoder (phantomstep_vocoder)...")
54
+ vocoder = DiffusionPipeline.from_pretrained(
55
+ "./phantomstep_vocoder",
56
+ use_safetensors=True,
57
+ torch_dtype=torch.float16,
58
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
+ logger.info("Vocoder loaded successfully.")
61
+ return vocoder
62
+
63
+ # Generate music from a text prompt
64
+ def generate_music(prompt, duration=20, seed=42):
65
+ logger.info(f"Generating music with prompt: {prompt}, duration: {duration}, seed: {seed}")
66
+ torch.manual_seed(seed)
67
+
68
+ # Load all components
69
+ tokenizer, text_model = load_text_processor()
70
+ transformer = load_transformer()
71
+ dcae = load_dcae()
72
+ vocoder = load_vocoder()
73
+
74
+ # Step 1: Process text prompt to embeddings
75
+ logger.info("Processing text prompt to embeddings...")
76
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
+ inputs = {k: v.to(text_model.device) for k, v in inputs.items()}
78
+ with torch.no_grad():
79
+ embeddings = text_model(**inputs).last_hidden_state.mean(dim=1)
80
+
81
+ # Step 2: Pass embeddings through transformer
82
+ logger.info("Generating with transformer...")
83
+ transformer_output = transformer(
84
+ embeddings,
85
+ num_inference_steps=50,
86
+ audio_length_in_s=duration
87
+ ).audios[0]
88
+
89
+ # Step 3: Decode audio features with DCAE
90
+ logger.info("Decoding with DCAE...")
91
+ dcae_output = dcae(
92
+ transformer_output,
93
+ num_inference_steps=50,
94
+ audio_length_in_s=duration
95
+ ).audios[0]
96
+
97
+ # Step 4: Synthesize final audio with vocoder
98
+ logger.info("Synthesizing with vocoder...")
99
+ audio = vocoder(
100
+ dcae_output,
101
+ num_inference_steps=50,
102
+ audio_length_in_s=duration
103
+ ).audios[0]
104
+
105
+ # Save audio to a file
106
+ output_path = "output.wav"
107
+ sf.write(output_path, audio, 22050) # 22kHz sample rate
108
+ logger.info("Music generation complete.")
109
+ return output_path
110
+
111
+ # Gradio interface
112
+ logger.info("Setting up Gradio interface...")
113
+ with gr.Blocks(title="PhantomStep: Text-to-Music Generation 🎵") as demo:
114
+ gr.Markdown("# PhantomStep by GhostAI 🚀")
115
+ gr.Markdown("Enter a text prompt to generate music! 🎶")
116
+
117
+ prompt_input = gr.Textbox(label="Text Prompt", placeholder="A jazzy piano melody with a fast tempo")
118
+ duration_input = gr.Slider(label="Duration (seconds)", minimum=10, maximum=60, value=20, step=1)
119
+ seed_input = gr.Number(label="Random Seed", value=42, precision=0)
120
+ generate_button = gr.Button("Generate Music")
121
+
122
+ audio_output = gr.Audio(label="Generated Music")
123
+
124
+ generate_button.click(
125
+ fn=generate_music,
126
+ inputs=[prompt_input, duration_input, seed_input],
127
+ outputs=audio_output
128
+ )
129
+
130
+ logger.info("Launching Gradio app...")
131
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
132
 
133
+ except Exception as e:
134
+ logger.error(f"Failed to start the application: {str(e)}")
135
+ raise