|
import os |
|
import sys |
|
import uvicorn |
|
from fastapi import FastAPI, Query, HTTPException |
|
from fastapi.responses import HTMLResponse |
|
from starlette.middleware.cors import CORSMiddleware |
|
from datasets import load_dataset, list_datasets |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from loguru import logger |
|
import concurrent.futures |
|
import psutil |
|
import asyncio |
|
import torch |
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
from huggingface_hub import HfApi |
|
from huggingface_hub.utils import RepositoryNotFoundError |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
if not HUGGINGFACE_TOKEN: |
|
logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.") |
|
sys.exit(1) |
|
|
|
|
|
datasets_dict = {} |
|
example_usage_list = [] |
|
|
|
|
|
CACHE_DIR = os.path.expanduser("~/.cache/huggingface") |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
os.environ["HF_HOME"] = CACHE_DIR |
|
os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN |
|
|
|
pipeline_instance = None |
|
|
|
|
|
initialization_complete = False |
|
|
|
def initialize_model(): |
|
global pipeline_instance, initialization_complete |
|
try: |
|
logger.info("Initializing the GPT-2 model and tokenizer.") |
|
base_model_repo = "gpt2" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model_repo, |
|
cache_dir=CACHE_DIR, |
|
ignore_mismatched_sizes=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
pipeline_instance = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
logger.info("GPT-2 model and tokenizer initialized successfully.") |
|
initialization_complete = True |
|
except Exception as e: |
|
logger.error(f"Error initializing model and tokenizer: {e}", exc_info=True) |
|
sys.exit(1) |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5)) |
|
def download_dataset(dataset_name): |
|
try: |
|
logger.info(f"Starting download for dataset: {dataset_name}") |
|
|
|
datasets_dict[dataset_name] = load_dataset(dataset_name, cache_dir=CACHE_DIR) |
|
create_example_usage(dataset_name) |
|
except Exception as e: |
|
logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True) |
|
raise |
|
|
|
def upload_model_to_hub(): |
|
try: |
|
api = HfApi() |
|
model_repo = "Yhhxhfh/Hhggg" |
|
try: |
|
api.repo_info(repo_id=model_repo) |
|
logger.info(f"Model repository {model_repo} already exists.") |
|
except RepositoryNotFoundError: |
|
api.create_repo(repo_id=model_repo, private=False, token=HUGGINGFACE_TOKEN) |
|
logger.info(f"Created model repository {model_repo}.") |
|
logger.info(f"Pushing the model and tokenizer to {model_repo}.") |
|
pipeline_instance.model.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN) |
|
pipeline_instance.tokenizer.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN) |
|
logger.info(f"Successfully pushed the model and tokenizer to {model_repo}.") |
|
except Exception as e: |
|
logger.error(f"Error uploading model to Hugging Face Hub: {e}", exc_info=True) |
|
|
|
def create_example_usage(dataset_name): |
|
try: |
|
logger.info(f"Creating example usage for dataset {dataset_name}") |
|
example_prompts = [ |
|
"Once upon a time,", |
|
"In a world where AI rules,", |
|
"The future of technology is", |
|
"Explain the concept of", |
|
"Describe a scenario where" |
|
] |
|
examples = [] |
|
for prompt in example_prompts: |
|
generated_text = pipeline_instance(prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] |
|
examples.append({"prompt": prompt, "response": generated_text}) |
|
example_usage_list.append({"dataset_name": dataset_name, "examples": examples}) |
|
logger.info(f"Example usage created for dataset {dataset_name}") |
|
except Exception as e: |
|
logger.error(f"Error creating example usage for dataset {dataset_name}: {e}", exc_info=True) |
|
|
|
def unify_datasets(): |
|
try: |
|
logger.info("Starting to unify datasets") |
|
unified_dataset = None |
|
for dataset in datasets_dict.values(): |
|
if unified_dataset is None: |
|
unified_dataset = dataset |
|
else: |
|
unified_dataset = unified_dataset.concatenate(dataset) |
|
datasets_dict['unified'] = unified_dataset |
|
logger.info("Datasets successfully unified.") |
|
except Exception as e: |
|
logger.error(f"Error unifying datasets: {e}", exc_info=True) |
|
|
|
|
|
cpu_count = psutil.cpu_count(logical=False) or 1 |
|
memory_available_mb = psutil.virtual_memory().available / (1024 * 1024) |
|
memory_per_download_mb = 100 |
|
memory_available = int(memory_available_mb / memory_per_download_mb) |
|
gpu_count = torch.cuda.device_count() |
|
max_concurrent_downloads = min(cpu_count, memory_available, gpu_count * 2 if gpu_count else cpu_count) |
|
max_concurrent_downloads = max(1, max_concurrent_downloads) |
|
max_concurrent_downloads = min(10, max_concurrent_downloads) |
|
|
|
logger.info(f"Using up to {max_concurrent_downloads} concurrent workers for downloading datasets.") |
|
|
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads) |
|
|
|
async def download_and_process_datasets(): |
|
dataset_names = list_datasets() |
|
logger.info(f"Found {len(dataset_names)} datasets to download.") |
|
loop = asyncio.get_event_loop() |
|
tasks = [] |
|
for dataset_name in dataset_names: |
|
task = loop.run_in_executor(executor, download_dataset, dataset_name) |
|
tasks.append(task) |
|
await asyncio.gather(*tasks) |
|
unify_datasets() |
|
upload_model_to_hub() |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
message_history = [] |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
logger.info("Application startup initiated.") |
|
loop = asyncio.get_event_loop() |
|
|
|
asyncio.create_task(run_initialization(loop)) |
|
logger.info("Startup tasks initiated.") |
|
|
|
async def run_initialization(loop): |
|
try: |
|
|
|
await loop.run_in_executor(None, initialize_model) |
|
|
|
await download_and_process_datasets() |
|
logger.info("All startup tasks completed successfully.") |
|
except Exception as e: |
|
logger.error(f"Error during startup tasks: {e}", exc_info=True) |
|
|
|
@app.get('/') |
|
async def index(): |
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<!-- Existing head content --> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>GPT-2 Chatbot</title> |
|
<!-- Bootstrap CSS for a professional interface --> |
|
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"> |
|
<style> |
|
body { |
|
background-color: #f8f9fa; |
|
font-family: Arial, sans-serif; |
|
} |
|
.container { |
|
max-width: 800px; |
|
margin-top: 50px; |
|
} |
|
.chat-container { |
|
background-color: #ffffff; |
|
border-radius: 10px; |
|
box-shadow: 0 0 15px rgba(0,0,0,0.2); |
|
padding: 20px; |
|
display: flex; |
|
flex-direction: column; |
|
height: 600px; |
|
} |
|
.chat-box { |
|
flex: 1; |
|
overflow-y: auto; |
|
margin-bottom: 15px; |
|
} |
|
.chat-input { |
|
width: 100%; |
|
padding: 10px; |
|
border: 1px solid #ced4da; |
|
border-radius: 5px; |
|
font-size: 16px; |
|
} |
|
.chat-input:focus { |
|
outline: none; |
|
border-color: #80bdff; |
|
box-shadow: 0 0 5px rgba(0,123,255,0.5); |
|
} |
|
.user-message { |
|
text-align: right; |
|
margin-bottom: 10px; |
|
} |
|
.user-message .message { |
|
display: inline-block; |
|
background-color: #007bff; |
|
color: #fff; |
|
padding: 10px 15px; |
|
border-radius: 15px; |
|
max-width: 70%; |
|
} |
|
.bot-message { |
|
text-align: left; |
|
margin-bottom: 10px; |
|
} |
|
.bot-message .message { |
|
display: inline-block; |
|
background-color: #6c757d; |
|
color: #fff; |
|
padding: 10px 15px; |
|
border-radius: 15px; |
|
max-width: 70%; |
|
} |
|
.toggle-history { |
|
text-align: center; |
|
cursor: pointer; |
|
color: #007bff; |
|
margin-top: 10px; |
|
} |
|
.history-container { |
|
display: none; |
|
background-color: #ffffff; |
|
border-radius: 10px; |
|
box-shadow: 0 0 15px rgba(0,0,0,0.2); |
|
padding: 20px; |
|
margin-top: 20px; |
|
max-height: 300px; |
|
overflow-y: auto; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1 class="text-center mb-4">GPT-2 Chatbot</h1> |
|
<div class="chat-container"> |
|
<div class="chat-box" id="chat-box"> |
|
</div> |
|
<input type="text" class="chat-input" id="user-input" placeholder="Type your message..." onkeypress="handleKeyPress(event)"> |
|
<button class="btn btn-primary mt-3" onclick="sendMessage()">Send</button> |
|
<div class="toggle-history mt-3" onclick="toggleHistory()">Toggle History</div> |
|
<div class="history-container" id="history-container"> |
|
<h3>Chat History</h3> |
|
<div id="history-content"></div> |
|
</div> |
|
</div> |
|
</div> |
|
|
|
<!-- Bootstrap JS (optional) --> |
|
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script> |
|
<script> |
|
function toggleHistory() { |
|
const historyContainer = document.getElementById('history-container'); |
|
historyContainer.classList.toggle('d-none'); |
|
} |
|
|
|
function saveMessage(sender, message) { |
|
const historyContent = document.getElementById('history-content'); |
|
const messageElement = document.createElement('div'); |
|
messageElement.className = sender === 'user' ? 'user-message' : 'bot-message'; |
|
messageElement.innerHTML = `<div class="message">${message}</div>`; |
|
historyContent.appendChild(messageElement); |
|
} |
|
|
|
function appendMessage(sender, message) { |
|
const chatBox = document.getElementById('chat-box'); |
|
const messageElement = document.createElement('div'); |
|
messageElement.className = sender === 'user' ? 'user-message' : 'bot-message'; |
|
messageElement.innerHTML = `<div class="message">${message}</div>`; |
|
chatBox.appendChild(messageElement); |
|
chatBox.scrollTop = chatBox.scrollHeight; |
|
} |
|
|
|
function handleKeyPress(event) { |
|
if (event.key === 'Enter') { |
|
event.preventDefault(); |
|
sendMessage(); |
|
} |
|
} |
|
|
|
function sendMessage() { |
|
const userInput = document.getElementById('user-input'); |
|
const userMessage = userInput.value.trim(); |
|
if (userMessage === '') return; |
|
|
|
appendMessage('user', userMessage); |
|
saveMessage('user', userMessage); |
|
userInput.value = ''; |
|
|
|
fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`) |
|
.then(response => { |
|
if (response.status === 503) { |
|
return response.json().then(data => { throw new Error(data.detail); }); |
|
} |
|
return response.json(); |
|
}) |
|
.then(data => { |
|
const botMessages = data.result; |
|
botMessages.forEach(message => { |
|
appendMessage('bot', message); |
|
saveMessage('bot', message); |
|
}); |
|
}) |
|
.catch(error => { |
|
console.error('Error:', error); |
|
appendMessage('bot', "Sorry, I'm not available right now. Please try again later."); |
|
saveMessage('bot', "Sorry, I'm not available right now. Please try again later."); |
|
}); |
|
} |
|
|
|
function retryLastMessage() { |
|
const lastUserMessage = document.querySelector('.user-message:last-of-type .message'); |
|
if (lastUserMessage) { |
|
const userInput = document.getElementById('user-input'); |
|
userInput.value = lastUserMessage.innerText; |
|
sendMessage(); |
|
} |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code, status_code=200) |
|
|
|
@app.get('/autocomplete') |
|
async def autocomplete(q: str = Query(..., title='query')): |
|
global message_history, pipeline_instance, initialization_complete |
|
message_history.append(('user', q)) |
|
|
|
if not initialization_complete: |
|
logger.warning("Model is not initialized yet.") |
|
raise HTTPException(status_code=503, detail="Model is not initialized yet. Please try again later.") |
|
|
|
try: |
|
response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text'] |
|
logger.debug(f"Successfully autocomplete, q:{q}, res:{response}") |
|
return {"result": [response]} |
|
except Exception as e: |
|
logger.error(f"Ignored error in autocomplete: {e}", exc_info=True) |
|
return {"result": ["Sorry, I encountered an error processing your request."]} |
|
|
|
if __name__ == '__main__': |
|
port = 7860 |
|
uvicorn.run(app=app, host='0.0.0.0', port=port) |
|
|