|
import os |
|
import json |
|
import torch |
|
from typing import Dict, List, Any, Optional, Union |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from threading import Thread |
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "GainEnergy/OGAI-24B") |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 2048)) |
|
TEMPERATURE = float(os.environ.get("TEMPERATURE", 0.7)) |
|
TOP_P = float(os.environ.get("TOP_P", 0.95)) |
|
TOP_K = int(os.environ.get("TOP_K", 40)) |
|
REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.1)) |
|
|
|
|
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
device_map="auto" if DEVICE == "cuda" else None, |
|
torch_dtype=TORCH_DTYPE, |
|
trust_remote_code=True, |
|
use_cache=True |
|
) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
|
|
print(f"Loading model {MODEL_ID}...") |
|
model, tokenizer = load_model() |
|
print(f"Model loaded on {DEVICE} with dtype {TORCH_DTYPE}") |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """You are OGAI, an expert assistant in oil and gas engineering. |
|
|
|
You provide **technically accurate, structured, and detailed** responses to inquiries related to **drilling, reservoir engineering, completions, production optimization, and oilfield calculations**. Your goal is to offer step-by-step explanations, precise calculations, and practical industry insights. |
|
|
|
### **Guidelines for Responses:** |
|
- **Use Markdown formatting** for better readability. |
|
- **Explain formulas step-by-step**, defining each variable. |
|
- **Ensure numerical consistency** in calculations. |
|
- **Use real-world examples** where applicable. |
|
- **Provide unit conversions** if relevant. |
|
|
|
### **Example Format:** |
|
#### **Q: How do you calculate bottomhole pressure?** |
|
|
|
Bottomhole pressure (BHP) can be determined using the hydrostatic pressure equation: |
|
|
|
\[ BHP = P_s + (\rho \cdot g \cdot h) \] |
|
|
|
Where: |
|
- \( BHP \) = Bottomhole Pressure (psi) |
|
- \( P_s \) = Surface Pressure (psi) |
|
- \( \rho \) = Mud Density (lb/gal) |
|
- \( g \) = Acceleration due to gravity (ft/s²) |
|
- \( h \) = True Vertical Depth (ft) |
|
|
|
**Example Calculation:** |
|
If: |
|
- \( P_s = 500 \) psi |
|
- \( \rho = 9.5 \) lb/gal |
|
- \( h = 10,000 \) ft |
|
|
|
Convert density: |
|
\[ \rho' = 0.052 \times \rho = 0.052 \times 9.5 = 0.494 \text{ psi/ft} \] |
|
|
|
Calculate BHP: |
|
\[ BHP = 500 + (0.494 \times 10,000) = 5,440 \text{ psi} \] |
|
|
|
Thus, **BHP is approximately 5,440 psi.** |
|
|
|
Ensure all responses maintain technical precision, and clarify assumptions if necessary.""" |
|
|
|
def format_prompt(messages: List[Dict[str, str]], system_prompt: Optional[str] = None) -> str: |
|
"""Format the conversation messages into a prompt the model can understand.""" |
|
system = system_prompt or DEFAULT_SYSTEM_PROMPT |
|
formatted_prompt = f"<|system|>\n{system}\n<|user|>\n" |
|
|
|
|
|
for i, message in enumerate(messages): |
|
role = message["role"] |
|
content = message["content"] |
|
|
|
if i == 0 and role == "user": |
|
formatted_prompt += f"{content}\n<|assistant|>\n" |
|
else: |
|
if role == "user": |
|
formatted_prompt += f"<|user|>\n{content}\n<|assistant|>\n" |
|
elif role == "assistant": |
|
formatted_prompt += f"{content}\n" |
|
|
|
return formatted_prompt |
|
|
|
def generate( |
|
prompt: str, |
|
max_new_tokens: int = MAX_NEW_TOKENS, |
|
temperature: float = TEMPERATURE, |
|
top_p: float = TOP_P, |
|
top_k: int = TOP_K, |
|
repetition_penalty: float = REPETITION_PENALTY, |
|
stream: bool = True |
|
): |
|
"""Generate text from the model.""" |
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
input_ids = inputs["input_ids"] |
|
|
|
generation_config = { |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"repetition_penalty": repetition_penalty, |
|
"do_sample": temperature > 0, |
|
"pad_token_id": tokenizer.eos_token_id |
|
} |
|
|
|
if stream: |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict( |
|
inputs=input_ids, |
|
streamer=streamer, |
|
**generation_config |
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
def iterator(): |
|
for text in streamer: |
|
yield text |
|
|
|
return iterator() |
|
else: |
|
|
|
output = model.generate( |
|
input_ids, |
|
**generation_config |
|
) |
|
return tokenizer.decode(output[0][len(input_ids[0]):], skip_special_tokens=True) |
|
|
|
def handler(event: Dict[str, Any], context: Any) -> Union[Dict[str, Any], Any]: |
|
"""API handler function.""" |
|
try: |
|
|
|
body = json.loads(event.get("body", "{}")) |
|
|
|
|
|
messages = body.get("messages", [{"role": "user", "content": "Hello, how can you help me with oil and gas engineering?"}]) |
|
system_prompt = body.get("system_prompt", DEFAULT_SYSTEM_PROMPT) |
|
params = body.get("parameters", {}) |
|
|
|
max_new_tokens = params.get("max_new_tokens", MAX_NEW_TOKENS) |
|
temperature = params.get("temperature", TEMPERATURE) |
|
top_p = params.get("top_p", TOP_P) |
|
top_k = params.get("top_k", TOP_K) |
|
repetition_penalty = params.get("repetition_penalty", REPETITION_PENALTY) |
|
stream = params.get("stream", False) |
|
|
|
|
|
prompt = format_prompt(messages, system_prompt) |
|
|
|
|
|
if stream: |
|
|
|
def generate_stream(): |
|
for chunk in generate( |
|
prompt, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
stream=True |
|
): |
|
yield json.dumps({"generated_text": chunk}) + "\n" |
|
|
|
return generate_stream() |
|
else: |
|
|
|
generated_text = generate( |
|
prompt, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
stream=False |
|
) |
|
|
|
|
|
return { |
|
"statusCode": 200, |
|
"headers": { |
|
"Content-Type": "application/json" |
|
}, |
|
"body": json.dumps({"generated_text": generated_text}) |
|
} |
|
|
|
except Exception as e: |
|
|
|
print(f"Error: {str(e)}") |
|
return { |
|
"statusCode": 500, |
|
"headers": { |
|
"Content-Type": "application/json" |
|
}, |
|
"body": json.dumps({"error": str(e)}) |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_event = { |
|
"body": json.dumps({ |
|
"messages": [ |
|
{"role": "user", "content": "Explain the principles of reservoir simulation in oil and gas engineering."} |
|
] |
|
}) |
|
} |
|
|
|
response = handler(test_event, None) |
|
if isinstance(response, dict): |
|
print(json.loads(response["body"])["generated_text"]) |
|
else: |
|
|
|
for chunk in response: |
|
print(json.loads(chunk)["generated_text"], end="") |
|
|