lhallee commited on
Commit
4464624
·
verified ·
1 Parent(s): 4b8accf

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +25 -20
README.md CHANGED
@@ -78,14 +78,17 @@ To generate a novel sequence of a specific length. DSM uses a progressive denois
78
  length = 100
79
  mask_token = tokenizer.mask_token
80
  # optionally, enforce starting with methionine
81
- input_template = tokenizer.encode('M' + ''.join([mask_token] * (length - 1)), add_special_tokens=True).to(device)
82
  output = model.mask_diffusion_generate(
83
- input_tokens=input_template,
84
- step_divisor=100, # lower is slower but better
85
- temperature=1.0, # sampling temperature
86
- remasking="random", # strategy for remasking tokens not kept
87
- preview=False #
88
- )
 
 
 
89
 
90
  generated_sequences = model.decode_output(output)
91
  print(f"Generated sequence: {generated_sequences[0]}")
@@ -101,15 +104,18 @@ To fill in masked regions of a template sequence:
101
  ```python
102
  # Mask Filling / Inpainting
103
  template_sequence = "MA<mask><mask><mask>KEG<mask><mask>STL"
104
- template_tokens = model.tokenizer.encode(template_sequence, add_special_tokens=True).to(device)
105
 
106
- filled_ids = model.mask_diffusion_generate(
107
- input_tokens=template_tokens,
108
- step_divisor=100, # lower is slower but better
109
- temperature=1.0, # sampling temperature
110
- remasking="random", # strategy for remasking tokens not kept
111
- preview=False
112
- )
 
 
 
113
 
114
  generated_sequences = model.decode_output(output)
115
  print(f"Generated sequence: {generated_sequences[0]}")
@@ -125,9 +131,8 @@ Generated sequence: MAVKFKEGGISTL
125
  # model_binder = DSM_ppi.from_pretrained("GleghornLab/DSM_650_ppi_lora").to(device).eval()
126
  # The lora version from the paper leads to unreliable outputs
127
  # Synthyra has generously trained a version through full fine tuning
128
- from models.modeling_dsm import DSM
129
 
130
- model_binder = DSM.from_pretrained("Synthyra/DSM_ppi_full").to(device).eval()
131
 
132
  # BBF-14
133
  target_seq = "MGTPLWALLGGPWRGTATYEDGTKVTLDYRYTRVSPDRLRADVTYTTPDGTTLEATVDLWKDANGVIRYHATYPDGTSADGTLTQLDADTLLATGTYDDGTKYTVTLTRVAPGSGWHHHHHH"
@@ -140,10 +145,10 @@ combined_input_str = target_seq + '<eos>' + interactor_template
140
 
141
  input_tokens = tokenizer.encode(combined_input_str, add_special_tokens=True, return_tensors='pt').to(device)
142
 
143
- output = model_binder.mask_diffusion_generate(
144
  tokenizer=tokenizer,
145
  input_tokens=input_tokens,
146
- step_divisor=10, # lower is slower but better
147
  temperature=1.0, # sampling temperature
148
  remasking="random", # strategy for remasking tokens not kept
149
  preview=False, # set this to True to watch the mask tokens get rilled in real time
@@ -192,7 +197,7 @@ output = model.mask_diffusion_generate(
192
  seqa, seqb = model.decode_dual_input(output, seperator='<eos>')
193
  # Parse out the generated interactor part based on EOS tokens.
194
  # Example: generated_full_seq_str.split(model_binder.tokenizer.eos_token)[1]
195
- print(f"SeqA: {seqa[0][4:]}") # remove cls token
196
  print(f"SeqB: {seqb[0]}")
197
  ```
198
 
 
78
  length = 100
79
  mask_token = tokenizer.mask_token
80
  # optionally, enforce starting with methionine
81
+ input_tokens = tokenizer.encode('M' + ''.join([mask_token] * (length - 1)), add_special_tokens=True, return_tensors='pt').to(device)
82
  output = model.mask_diffusion_generate(
83
+ tokenizer=tokenizer,
84
+ input_tokens=input_tokens,
85
+ step_divisor=100, # lower is slower but better
86
+ temperature=1.0, # sampling temperature
87
+ remasking="random", # strategy for remasking tokens not kept
88
+ preview=False, # set this to True to watch the mask tokens get rilled in real time
89
+ slow=False, # adds a small delay to the real time filling (because it is usually very fast and watching carefully is hard!)
90
+ return_trajectory=False # set this to True to return the trajectory of the generation (what you watch in the preview)
91
+ ) # Note: output will be a tuple if return_trajectory is True
92
 
93
  generated_sequences = model.decode_output(output)
94
  print(f"Generated sequence: {generated_sequences[0]}")
 
104
  ```python
105
  # Mask Filling / Inpainting
106
  template_sequence = "MA<mask><mask><mask>KEG<mask><mask>STL"
107
+ input_tokens = tokenizer.encode(template_sequence, add_special_tokens=True, return_tensors='pt').to(device)
108
 
109
+ output = model.mask_diffusion_generate(
110
+ tokenizer=tokenizer,
111
+ input_tokens=input_tokens,
112
+ step_divisor=100, # lower is slower but better
113
+ temperature=1.0, # sampling temperature
114
+ remasking="random", # strategy for remasking tokens not kept
115
+ preview=False, # set this to True to watch the mask tokens get rilled in real time
116
+ slow=False, # adds a small delay to the real time filling (because it is usually very fast and watching carefully is hard!)
117
+ return_trajectory=False # set this to True to return the trajectory of the generation (what you watch in the preview)
118
+ ) # Note: output will be a tuple if return_trajectory is True
119
 
120
  generated_sequences = model.decode_output(output)
121
  print(f"Generated sequence: {generated_sequences[0]}")
 
131
  # model_binder = DSM_ppi.from_pretrained("GleghornLab/DSM_650_ppi_lora").to(device).eval()
132
  # The lora version from the paper leads to unreliable outputs
133
  # Synthyra has generously trained a version through full fine tuning
 
134
 
135
+ model = DSM.from_pretrained("Synthyra/DSM_ppi_full").to(device).eval()
136
 
137
  # BBF-14
138
  target_seq = "MGTPLWALLGGPWRGTATYEDGTKVTLDYRYTRVSPDRLRADVTYTTPDGTTLEATVDLWKDANGVIRYHATYPDGTSADGTLTQLDADTLLATGTYDDGTKYTVTLTRVAPGSGWHHHHHH"
 
145
 
146
  input_tokens = tokenizer.encode(combined_input_str, add_special_tokens=True, return_tensors='pt').to(device)
147
 
148
+ output = model.mask_diffusion_generate(
149
  tokenizer=tokenizer,
150
  input_tokens=input_tokens,
151
+ step_divisor=100, # lower is slower but better
152
  temperature=1.0, # sampling temperature
153
  remasking="random", # strategy for remasking tokens not kept
154
  preview=False, # set this to True to watch the mask tokens get rilled in real time
 
197
  seqa, seqb = model.decode_dual_input(output, seperator='<eos>')
198
  # Parse out the generated interactor part based on EOS tokens.
199
  # Example: generated_full_seq_str.split(model_binder.tokenizer.eos_token)[1]
200
+ print(f"SeqA: {seqa[0][5:]}") # remove cls token
201
  print(f"SeqB: {seqb[0]}")
202
  ```
203