Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| import spaces | |
| import os | |
| # Model configuration | |
| MODEL_PATH = "ibm-granite/granite-4.0-h-1b" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Global variables to store model and tokenizer | |
| tokenizer = None | |
| model = None | |
| def load_model(): | |
| """Load the model and tokenizer""" | |
| global tokenizer, model | |
| if tokenizer is None or model is None: | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=DEVICE) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # Use GPU for inference | |
| def chat_with_model(message, history): | |
| """ | |
| Chat function that processes user input and generates responses | |
| Args: | |
| message (str): Current user message | |
| history (list): Previous conversation history | |
| Returns: | |
| str: Model response | |
| """ | |
| try: | |
| # Load model if not already loaded | |
| load_model() | |
| # Prepare chat format | |
| messages = [] | |
| # Add system message for better performance | |
| messages.append({ | |
| "role": "system", | |
| "content": "You are a helpful AI assistant. Provide clear, accurate, and helpful responses." | |
| }) | |
| # Add conversation history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| chat = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| input_tokens = tokenizer(chat, return_tensors="pt").to(DEVICE) | |
| # Generate response | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **input_tokens, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| full_response = tokenizer.batch_decode(output)[0] | |
| # Extract only the assistant's response | |
| # Find the start of assistant role | |
| assistant_start = full_response.find('<|start_of_role|>assistant<|end_of_role|>') | |
| if assistant_start != -1: | |
| assistant_start += len('<|start_of_role|>assistant<|end_of_role|>') | |
| assistant_response = full_response[assistant_start:].strip() | |
| else: | |
| # Fallback to original method if pattern not found | |
| response_start = full_response.find('<|assistant|>') | |
| if response_start != -1: | |
| response_start += len('<|assistant|>') | |
| assistant_response = full_response[response_start:].strip() | |
| else: | |
| assistant_response = full_response.strip() | |
| # Clean up the response - remove end markers | |
| assistant_response = assistant_response.replace('<|endoftext|>', '').replace('<|end_of_text|>', '').strip() | |
| return assistant_response | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| return f"I apologize, but I encountered an error: {str(e)}. Please try again." | |
| def clear_chat(): | |
| """Clear the chat history""" | |
| return [] | |
| # Create the Gradio chat interface | |
| def create_chat_app(): | |
| with gr.Blocks(title="IBM Granite Chat", css=""" | |
| .header { | |
| text-align: center; | |
| padding: 10px; | |
| background: linear-gradient(90deg, #0066cc, #004499); | |
| color: white; | |
| margin-bottom: 20px; | |
| border-radius: 10px; | |
| } | |
| .header a { | |
| color: #ffffff; | |
| text-decoration: none; | |
| font-weight: bold; | |
| } | |
| .header a:hover { | |
| text-decoration: underline; | |
| } | |
| """) as demo: | |
| # Header with attribution | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1>IBM Granite 4.0 Chat</h1> | |
| <p>Powered by <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></p> | |
| </div> | |
| """) | |
| # Chat interface | |
| chatbot = gr.ChatInterface( | |
| fn=chat_with_model, | |
| title="Chat with IBM Granite 4.0", | |
| description="Chat with the IBM Granite 4.0 1B parameter language model. Ask questions, get help, or have a conversation!", | |
| examples=[ | |
| "What is machine learning?", | |
| "Explain quantum computing in simple terms", | |
| "How can I improve my programming skills?", | |
| "What are the latest developments in AI?", | |
| "Tell me about IBM Research" | |
| ], | |
| ) | |
| # Additional info | |
| with gr.Accordion("Model Information", open=False): | |
| gr.Markdown(f""" | |
| ## Model Details | |
| - **Model**: {MODEL_PATH} | |
| - **Parameters**: 1B | |
| - **Device**: {DEVICE.upper()} | |
| - **Max Tokens**: 200 per response | |
| - **Temperature**: 0.7 (for balanced creativity and accuracy) | |
| ## Tips | |
| - Ask specific questions for better results | |
| - The model works best with clear, concise prompts | |
| - Try asking follow-up questions to dive deeper into topics | |
| - The model can help with programming, explanations, and general knowledge | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the app | |
| app = create_chat_app() | |
| # Launch configuration | |
| app.launch() |