OGAI-24B-Q6_K-GGUF / handler.py
tommytracx's picture
Update handler.py
bb1d7fe verified
raw
history blame
8.03 kB
import os
import json
import torch
from typing import Dict, List, Any, Optional, Union
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# Define environment variables and defaults
MODEL_ID = os.environ.get("MODEL_ID", "GainEnergy/OGAI-24B")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 2048))
TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7))
TOP_P = float(os.environ.get("TOP_P", 0.95))
TOP_K = int(os.environ.get("TOP_K", 40))
REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.1))
# Load model and tokenizer
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=TORCH_DTYPE,
trust_remote_code=True,
use_cache=True
)
model.eval()
return model, tokenizer
# Initialize
print(f"Loading model {MODEL_ID}...")
model, tokenizer = load_model()
print(f"Model loaded on {DEVICE} with dtype {TORCH_DTYPE}")
# Default system prompt for the model
DEFAULT_SYSTEM_PROMPT = """You are OGAI, an expert assistant in oil and gas engineering.
You provide **technically accurate, structured, and detailed** responses to inquiries related to **drilling, reservoir engineering, completions, production optimization, and oilfield calculations**. Your goal is to offer step-by-step explanations, precise calculations, and practical industry insights.
### **Guidelines for Responses:**
- **Use Markdown formatting** for better readability.
- **Explain formulas step-by-step**, defining each variable.
- **Ensure numerical consistency** in calculations.
- **Use real-world examples** where applicable.
- **Provide unit conversions** if relevant.
### **Example Format:**
#### **Q: How do you calculate bottomhole pressure?**
Bottomhole pressure (BHP) can be determined using the hydrostatic pressure equation:
\[ BHP = P_s + (\rho \cdot g \cdot h) \]
Where:
- \( BHP \) = Bottomhole Pressure (psi)
- \( P_s \) = Surface Pressure (psi)
- \( \rho \) = Mud Density (lb/gal)
- \( g \) = Acceleration due to gravity (ft/s²)
- \( h \) = True Vertical Depth (ft)
**Example Calculation:**
If:
- \( P_s = 500 \) psi
- \( \rho = 9.5 \) lb/gal
- \( h = 10,000 \) ft
Convert density:
\[ \rho' = 0.052 \times \rho = 0.052 \times 9.5 = 0.494 \text{ psi/ft} \]
Calculate BHP:
\[ BHP = 500 + (0.494 \times 10,000) = 5,440 \text{ psi} \]
Thus, **BHP is approximately 5,440 psi.**
Ensure all responses maintain technical precision, and clarify assumptions if necessary."""
def format_prompt(messages: List[Dict[str, str]], system_prompt: Optional[str] = None) -> str:
"""Format the conversation messages into a prompt the model can understand."""
system = system_prompt or DEFAULT_SYSTEM_PROMPT
formatted_prompt = f"<|system|>\n{system}\n<|user|>\n"
# Process all messages
for i, message in enumerate(messages):
role = message["role"]
content = message["content"]
if i == 0 and role == "user":
formatted_prompt += f"{content}\n<|assistant|>\n"
else:
if role == "user":
formatted_prompt += f"<|user|>\n{content}\n<|assistant|>\n"
elif role == "assistant":
formatted_prompt += f"{content}\n"
return formatted_prompt
def generate(
prompt: str,
max_new_tokens: int = MAX_NEW_TOKENS,
temperature: float = TEMPERATURE,
top_p: float = TOP_P,
top_k: int = TOP_K,
repetition_penalty: float = REPETITION_PENALTY,
stream: bool = True
):
"""Generate text from the model."""
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"]
generation_config = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"do_sample": temperature > 0,
"pad_token_id": tokenizer.eos_token_id
}
if stream:
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs=input_ids,
streamer=streamer,
**generation_config
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Return an iterator
def iterator():
for text in streamer:
yield text
return iterator()
else:
# Non-streaming mode
output = model.generate(
input_ids,
**generation_config
)
return tokenizer.decode(output[0][len(input_ids[0]):], skip_special_tokens=True)
def handler(event: Dict[str, Any], context: Any) -> Union[Dict[str, Any], Any]:
"""API handler function."""
try:
# Parse request
body = json.loads(event.get("body", "{}"))
# Get parameters
messages = body.get("messages", [{"role": "user", "content": "Hello, how can you help me with oil and gas engineering?"}])
system_prompt = body.get("system_prompt", DEFAULT_SYSTEM_PROMPT)
params = body.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", MAX_NEW_TOKENS)
temperature = params.get("temperature", TEMPERATURE)
top_p = params.get("top_p", TOP_P)
top_k = params.get("top_k", TOP_K)
repetition_penalty = params.get("repetition_penalty", REPETITION_PENALTY)
stream = params.get("stream", False)
# Format prompt
prompt = format_prompt(messages, system_prompt)
# Check for streaming
if stream:
# Return a streaming response
def generate_stream():
for chunk in generate(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stream=True
):
yield json.dumps({"generated_text": chunk}) + "\n"
return generate_stream()
else:
# Generate text
generated_text = generate(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stream=False
)
# Return response
return {
"statusCode": 200,
"headers": {
"Content-Type": "application/json"
},
"body": json.dumps({"generated_text": generated_text})
}
except Exception as e:
# Handle errors
print(f"Error: {str(e)}")
return {
"statusCode": 500,
"headers": {
"Content-Type": "application/json"
},
"body": json.dumps({"error": str(e)})
}
# For local testing
if __name__ == "__main__":
# Simple test
test_event = {
"body": json.dumps({
"messages": [
{"role": "user", "content": "Explain the principles of reservoir simulation in oil and gas engineering."}
]
})
}
response = handler(test_event, None)
if isinstance(response, dict):
print(json.loads(response["body"])["generated_text"])
else:
# Stream response
for chunk in response:
print(json.loads(chunk)["generated_text"], end="")