Spaces:
Running
Running
import torch | |
from torch.nn import Linear | |
from torch_geometric.nn import HGTConv, MLP | |
import pandas as pd | |
import yaml | |
import os | |
from datasets import load_dataset | |
import gdown | |
import copy | |
import json | |
import gzip | |
class ProtHGT(torch.nn.Module): | |
def __init__(self, data,hidden_channels, num_heads, num_layers, mlp_hidden_layers, mlp_dropout): | |
super().__init__() | |
self.lin_dict = torch.nn.ModuleDict() | |
for node_type in data.node_types: | |
input_dim = data[node_type].x.size(1) # Get actual input dimension from data | |
self.lin_dict[node_type] = Linear(input_dim, hidden_channels) | |
self.convs = torch.nn.ModuleList() | |
for _ in range(num_layers): | |
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads, group='sum') | |
self.convs.append(conv) | |
self.mlp = MLP(mlp_hidden_layers , dropout=mlp_dropout, norm=None) | |
def generate_embeddings(self, x_dict, edge_index_dict): | |
# Generate updated embeddings through the HGT layers | |
x_dict = { | |
node_type: self.lin_dict[node_type](x).relu_() | |
for node_type, x in x_dict.items() | |
} | |
for conv in self.convs: | |
x_dict = conv(x_dict, edge_index_dict) | |
return x_dict | |
def forward(self, x_dict, edge_index_dict, tr_edge_label_index, target_type, test=False): | |
# Get updated embeddings | |
x_dict = self.generate_embeddings(x_dict, edge_index_dict) | |
# Make predictions | |
row, col = tr_edge_label_index | |
z = torch.cat([x_dict["Protein"][row], x_dict[target_type][col]], dim=-1) | |
return self.mlp(z).view(-1), x_dict | |
def _load_data(heterodata, protein_ids, go_category): | |
"""Process the loaded heterodata for specific proteins and GO categories.""" | |
# Get protein indices for all input proteins | |
protein_indices = [heterodata['Protein']['id_mapping'][pid] for pid in protein_ids] | |
n_terms = len(heterodata[go_category]['id_mapping']) | |
all_edges = [] | |
for protein_idx in protein_indices: | |
for term_idx in range(n_terms): | |
all_edges.append([protein_idx, term_idx]) | |
edge_index = torch.tensor(all_edges).t() | |
heterodata[('Protein', 'protein_function', go_category)].edge_index = edge_index | |
heterodata[(go_category, 'rev_protein_function', 'Protein')].edge_index = torch.stack([edge_index[1], edge_index[0]]) | |
return heterodata | |
def get_available_proteins(name_file='data/name_info.json.gz'): | |
with gzip.open(name_file, 'rt', encoding='utf-8') as file: | |
name_info = json.load(file) | |
return list(name_info['Protein'].keys()) | |
def _generate_predictions(heterodata, model, target_type): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
model.eval() | |
heterodata = heterodata.to(device) | |
with torch.no_grad(): | |
edge_label_index = heterodata.edge_index_dict[('Protein', 'protein_function', target_type)] | |
predictions, _ = model(heterodata.x_dict, heterodata.edge_index_dict, edge_label_index, target_type) | |
predictions = torch.sigmoid(predictions) | |
return predictions.cpu() | |
def _create_prediction_df(predictions, heterodata, protein_ids, go_category): | |
go_category_dict = { | |
'GO_term_F': 'Molecular Function', | |
'GO_term_P': 'Biological Process', | |
'GO_term_C': 'Cellular Component' | |
} | |
# Load name information from gzipped file | |
with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file: | |
name_info = json.load(file) | |
# Get number of GO terms for this category | |
n_go_terms = len(heterodata[go_category]['id_mapping']) | |
# Create lists to store the data | |
all_proteins = [] | |
all_protein_names = [] | |
all_go_terms = [] | |
all_go_term_names = [] | |
all_categories = [] | |
all_probabilities = [] | |
# Get list of GO terms once | |
go_terms = list(heterodata[go_category]['id_mapping'].keys()) | |
# Process predictions for each protein | |
for i, protein_id in enumerate(protein_ids): | |
# Get predictions for this protein | |
start_idx = i * n_go_terms | |
end_idx = (i + 1) * n_go_terms | |
protein_predictions = predictions[start_idx:end_idx] | |
# Get protein name | |
protein_name = name_info['Protein'].get(protein_id, protein_id) | |
# Extend the lists | |
all_proteins.extend([protein_id] * n_go_terms) | |
all_protein_names.extend([protein_name] * n_go_terms) | |
all_go_terms.extend(go_terms) | |
all_go_term_names.extend([name_info['GO_term'].get(term_id, term_id) for term_id in go_terms]) | |
all_categories.extend([go_category_dict[go_category]] * n_go_terms) | |
all_probabilities.extend(protein_predictions.tolist()) | |
# Create DataFrame | |
prediction_df = pd.DataFrame({ | |
'UniProt_ID': all_proteins, | |
'Protein': all_protein_names, | |
'GO_ID': all_go_terms, | |
'GO_term': all_go_term_names, | |
'GO_category': all_categories, | |
'Probability': all_probabilities | |
}) | |
return prediction_df | |
def generate_prediction_df(protein_ids, model_paths, model_config_paths, go_category): | |
all_predictions = [] | |
# Convert single protein ID to list if necessary | |
if isinstance(protein_ids, str): | |
protein_ids = [protein_ids] | |
# Load dataset once | |
# heterodata = load_dataset('HUBioDataLab/ProtHGT-KG', data_files="prothgt-kg.json.gz") | |
print('Loading data...') | |
file_id = "18u1o2sm8YjMo9joFw4Ilwvg0-rUU0PXK" | |
output = "data/prothgt-kg.pt" | |
if not os.path.exists(output): | |
try: | |
url = f"https://drive.google.com/uc?id={file_id}" | |
print(f"Downloading file from {url}...") | |
gdown.download(url, output, quiet=False) | |
print(f"File downloaded to {output}") | |
except Exception as e: | |
print(f"Error downloading file: {e}") | |
raise | |
else: | |
print(f"File already exists at {output}") | |
heterodata = torch.load(output) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
for go_cat, model_config_path, model_path in zip(go_category, model_config_paths, model_paths): | |
print(f'Generating predictions for {go_cat}...') | |
# Process data for current GO category | |
processed_data = _load_data(copy.deepcopy(heterodata), protein_ids, go_cat) | |
# Load model config | |
with open(model_config_path, 'r') as file: | |
model_config = yaml.safe_load(file) | |
# Initialize model with configuration | |
model = ProtHGT( | |
processed_data, | |
hidden_channels=model_config['hidden_channels'][0], | |
num_heads=model_config['num_heads'], | |
num_layers=model_config['num_layers'], | |
mlp_hidden_layers=model_config['hidden_channels'][1], | |
mlp_dropout=model_config['mlp_dropout'] | |
) | |
# Load model weights | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
print(f'Loaded model weights from {model_path}') | |
# Generate predictions | |
predictions = _generate_predictions(processed_data, model, go_cat) | |
prediction_df = _create_prediction_df(predictions, processed_data, protein_ids, go_cat) | |
all_predictions.append(prediction_df) | |
# Clean up memory | |
del processed_data | |
del model | |
del predictions | |
torch.cuda.empty_cache() # Clear CUDA cache if using GPU | |
# Combine all predictions | |
final_df = pd.concat(all_predictions, ignore_index=True) | |
# Clean up | |
del all_predictions | |
torch.cuda.empty_cache() | |
return heterodata, final_df | |