JoannaKOKO's picture
Update app.py
758ffbc verified
import streamlit as st
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import io
import numpy as np
from kokoro import KPipeline
import soundfile as sf
import re
# Cache model loading to optimize resource use
@st.cache_resource
def load_caption_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
return processor, model
@st.cache_resource
def load_story_model():
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
@st.cache_resource
def load_audio_pipeline():
return KPipeline(lang_code='a')
# Function to truncate story to last complete sentence within 100 words
def truncate_to_100_words(story):
sentences = re.split(r'(?<=[.!?]) +', story)
word_count = 0
truncated_story = []
for sentence in sentences:
sentence_words = sentence.split()
if word_count + len(sentence_words) > 100:
break
truncated_story.append(sentence)
word_count += len(sentence_words)
return ' '.join(truncated_story)
# Function to generate a caption from an image
def generate_caption(image_bytes, processor, caption_model):
try:
image = Image.open(io.BytesIO(image_bytes))
inputs = processor(images=image, return_tensors="pt")
outputs = caption_model.generate(**inputs)
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
except Exception as e:
st.error(f"Error generating caption: {e}")
return None
# Function to generate a story from a caption
def generate_story(caption, story_generator):
try:
tokenizer = story_generator.tokenizer
prompt = [{"role": "user",
"content": f"Based on this text:'{caption}', elaborate a short story with no more than 100 words for young children."}]
input_text = tokenizer.apply_chat_template(prompt, tokenize=False)
story_output = story_generator(
input_text,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
return_full_text=False
)
story = story_output[0]["generated_text"]
story = truncate_to_100_words(story)
return story
except Exception as e:
st.error(f"Error generating story: {e}")
return None
# Function to generate audio from a story
def generate_audio(story, audio_pipeline):
try:
audio_generator = audio_pipeline(story, voice='af_heart', speed=1)
audio_segments = []
for i, (gs, ps, audio) in enumerate(audio_generator):
audio_segments.append(audio)
if not audio_segments:
return None
concatenated_audio = np.concatenate(audio_segments)
audio_buffer = io.BytesIO()
sf.write(audio_buffer, concatenated_audio, 24000, format='WAV')
audio_buffer.seek(0)
return audio_buffer
except Exception as e:
st.error(f"Error generating audio: {e}")
return None
# Main function to encapsulate Streamlit logic
def main():
# Load models
processor, caption_model = load_caption_model()
story_generator = load_story_model()
audio_pipeline = load_audio_pipeline()
# Streamlit UI
st.title("A Cute Story Teller for Children \n🧸💐🦄🩰👼🐇🍄🍨🍩🧚‍♀️")
st.write("✨Upload an image to generate a short children’s story with audio.✨")
uploaded_file = st.file_uploader("Choose an image📷...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image_bytes = uploaded_file.read()
st.image(image_bytes, caption="Uploaded Image🐾", use_container_width=True)
with st.spinner("Generating caption✍..."):
caption = generate_caption(image_bytes, processor, caption_model)
if caption:
st.write("**Generated Caption🐈:**")
st.write(caption)
with st.spinner("Generating story📚..."):
story = generate_story(caption, story_generator)
if story:
st.write("**Generated Story💐:**")
st.write(story)
with st.spinner("Generating audio💽..."):
audio_buffer = generate_audio(story, audio_pipeline)
if audio_buffer:
st.audio(audio_buffer, format="audio/wav")
st.download_button(
label="Download Story Audio🎀",
data=audio_buffer,
file_name="story_audio.wav",
mime="audio/wav"
)
if __name__ == "__main__":
main()