File size: 5,094 Bytes
a67a914 b52b52a a67a914 b52b52a 20713d5 1fa6509 a67a914 1fa6509 b52b52a 1fa6509 758ffbc 1fa6509 b52b52a 1fa6509 a67a914 9bc975a bcfeb34 20713d5 e881a5a 9bc975a bcfeb34 20713d5 776a639 7448649 bcfeb34 7448649 776a639 1fa6509 776a639 64fedde 1fa6509 776a639 bcfeb34 776a639 64fedde 20713d5 e881a5a 9bc975a bcfeb34 20713d5 1fa6509 20713d5 9bc975a a67a914 bcfeb34 3e93456 bcfeb34 a67a914 bcfeb34 a67a914 bcfeb34 a67a914 bcfeb34 9bc975a bcfeb34 9bc975a bcfeb34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import streamlit as st
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import io
import numpy as np
from kokoro import KPipeline
import soundfile as sf
import re
# Cache model loading to optimize resource use
@st.cache_resource
def load_caption_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
return processor, model
@st.cache_resource
def load_story_model():
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
@st.cache_resource
def load_audio_pipeline():
return KPipeline(lang_code='a')
# Function to truncate story to last complete sentence within 100 words
def truncate_to_100_words(story):
sentences = re.split(r'(?<=[.!?]) +', story)
word_count = 0
truncated_story = []
for sentence in sentences:
sentence_words = sentence.split()
if word_count + len(sentence_words) > 100:
break
truncated_story.append(sentence)
word_count += len(sentence_words)
return ' '.join(truncated_story)
# Function to generate a caption from an image
def generate_caption(image_bytes, processor, caption_model):
try:
image = Image.open(io.BytesIO(image_bytes))
inputs = processor(images=image, return_tensors="pt")
outputs = caption_model.generate(**inputs)
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
except Exception as e:
st.error(f"Error generating caption: {e}")
return None
# Function to generate a story from a caption
def generate_story(caption, story_generator):
try:
tokenizer = story_generator.tokenizer
prompt = [{"role": "user",
"content": f"Based on this text:'{caption}', elaborate a short story with no more than 100 words for young children."}]
input_text = tokenizer.apply_chat_template(prompt, tokenize=False)
story_output = story_generator(
input_text,
max_new_tokens=120,
do_sample=True,
temperature=0.7,
top_p=0.9,
return_full_text=False
)
story = story_output[0]["generated_text"]
story = truncate_to_100_words(story)
return story
except Exception as e:
st.error(f"Error generating story: {e}")
return None
# Function to generate audio from a story
def generate_audio(story, audio_pipeline):
try:
audio_generator = audio_pipeline(story, voice='af_heart', speed=1)
audio_segments = []
for i, (gs, ps, audio) in enumerate(audio_generator):
audio_segments.append(audio)
if not audio_segments:
return None
concatenated_audio = np.concatenate(audio_segments)
audio_buffer = io.BytesIO()
sf.write(audio_buffer, concatenated_audio, 24000, format='WAV')
audio_buffer.seek(0)
return audio_buffer
except Exception as e:
st.error(f"Error generating audio: {e}")
return None
# Main function to encapsulate Streamlit logic
def main():
# Load models
processor, caption_model = load_caption_model()
story_generator = load_story_model()
audio_pipeline = load_audio_pipeline()
# Streamlit UI
st.title("A Cute Story Teller for Children \n🧸💐🦄🩰👼🐇🍄🍨🍩🧚♀️")
st.write("✨Upload an image to generate a short children’s story with audio.✨")
uploaded_file = st.file_uploader("Choose an image📷...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image_bytes = uploaded_file.read()
st.image(image_bytes, caption="Uploaded Image🐾", use_container_width=True)
with st.spinner("Generating caption✍..."):
caption = generate_caption(image_bytes, processor, caption_model)
if caption:
st.write("**Generated Caption🐈:**")
st.write(caption)
with st.spinner("Generating story📚..."):
story = generate_story(caption, story_generator)
if story:
st.write("**Generated Story💐:**")
st.write(story)
with st.spinner("Generating audio💽..."):
audio_buffer = generate_audio(story, audio_pipeline)
if audio_buffer:
st.audio(audio_buffer, format="audio/wav")
st.download_button(
label="Download Story Audio🎀",
data=audio_buffer,
file_name="story_audio.wav",
mime="audio/wav"
)
if __name__ == "__main__":
main() |