|
import streamlit as st |
|
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
from kokoro import KPipeline |
|
import soundfile as sf |
|
import re |
|
|
|
|
|
@st.cache_resource |
|
def load_caption_model(): |
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") |
|
return processor, model |
|
|
|
@st.cache_resource |
|
def load_story_model(): |
|
model_name = "Qwen/Qwen2.5-1.5B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
@st.cache_resource |
|
def load_audio_pipeline(): |
|
return KPipeline(lang_code='a') |
|
|
|
|
|
def truncate_to_100_words(story): |
|
sentences = re.split(r'(?<=[.!?]) +', story) |
|
word_count = 0 |
|
truncated_story = [] |
|
for sentence in sentences: |
|
sentence_words = sentence.split() |
|
if word_count + len(sentence_words) > 100: |
|
break |
|
truncated_story.append(sentence) |
|
word_count += len(sentence_words) |
|
return ' '.join(truncated_story) |
|
|
|
|
|
def generate_caption(image_bytes, processor, caption_model): |
|
try: |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = caption_model.generate(**inputs) |
|
caption = processor.decode(outputs[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
st.error(f"Error generating caption: {e}") |
|
return None |
|
|
|
|
|
def generate_story(caption, story_generator): |
|
try: |
|
tokenizer = story_generator.tokenizer |
|
prompt = [{"role": "user", |
|
"content": f"Based on this text:'{caption}', elaborate a short story with no more than 100 words for young children."}] |
|
input_text = tokenizer.apply_chat_template(prompt, tokenize=False) |
|
|
|
story_output = story_generator( |
|
input_text, |
|
max_new_tokens=120, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9, |
|
return_full_text=False |
|
) |
|
story = story_output[0]["generated_text"] |
|
story = truncate_to_100_words(story) |
|
return story |
|
except Exception as e: |
|
st.error(f"Error generating story: {e}") |
|
return None |
|
|
|
|
|
def generate_audio(story, audio_pipeline): |
|
try: |
|
audio_generator = audio_pipeline(story, voice='af_heart', speed=1) |
|
audio_segments = [] |
|
for i, (gs, ps, audio) in enumerate(audio_generator): |
|
audio_segments.append(audio) |
|
if not audio_segments: |
|
return None |
|
concatenated_audio = np.concatenate(audio_segments) |
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, concatenated_audio, 24000, format='WAV') |
|
audio_buffer.seek(0) |
|
return audio_buffer |
|
except Exception as e: |
|
st.error(f"Error generating audio: {e}") |
|
return None |
|
|
|
|
|
def main(): |
|
|
|
processor, caption_model = load_caption_model() |
|
story_generator = load_story_model() |
|
audio_pipeline = load_audio_pipeline() |
|
|
|
|
|
st.title("A Cute Story Teller for Children \n🧸💐🦄🩰👼🐇🍄🍨🍩🧚♀️") |
|
st.write("✨Upload an image to generate a short children’s story with audio.✨") |
|
|
|
uploaded_file = st.file_uploader("Choose an image📷...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
image_bytes = uploaded_file.read() |
|
st.image(image_bytes, caption="Uploaded Image🐾", use_container_width=True) |
|
|
|
with st.spinner("Generating caption✍..."): |
|
caption = generate_caption(image_bytes, processor, caption_model) |
|
if caption: |
|
st.write("**Generated Caption🐈:**") |
|
st.write(caption) |
|
|
|
with st.spinner("Generating story📚..."): |
|
story = generate_story(caption, story_generator) |
|
if story: |
|
st.write("**Generated Story💐:**") |
|
st.write(story) |
|
|
|
with st.spinner("Generating audio💽..."): |
|
audio_buffer = generate_audio(story, audio_pipeline) |
|
if audio_buffer: |
|
st.audio(audio_buffer, format="audio/wav") |
|
st.download_button( |
|
label="Download Story Audio🎀", |
|
data=audio_buffer, |
|
file_name="story_audio.wav", |
|
mime="audio/wav" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |