S2S inference

#6
by james-golden-arcadia - opened

Thank you for open-sourcing this exciting model! It is impressive that the diverse set of training tasks allows the xl model to continue to improve over the large model performance. It’s also great to see how this was trained with Jax on TPUs.

In running inference with the model, I had some issues with reproducing the S2S completion examples in the README. For the sequence completion example, was teacher forcing used at inference? I also observed excellent performance with [NLU] for predicting masked sites.

Elnaggar Lab org

HI @james-golden-arcadia ,

Thanks for your kind words.

Could you please share the issues you are facing with reproducing the S2S completion example in the readme?
Please share an example in Colab for easy debugging of the issue.

For the S2S example in the readme, teacher forcing is not used at inference for the second half of the sequence, because it needs to generate it by itself, amino acid by amino acid.

Thank you for sharing that our model achieved excellent performance in predicting masked amino acids using NLU in your tasks :)

Elnaggar Lab org

Thank you for the quick response! I am working with transformers==4.53.1.

With the example from the HF README:

from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers.generation import GenerationConfig
import torch

sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"

ckpt = "ElnaggarLab/ankh3-large"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
# To use the sequence to sequence task using the S2S prefix:
model = T5ForConditionalGeneration.from_pretrained(ckpt).eval()


half_length = int(len(sequence) * 0.5)
s2s_sequence = "[S2S]" + sequence[:half_length]
encoded_s2s_sequence = tokenizer(s2s_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
# + 1 to account for the start of sequence token.
gen_config = GenerationConfig(min_length=half_length + 1, max_length=half_length + 1, do_sample=False, num_beams=1)
generated_sequence = model.generate(encoded_s2s_sequence["input_ids"], gen_config, )
predicted_sequence = sequence[:half_length] + tokenizer.batch_decode(generated_sequence)[0]
sequence[half_length:], tokenizer.batch_decode(generated_sequence)[0].split('<pad>')[1]

The appropriate prediction is:
'KVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD'

but I get:
'GGRGFSAFYLRYFRAFATTLAVAVSITITFTFTVVVLPISALQRVIRLNGGFEFNNEDGLAGLGIF'

With ankh3-xl, it is much closer:

'KVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD',
'RVAGDLEAARRFGSKAKCYNILATTWALVPPLLLLGLVVTGALHLSRLAKDSAAFFSTKLDDSDYD'

For the shorter example:

from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers.generation import GenerationConfig
import torch

sequence = "MKAYVLINSRGP"

ckpt = "ElnaggarLab/ankh3-large"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
# To use the sequence to sequence task using the S2S prefix:
model = T5ForConditionalGeneration.from_pretrained(ckpt).eval()


half_length = int(len(sequence) * 0.5)
s2s_sequence = "[S2S]" + sequence[:half_length]
encoded_s2s_sequence = tokenizer(s2s_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
# + 1 to account for the start of sequence token.
gen_config = GenerationConfig(min_length=half_length + 1, max_length=half_length + 1, do_sample=False, num_beams=1)
generated_sequence = model.generate(encoded_s2s_sequence["input_ids"], gen_config, )
predicted_sequence = sequence[:half_length] + tokenizer.batch_decode(generated_sequence)[0]
sequence[half_length:], tokenizer.batch_decode(generated_sequence)[0].split('<pad>')[1]

The appropriate prediction is
'INSRGP',

I get:
'LALALA'

With ankh3-xl:

'INSRGP',
'LLLLLL'

Elnaggar Lab org

It is expected that the XL will provide better performance compared to the large model, as you mentioned in your first test.

For the second test, it seems that either:

  1. The input was so short for the encoder part, and that didn't allow it to provide a good representation for the decoder part to provide correct results.
  2. It is out of the distribution of the training data, and that is why the model could not provide a sound output.

I would test the following for the above two points:

  1. Test the whole input except for the last amino acid and check the results, then repeat with n-2, n-3, etc. This will allow a larger input chunk, enabling you to detect at which input range the model cannot provide what you are looking for.
  2. Test different GenerationConfig parameters. This will allow the model to be more creative and generate protein sequences outside its trained distribution.
    The list of parameters is here:
    https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
    You can start testing with the following parameters:
    num_beams=4
    temperature=0.7
    penalty_alpha = 10
    top_k=4
james-golden-arcadia changed discussion status to closed

Sign up or log in to comment