|
|
|
import os |
|
from PIL import Image |
|
import random |
|
import shutil |
|
import datetime |
|
import torchvision.transforms.functional as f |
|
import torch |
|
|
|
from typing import Optional, Tuple |
|
from threading import Lock |
|
from langchain import ConversationChain |
|
|
|
from chat_anything.tts_talker.tts_edge import TTSTalker |
|
from chat_anything.sad_talker.sad_talker import SadTalker |
|
from chat_anything.chatbot.chat import load_chain |
|
from chat_anything.chatbot.select import model_selection_chain |
|
from chat_anything.chatbot.voice_select import voice_selection_chain |
|
import gradio as gr |
|
|
|
|
|
TALKING_HEAD_WIDTH = "350" |
|
sadtalker_checkpoint_path = "MODELS/SadTalker" |
|
config_path = "chat_anything/sad_talker/config" |
|
|
|
class ChatWrapper: |
|
def __init__(self): |
|
self.lock = Lock() |
|
self.sad_talker = SadTalker( |
|
sadtalker_checkpoint_path, config_path, lazy_load=True) |
|
|
|
def __call__( |
|
self, |
|
api_key: str, |
|
inp: str, |
|
history: Optional[Tuple[str, str]], |
|
chain: Optional[ConversationChain], |
|
speak_text: bool, talking_head: bool, |
|
uid: str, |
|
talker : None, |
|
fullbody : str, |
|
): |
|
"""Execute the chat functionality.""" |
|
self.lock.acquire() |
|
if chain is None: |
|
history.append((inp, "Please register with your API key first!")) |
|
else: |
|
try: |
|
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====") |
|
print("inp: " + inp) |
|
print("speak_text: ", speak_text) |
|
print("talking_head: ", talking_head) |
|
history = history or [] |
|
|
|
output = "Please paste your OpenAI key from openai.com to use this app. " + \ |
|
str(datetime.datetime.now()) |
|
|
|
output = chain.predict(input=inp).strip() |
|
output = output.replace("\n", "\n\n") |
|
|
|
text_to_display = output |
|
|
|
|
|
|
|
history.append((inp, text_to_display)) |
|
|
|
html_video, temp_file, html_audio, temp_aud_file = None, None, None, None |
|
if speak_text: |
|
if talking_head: |
|
html_video, temp_file = self.do_html_video_speak( |
|
talker, output, fullbody, uid) |
|
else: |
|
html_audio, temp_aud_file = self.do_html_audio_speak( |
|
talker, output,uid) |
|
else: |
|
if talking_head: |
|
temp_file = os.path.join('tmp', uid, 'videos') |
|
html_video = create_html_video( |
|
temp_file, TALKING_HEAD_WIDTH) |
|
else: |
|
pass |
|
|
|
except Exception as e: |
|
raise e |
|
finally: |
|
self.lock.release() |
|
return history, history, html_video, temp_file, html_audio, temp_aud_file, "" |
|
|
|
|
|
def do_html_audio_speak(self,talker, words_to_speak, uid): |
|
audio_path = os.path.join('tmp', uid, 'audios') |
|
print('uid:', uid, ":", words_to_speak) |
|
audo_file_path = talker.test(text=words_to_speak, audio_path=audio_path) |
|
html_audio = '<pre>no audio</pre>' |
|
try: |
|
temp_aud_file = gr.File(audo_file_path) |
|
|
|
temp_aud_file_url = "/file=" + temp_aud_file.value['name'] |
|
html_audio = f'<audio autoplay><source src={temp_aud_file_url} type="audio/mp3"></audio>' |
|
except IOError as error: |
|
|
|
print(error) |
|
return None, None |
|
|
|
return html_audio, audo_file_path |
|
|
|
def do_html_video_speak(self,talker,words_to_speak,fullbody, uid): |
|
if fullbody: |
|
|
|
preprocess='full' |
|
else: |
|
preprocess='crop' |
|
print("success") |
|
video_path = os.path.join('tmp', uid, 'videos') |
|
if not os.path.exists(video_path): |
|
os.makedirs(video_path) |
|
video_file_path = os.path.join(video_path, 'tempfile.mp4') |
|
_, audio_path = self.do_html_audio_speak( |
|
talker,words_to_speak,uid) |
|
face_file_path = os.path.join('tmp', uid, 'images', 'test.jpg') |
|
|
|
video = self.sad_talker.test(face_file_path, audio_path,preprocess, uid=uid) |
|
|
|
|
|
shutil.move(video, video_file_path) |
|
|
|
return video_file_path, video_file_path |
|
|
|
|
|
def generate_init_face_video(self,class_concept="clock", llm=None,uid=None,fullbody=None, ref_image=None, seed=None): |
|
""" |
|
""" |
|
print('generate concept of', class_concept) |
|
print("=================================================") |
|
print('fullbody:', fullbody) |
|
print('uid:', uid) |
|
print("==================================================") |
|
chain, memory, personality_text = load_chain(llm, class_concept) |
|
model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') |
|
|
|
voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') |
|
|
|
|
|
print('generate concept of', class_concept) |
|
augment_word_list = ["Female ", "female ", "beautiful ", "small ", "cute "] |
|
first_sentence = "Hello, how are you doing today?" |
|
voice_conf, selected_voice = model_selection_chain(llm, personality_text, conf_file='resources/voices_edge.yaml') |
|
talker = TTSTalker(selected_voice=selected_voice, gender=voice_conf['gender'], language=voice_conf['language']) |
|
model_conf, selected_model = model_selection_chain(llm, class_concept, conf_file='resources/models.yaml') |
|
retry_cnt = 4 |
|
if ref_image is None: |
|
face_files = os.listdir(FACE_DIR) |
|
face_img_path = os.path.join(FACE_DIR, random.choice(face_files)) |
|
ref_image = Image.open(face_img_path) |
|
|
|
print('loading face generating model') |
|
anything_facemaker = load_face_generator( |
|
model_dir=model_conf['model_dir'], |
|
lora_path=model_conf['lora_path'], |
|
prompt_template=model_conf['prompt_template'], |
|
negative_prompt=model_conf['negative_prompt'], |
|
) |
|
retry_cnt = 0 |
|
has_face = anything_facemaker.has_face(ref_image) |
|
init_strength = 1.0 if has_face else 0.85 |
|
strength_retry_step = -0.04 if has_face else 0.04 |
|
while retry_cnt < 8: |
|
try: |
|
generate_face_image( |
|
anything_facemaker, |
|
class_concept, |
|
ref_image, |
|
uid=uid, |
|
strength=init_strength if (retry_cnt==0 and has_face) else init_strength + retry_cnt * strength_retry_step, |
|
controlnet_conditioning_scale=0.5 if retry_cnt == 8 else 0.3, |
|
seed=seed, |
|
) |
|
self.do_html_video_speak(talker, first_sentence, fullbody, uid=uid) |
|
video_file_path = os.path.join('tmp', uid, 'videos/tempfile.mp4') |
|
htm_video = create_html_video( |
|
video_file_path, TALKING_HEAD_WIDTH) |
|
break |
|
except Exception as e: |
|
retry_cnt += 1 |
|
class_concept = random.choice(augment_word_list) + class_concept |
|
print(e) |
|
|
|
|
|
return chain, memory, htm_video, talker |
|
|
|
|
|
def update_talking_head(self, widget, uid, state): |
|
|
|
if widget: |
|
state = widget |
|
temp_file = os.path.join('tmp', uid, 'videos') |
|
video_html_talking_head = create_html_video( |
|
temp_file, TALKING_HEAD_WIDTH) |
|
return state, video_html_talking_head |
|
else: |
|
return None, "<pre></pre>" |
|
|
|
|
|
def reset_memory(history, memory): |
|
memory.clear() |
|
history = [] |
|
return history, history, memory |
|
|
|
|
|
def create_html_video(file_name, width): |
|
return file_name |
|
|
|
|
|
def create_html_audio(file_name): |
|
if os.path.exists(file_name): |
|
tmp_audio_file = gr.File(file_name, visible=False) |
|
tmp_aud_file_url = "/file=" + tmp_audio_file.value['name'] |
|
html_audio = f'<audio><source src={tmp_aud_file_url} type="audio/mp3"></audio>' |
|
del tmp_aud_file_url |
|
else: |
|
html_audio = f'' |
|
|
|
return html_audio |
|
|
|
|
|
def update_foo(widget, state): |
|
if widget: |
|
state = widget |
|
return state |
|
|
|
|
|
|
|
def update_use_embeddings(widget, state): |
|
if widget: |
|
state = widget |
|
return state |
|
|
|
|
|
|
|
|
|
def load_face_generator(model_dir, lora_path, prompt_template, negative_prompt): |
|
from chat_anything.face_generator.long_prompt_control_generator import LongPromptControlGenerator |
|
|
|
model_zoo = "MODELS" |
|
face_control_dir = os.path.join( |
|
model_zoo, "Face-Landmark-ControlNet", "models_for_diffusers") |
|
face_detect_path = os.path.join( |
|
model_zoo, "SadTalker", "shape_predictor_68_face_landmarks.dat") |
|
|
|
|
|
anything_facemaker = LongPromptControlGenerator( |
|
model_dir=model_dir, |
|
lora_path=lora_path, |
|
prompt_template=prompt_template, |
|
negative_prompt=negative_prompt, |
|
face_control_dir=face_control_dir, |
|
face_detect_path=face_detect_path, |
|
) |
|
anything_facemaker.load_model(safety_checker=None) |
|
return anything_facemaker |
|
|
|
|
|
|
|
FACE_DIR="resources/images/faces" |
|
def generate_face_image( |
|
anything_facemaker, |
|
class_concept, |
|
face_img_pil, |
|
uid=None, |
|
controlnet_conditioning_scale=1.0, |
|
strength=0.95, |
|
seed=42, |
|
): |
|
face_img_pil = f.center_crop( |
|
f.resize(face_img_pil, 512), 512).convert('RGB') |
|
prompt = anything_facemaker.prompt_template.format(class_concept) |
|
|
|
|
|
|
|
|
|
print('USING SEED:', seed) |
|
generator = torch.Generator(device=anything_facemaker.face_control_pipe.device) |
|
generator.manual_seed(seed) |
|
if strength is None: |
|
pure_control = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=False, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator) |
|
init_face_pil = pure_control |
|
else: |
|
control_inversion = anything_facemaker.face_control_generate(prompt=prompt, face_img_pil=face_img_pil, do_inversion=True, |
|
strength=strength, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, generator=generator) |
|
init_face_pil = control_inversion |
|
print('succeeded generating face image') |
|
face_path = os.path.join('tmp', uid, 'images') |
|
if not os.path.exists(face_path): |
|
os.makedirs(face_path) |
|
|
|
face_file_path = os.path.join(face_path, 'test.jpg') |
|
init_face_pil.save(face_file_path) |
|
return init_face_pil |
|
|