BrokenCircuits / multimod gui
joebruce1313's picture
Upload 2 files
fda421b verified
import sys
import numpy as np
import tensorflow as tf
from PyQt5.QtWidgets import (QApplication, QWidget, QVBoxLayout, QHBoxLayout, QTextEdit, QPushButton,
QLineEdit, QLabel, QFileDialog, QTabWidget, QProgressBar)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap
import sounddevice as sd
import soundfile as sf
import librosa
from PIL import Image
from multimodal_transformer import MultiModalTransformer, HParams
class WorkerThread(QThread):
finished = pyqtSignal(object)
def __init__(self, func, *args, **kwargs):
super().__init__()
self.func = func
self.args = args
self.kwargs = kwargs
def run(self):
result = self.func(*self.args, **self.kwargs)
self.finished.emit(result)
class EnhancedChatGUI(QWidget):
def __init__(self, model):
super().__init__()
self.model = model
self.initUI()
def initUI(self):
self.setWindowTitle('MultiModal Transformer Interface')
self.setGeometry(100, 100, 800, 600)
layout = QVBoxLayout()
# Create tabs
self.tabs = QTabWidget()
self.tabs.addTab(self.createChatTab(), "Chat")
self.tabs.addTab(self.createSpeechTab(), "Speech Recognition")
self.tabs.addTab(self.createImageTab(), "Image Captioning")
self.tabs.addTab(self.createMusicTab(), "Music Generation")
self.tabs.addTab(self.createAnomalyTab(), "Anomaly Detection")
layout.addWidget(self.tabs)
self.setLayout(layout)
def createChatTab(self):
widget = QWidget()
layout = QVBoxLayout()
self.chatDisplay = QTextEdit()
self.chatDisplay.setReadOnly(True)
layout.addWidget(self.chatDisplay)
inputLayout = QHBoxLayout()
self.inputField = QLineEdit()
self.inputField.returnPressed.connect(self.sendMessage)
inputLayout.addWidget(self.inputField)
sendButton = QPushButton('Send')
sendButton.clicked.connect(self.sendMessage)
inputLayout.addWidget(sendButton)
layout.addLayout(inputLayout)
traitLayout = QHBoxLayout()
self.traitLabel = QLabel('Adjust trait:')
self.traitInput = QLineEdit()
self.traitValue = QLineEdit()
self.traitButton = QPushButton('Update')
self.traitButton.clicked.connect(self.updateTrait)
traitLayout.addWidget(self.traitLabel)
traitLayout.addWidget(self.traitInput)
traitLayout.addWidget(self.traitValue)
traitLayout.addWidget(self.traitButton)
layout.addLayout(traitLayout)
widget.setLayout(layout)
return widget
def createSpeechTab(self):
widget = QWidget()
layout = QVBoxLayout()
self.recordButton = QPushButton('Record Audio (5 seconds)')
self.recordButton.clicked.connect(self.recordAudio)
layout.addWidget(self.recordButton)
self.speechOutput = QTextEdit()
self.speechOutput.setReadOnly(True)
layout.addWidget(self.speechOutput)
widget.setLayout(layout)
return widget
def createImageTab(self):
widget = QWidget()
layout = QVBoxLayout()
self.imageButton = QPushButton('Select Image')
self.imageButton.clicked.connect(self.selectImage)
layout.addWidget(self.imageButton)
self.imageLabel = QLabel()
layout.addWidget(self.imageLabel)
self.captionOutput = QTextEdit()
self.captionOutput.setReadOnly(True)
layout.addWidget(self.captionOutput)
widget.setLayout(layout)
return widget
def createMusicTab(self):
widget = QWidget()
layout = QVBoxLayout()
self.generateMusicButton = QPushButton('Generate Music')
self.generateMusicButton.clicked.connect(self.generateMusic)
layout.addWidget(self.generateMusicButton)
self.musicOutput = QTextEdit()
self.musicOutput.setReadOnly(True)
layout.addWidget(self.musicOutput)
widget.setLayout(layout)
return widget
def createAnomalyTab(self):
widget = QWidget()
layout = QVBoxLayout()
self.anomalyButton = QPushButton('Detect Anomalies')
self.anomalyButton.clicked.connect(self.detectAnomalies)
layout.addWidget(self.anomalyButton)
self.anomalyOutput = QTextEdit()
self.anomalyOutput.setReadOnly(True)
layout.addWidget(self.anomalyOutput)
widget.setLayout(layout)
return widget
def sendMessage(self):
userInput = self.inputField.text()
self.inputField.clear()
safeWordResponse = self.model.safe_word_format(userInput)
if safeWordResponse:
self.displayMessage("User: " + userInput)
self.displayMessage("AI: " + safeWordResponse)
return
self.displayMessage("User: " + userInput)
response = self.model.conversation(userInput)
self.displayMessage("AI: " + response)
def displayMessage(self, message):
self.chatDisplay.append(message)
def updateTrait(self):
trait = self.traitInput.text()
value = float(self.traitValue.text())
try:
self.model.fine_tune_personality(trait, value)
self.displayMessage(f"System: Updated {trait} to {value}")
except ValueError as e:
self.displayMessage(f"System Error: {str(e)}")
def recordAudio(self):
duration = 5 # seconds
fs = 16000 # Sample rate
recording = sd.rec(int(duration * fs), samplerate=fs, channels=1)
sd.wait()
sf.write('temp_recording.wav', recording, fs)
self.processSpeech('temp_recording.wav')
def processSpeech(self, file_path):
audio, _ = librosa.load(file_path, sr=16000)
audio_tensor = tf.convert_to_tensor(audio, dtype=tf.float32)
audio_tensor = tf.expand_dims(audio_tensor, axis=0)
worker = WorkerThread(self.model.pipe, audio_tensor, 'speech_recognition')
worker.finished.connect(self.onSpeechRecognitionFinished)
worker.start()
def onSpeechRecognitionFinished(self, result):
self.speechOutput.setText(f"Recognized Speech: {result}")
def selectImage(self):
file_path, _ = QFileDialog.getOpenFileName(self, "Select Image", "", "Image Files (*.png *.jpg *.bmp)")
if file_path:
pixmap = QPixmap(file_path)
self.imageLabel.setPixmap(pixmap.scaled(300, 300, Qt.KeepAspectRatio))
self.processImage(file_path)
def processImage(self, file_path):
image = Image.open(file_path)
image = image.resize((224, 224))
image_array = np.array(image) / 255.0
image_tensor = tf.convert_to_tensor(image_array, dtype=tf.float32)
image_tensor = tf.expand_dims(image_tensor, axis=0)
worker = WorkerThread(self.model.pipe, [image_tensor, tf.zeros((1, 1), dtype=tf.int32)], 'image_captioning')
worker.finished.connect(self.onImageCaptioningFinished)
worker.start()
def onImageCaptioningFinished(self, result):
self.captionOutput.setText(f"Generated Caption: {result}")
def generateMusic(self):
# Generate random music input (you might want to create a more meaningful input)
pitch = tf.random.uniform((1, 100), maxval=128, dtype=tf.int32)
duration = tf.random.uniform((1, 100), maxval=32, dtype=tf.int32)
velocity = tf.random.uniform((1, 100), maxval=128, dtype=tf.int32)
worker = WorkerThread(self.model.pipe, [pitch, duration, velocity], 'music_generation')
worker.finished.connect(self.onMusicGenerationFinished)
worker.start()
def onMusicGenerationFinished(self, result):
self.musicOutput.setText(f"Generated Music: {result}")
def detectAnomalies(self):
# Generate random input for anomaly detection
anomaly_input = tf.random.normal((1, 100, 768))
worker = WorkerThread(self.model.pipe, anomaly_input, 'anomaly_detection')
worker.finished.connect(self.onAnomalyDetectionFinished)
worker.start()
def onAnomalyDetectionFinished(self, result):
reconstructed, anomalies = result
self.anomalyOutput.setText(f"Detected Anomalies: {anomalies}")
def main():
# Initialize your model here
hparams = HParams(
n_vocab=50000,
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12
)
knowledge_base = [
{'text': 'Example knowledge 1', 'vector': np.random.rand(768)},
{'text': 'Example knowledge 2', 'vector': np.random.rand(768)},
]
model = MultiModalTransformer(hparams, knowledge_base)
app = QApplication(sys.argv)
gui = EnhancedChatGUI(model)
gui.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main()