|
import os
|
|
import io
|
|
import json
|
|
import torch
|
|
import requests
|
|
from PIL import Image
|
|
import soundfile as sf
|
|
import gradio as gr
|
|
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
|
|
|
|
|
|
|
|
|
model_path = "microsoft/Phi-4-multimodal-instruct"
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
device_map="auto",
|
|
torch_dtype="auto",
|
|
trust_remote_code=True,
|
|
_attn_implementation="eager",
|
|
)
|
|
|
|
generation_config = GenerationConfig.from_pretrained(model_path)
|
|
|
|
|
|
|
|
|
|
def process_task(mode, system_msg, user_msg, image_multi, audio, vs_images, vs_audio):
|
|
"""
|
|
根據不同任務模式組合 prompt,並使用 processor 與 model 進行生成
|
|
"""
|
|
|
|
|
|
|
|
if mode == "Text Chat":
|
|
prompt = f"<|system|>{system_msg}<|end|><|user|>{user_msg}<|end|><|assistant|>"
|
|
inputs = processor(text=prompt, return_tensors='pt').to(model.device)
|
|
|
|
elif mode == "Tool-enabled Function Calling":
|
|
tools = [{
|
|
"name": "get_weather_updates",
|
|
"description": "Fetches weather updates for a given city using the RapidAPI Weather API.",
|
|
"parameters": {
|
|
"city": {
|
|
"description": "The name of the city for which to retrieve weather information.",
|
|
"type": "str",
|
|
"default": "London"
|
|
}
|
|
}
|
|
}]
|
|
tools_json = json.dumps(tools, ensure_ascii=False)
|
|
prompt = f"<|system|>{system_msg}<|tool|>{tools_json}<|/tool|><|end|><|user|>{user_msg}<|end|><|assistant|>"
|
|
inputs = processor(text=prompt, return_tensors='pt').to(model.device)
|
|
|
|
elif mode == "Vision-Language":
|
|
|
|
if image_multi is not None and len(image_multi) > 0:
|
|
num = len(image_multi)
|
|
image_tags = ''.join([f"<|image_{i+1}|>" for i in range(num)])
|
|
prompt = f"<|user|>{image_tags}{user_msg}<|end|><|assistant|>"
|
|
images = []
|
|
for file in image_multi:
|
|
images.append(Image.open(file))
|
|
inputs = processor(text=prompt, images=images, return_tensors='pt').to(model.device)
|
|
else:
|
|
return "No image provided."
|
|
|
|
elif mode == "Speech-Language":
|
|
prompt = f"<|user|><|audio_1|>{user_msg}<|end|><|assistant|>"
|
|
if audio is None:
|
|
return "No audio provided."
|
|
|
|
if isinstance(audio, tuple):
|
|
sample_rate, audio_data = audio
|
|
else:
|
|
audio_data, sample_rate = sf.read(audio)
|
|
inputs = processor(text=prompt, audios=[(audio_data, sample_rate)], return_tensors='pt').to(model.device)
|
|
|
|
elif mode == "Vision-Speech":
|
|
prompt = f"<|user|>"
|
|
images = []
|
|
if vs_images is not None and len(vs_images) > 0:
|
|
num = len(vs_images)
|
|
image_tags = ''.join([f"<|image_{i+1}|>" for i in range(num)])
|
|
prompt += image_tags
|
|
for file in vs_images:
|
|
images.append(Image.open(file))
|
|
if vs_audio is None:
|
|
return "No audio provided for vision-speech."
|
|
prompt += "<|audio_1|><|end|><|assistant|>"
|
|
audio_data, samplerate = sf.read(vs_audio)
|
|
inputs = processor(text=prompt, images=images, audios=[(audio_data, samplerate)], return_tensors='pt').to(model.device)
|
|
|
|
else:
|
|
return "Invalid mode."
|
|
|
|
|
|
|
|
|
|
generate_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=1000,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
|
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
return response
|
|
|
|
|
|
|
|
|
|
def update_visibility(mode):
|
|
if mode == "Text Chat":
|
|
return (gr.update(visible=True),
|
|
gr.update(visible=True),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False))
|
|
elif mode == "Tool-enabled Function Calling":
|
|
return (gr.update(visible=True),
|
|
gr.update(visible=True),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False))
|
|
elif mode == "Vision-Language":
|
|
return (gr.update(visible=False),
|
|
gr.update(visible=True),
|
|
gr.update(visible=True),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False))
|
|
elif mode == "Speech-Language":
|
|
return (gr.update(visible=False),
|
|
gr.update(visible=True),
|
|
gr.update(visible=False),
|
|
gr.update(visible=True),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False))
|
|
elif mode == "Vision-Speech":
|
|
return (gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=False),
|
|
gr.update(visible=True),
|
|
gr.update(visible=True))
|
|
else:
|
|
return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("## Multi-Modal Prompt Builder & Model Inference")
|
|
|
|
mode_radio = gr.Radio(
|
|
choices=["Text Chat", "Vision-Language", "Speech-Language", "Vision-Speech"],
|
|
label="Select Task Mode",
|
|
value="Text Chat"
|
|
)
|
|
|
|
|
|
system_text = gr.Textbox(label="System Message", value="You are a helpful assistant.", visible=True)
|
|
user_text = gr.Textbox(label="User Message", visible=True)
|
|
|
|
|
|
|
|
image_upload_multi = gr.File(label="Upload Image(s) (Multiple)", file_count="multiple", visible=False)
|
|
|
|
|
|
audio_upload = gr.Audio(label="Upload Audio (wav, mp3, flac)", visible=False)
|
|
|
|
|
|
vs_image_upload = gr.File(label="Upload Image(s) for Vision-Speech", file_count="multiple", visible=False)
|
|
vs_audio_upload = gr.Audio(label="Upload Audio for Vision-Speech", visible=False)
|
|
|
|
|
|
submit_btn = gr.Button("Submit")
|
|
output_text = gr.Textbox(label="Result", lines=6)
|
|
|
|
|
|
examples = gr.Examples(
|
|
examples=[
|
|
["Text Chat", "hi who are you?"],
|
|
|
|
["Vision-Language", "Describe the image in detail."],
|
|
["Speech-Language", "Transcribe the audio to text."],
|
|
["Vision-Speech", ""]
|
|
],
|
|
inputs=[mode_radio, user_text],
|
|
label="Examples"
|
|
)
|
|
|
|
|
|
mode_radio.change(fn=update_visibility,
|
|
inputs=mode_radio,
|
|
outputs=[system_text, user_text, image_upload_multi, audio_upload, vs_image_upload, vs_audio_upload])
|
|
|
|
|
|
submit_btn.click(
|
|
fn=process_task,
|
|
inputs=[mode_radio, system_text, user_text, image_upload_multi, audio_upload, vs_image_upload, vs_audio_upload],
|
|
outputs=output_text
|
|
)
|
|
|
|
demo.launch()
|
|
|