import streamlit as st from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoModelForCausalLM, AutoTokenizer import torch from PIL import Image import io import numpy as np from kokoro import KPipeline import soundfile as sf import re # Cache model loading to optimize resource use @st.cache_resource def load_caption_model(): processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") return processor, model @st.cache_resource def load_story_model(): model_name = "Qwen/Qwen2.5-1.5B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return pipeline("text-generation", model=model, tokenizer=tokenizer) @st.cache_resource def load_audio_pipeline(): return KPipeline(lang_code='a') # Function to truncate story to last complete sentence within 100 words def truncate_to_100_words(story): sentences = re.split(r'(?<=[.!?]) +', story) word_count = 0 truncated_story = [] for sentence in sentences: sentence_words = sentence.split() if word_count + len(sentence_words) > 100: break truncated_story.append(sentence) word_count += len(sentence_words) return ' '.join(truncated_story) # Function to generate a caption from an image def generate_caption(image_bytes, processor, caption_model): try: image = Image.open(io.BytesIO(image_bytes)) inputs = processor(images=image, return_tensors="pt") outputs = caption_model.generate(**inputs) caption = processor.decode(outputs[0], skip_special_tokens=True) return caption except Exception as e: st.error(f"Error generating caption: {e}") return None # Function to generate a story from a caption def generate_story(caption, story_generator): try: tokenizer = story_generator.tokenizer prompt = [{"role": "user", "content": f"Based on this text:'{caption}', elaborate a short story with no more than 100 words for young children."}] input_text = tokenizer.apply_chat_template(prompt, tokenize=False) story_output = story_generator( input_text, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9, return_full_text=False ) story = story_output[0]["generated_text"] story = truncate_to_100_words(story) return story except Exception as e: st.error(f"Error generating story: {e}") return None # Function to generate audio from a story def generate_audio(story, audio_pipeline): try: audio_generator = audio_pipeline(story, voice='af_heart', speed=1) audio_segments = [] for i, (gs, ps, audio) in enumerate(audio_generator): audio_segments.append(audio) if not audio_segments: return None concatenated_audio = np.concatenate(audio_segments) audio_buffer = io.BytesIO() sf.write(audio_buffer, concatenated_audio, 24000, format='WAV') audio_buffer.seek(0) return audio_buffer except Exception as e: st.error(f"Error generating audio: {e}") return None # Main function to encapsulate Streamlit logic def main(): # Load models processor, caption_model = load_caption_model() story_generator = load_story_model() audio_pipeline = load_audio_pipeline() # Streamlit UI st.title("A Cute Story Teller for Children \n🧸💐🦄🩰👼🐇🍄🍨🍩🧚‍♀️") st.write("✨Upload an image to generate a short children’s story with audio.✨") uploaded_file = st.file_uploader("Choose an image📷...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: image_bytes = uploaded_file.read() st.image(image_bytes, caption="Uploaded Image🐾", use_container_width=True) with st.spinner("Generating caption✍..."): caption = generate_caption(image_bytes, processor, caption_model) if caption: st.write("**Generated Caption🐈:**") st.write(caption) with st.spinner("Generating story📚..."): story = generate_story(caption, story_generator) if story: st.write("**Generated Story💐:**") st.write(story) with st.spinner("Generating audio💽..."): audio_buffer = generate_audio(story, audio_pipeline) if audio_buffer: st.audio(audio_buffer, format="audio/wav") st.download_button( label="Download Story Audio🎀", data=audio_buffer, file_name="story_audio.wav", mime="audio/wav" ) if __name__ == "__main__": main()