""" Healthcare Standards RAFT (RAG + Fine-Tuned LoRA) implementation. This is a simple wrapper class to make the model easy to use. """ import os from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer from llama_index.core import VectorStoreIndex, load_index_from_storage from llama_index.embeddings.huggingface import HuggingFaceEmbedding from peft import PeftModel, PeftConfig class HealthcareStandardsRAFT: """ Healthcare Standards RAFT system that combines RAG and LoRA fine-tuning. """ def __init__(self, model_path=None, device="cuda" if torch.cuda.is_available() else "cpu"): """ Initialize the Healthcare Standards RAFT system. Args: model_path: Path to model directory or Hugging Face repo name device: Device to use for inference (cuda/cpu) """ # Handle local fallback if no path is provided if model_path is None: model_path = "./healthcare-standards-raft" self.model_path = model_path self.adapter_dir = os.path.join(self.model_path, "model") self.device = device # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-4-mini-instruct") # Load base model and apply LoRA adapter self._load_model() # Load vector index for RAG self._load_vector_index() def _load_model(self): """Load base model and apply LoRA weights using PEFT .""" print("Loading base Phi-4-mini model...") self.model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-4-mini-instruct", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None ) adapter_dir = os.path.join(self.model_path, "model") print(f"Applying LoRA adapter from: {adapter_dir}") self.model = PeftModel.from_pretrained(self.model,adapter_dir) def _load_vector_index(self): """Load the vector index for RAG.""" try: # Determine the index path if os.path.isdir(self.model_path): index_path = os.path.join(self.model_path, "vector_index") else: index_path = "healthcare-standards-raft/vector_index" # Local cached path # Create embedding model embed_model = HuggingFaceEmbedding( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Load storage context from llama_index.core import StorageContext if os.path.exists(index_path): storage_context = StorageContext.from_defaults( persist_dir=index_path ) # Load index from llama_index.core import load_index_from_storage self.index = load_index_from_storage( storage_context, embed_model=embed_model ) else: print(f"Warning: Vector index not found at {index_path}") self.index = None except Exception as e: print(f"Error loading vector index: {e}") self.index = None def query(self, question, max_tokens=512, temperature=0.7): """ Query the Healthcare Standards RAFT system. Args: question: The healthcare standards question to answer max_tokens: Maximum number of tokens in the response temperature: Temperature for generation Returns: str: The model's response """ # Retrieve relevant context if index is available context = "" if self.index is not None: # Get retriever retriever = self.index.as_retriever(similarity_top_k=3) # Retrieve relevant documents nodes = retriever.retrieve(question) # Extract text from nodes if nodes: context_texts = [] for node in nodes: if hasattr(node, 'node') and hasattr(node.node, 'get_content'): context_texts.append(node.node.get_content()) elif hasattr(node, 'get_content'): context_texts.append(node.get_content()) elif hasattr(node, 'text'): context_texts.append(node.text) else: context_texts.append(str(node)) # Join context texts context = "\n\n".join(context_texts) # Create prompt if context: prompt = f"""You are a Healthcare Standards Expert. Use the following context to answer the question. Context: {context} Question: {question} Answer:""" else: prompt = f"""You are a Healthcare Standards Expert. Answer the following question based on your knowledge. Question: {question} Answer:""" # Generate response # Ensure pad token is set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) # Create attention mask inputs["attention_mask"] = (inputs["input_ids"] != self.tokenizer.pad_token_id).long() # Move to device inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=max_tokens, temperature=temperature, top_p=0.9, do_sample=temperature > 0, pad_token_id=self.tokenizer.pad_token_id # good practice ) # Decode response response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) return response def get_retrieved_contexts(self, question): """ Get the contexts retrieved for a specific question. Args: question (str): The question to retrieve contexts for Returns: list: List of retrieved context strings """ try: if hasattr(self, 'index') and self.index is not None: # Use the retriever to get relevant documents retriever = self.index.as_retriever(similarity_top_k=3) nodes = retriever.retrieve(question) # Extract text from the retrieved nodes contexts = [] for node in nodes: if hasattr(node, 'node') and hasattr(node.node, 'get_content'): contexts.append(node.node.get_content()) elif hasattr(node, 'get_content'): contexts.append(node.get_content()) elif hasattr(node, 'text'): contexts.append(node.text) else: contexts.append(str(node)) return contexts else: return ["Vector index not available"] except Exception as e: print(f"Error retrieving contexts: {e}") return [f"Error retrieving contexts: {e}"]