Unconditional Sequence Generation does not work

#2
by DaniDubi - opened

Hi GleghornLab and Synthyra,
@lhallee

Thanks again for releasing this and other protein LLM models.

I'm following your example for "Unconditional Sequence Generation".
If I am running it as is getting this error:
TypeError: GenerateMixin.mask_diffusion_generate() missing 1 required positional argument: 'tokenizer'

If I add 'tokenizer' object (tokenizer=tokenizer) I get this error:

    148 self.special_token_ids = self._get_special_token_ids(extra_tokens)
    149 self.special_token_ids = torch.tensor(self.special_token_ids, device=device).flatten()
--> 151 num_mask_tokens = (input_tokens == mask_token_id).sum().item()
    152 steps = max(1, num_mask_tokens // step_divisor)
    154 trajectory = []
AttributeError: 'bool' object has no attribute 'sum'
Gleghorn Lab org

Hey @DaniDubi ,

The readme was outdated and has been updated now. Is running fine on my end. If you still get the error could you paste in a complete step by step on how you got there.
Best,
Logan

The code I ran to confirm:

import torch
from models.modeling_dsm import DSM # Or DSM_ppi for binder generation

# Load a pre-trained model
model_name_or_path = "GleghornLab/DSM_650" # Replace with your model of choice
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DSM.from_pretrained(model_name_or_path).to(device).eval()
tokenizer = model.tokenizer

### Unconditional generation
length = 100
mask_token = tokenizer.mask_token
# optionally, enforce starting with methionine
input_tokens = tokenizer.encode('M' + ''.join([mask_token] * (length - 1)), add_special_tokens=True, return_tensors='pt').to(device)
output = model.mask_diffusion_generate(
    tokenizer=tokenizer,
    input_tokens=input_tokens,
    step_divisor=100,          # lower is slower but better
    temperature=1.0,           # sampling temperature
    remasking="random",        # strategy for remasking tokens not kept
    preview=False,             # set this to True to watch the mask tokens get rilled in real time
    slow=False,                # adds a small delay to the real time filling (because it is usually very fast and watching carefully is hard!)
    return_trajectory=False    # set this to True to return the trajectory of the generation (what you watch in the preview)
) # Note: output will be a tuple if return_trajectory is True

generated_sequences = model.decode_output(output)
print(f"Generated sequence: {generated_sequences[0]}")

The output:

Generated sequence: MKRIDLLFTGFVDQRPHNEEVILVAYGITLGAPASERTGFTRDLQGDLIDERARGGEFRFDMIAKDDFAPAGFTCHGAVHVLRRFIFLGAPDPIYVNMSL
DaniDubi changed discussion status to closed

Sign up or log in to comment