David Prihoda commited on
Commit
5500725
·
1 Parent(s): fb77734
Files changed (4) hide show
  1. Dockerfile +0 -2
  2. requirements.txt +2 -2
  3. src/example.py +0 -34
  4. src/streamlit_app.py +156 -67
Dockerfile CHANGED
@@ -14,8 +14,6 @@ COPY src/ ./src/
14
 
15
  RUN pip3 install -r requirements.txt
16
 
17
- ENV HF_HUB_CACHE="/tmp/huggingface"
18
-
19
  EXPOSE 8501
20
 
21
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
 
14
 
15
  RUN pip3 install -r requirements.txt
16
 
 
 
17
  EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  altair
2
  pandas
3
  streamlit
4
- transformers
5
- torch
 
1
  altair
2
  pandas
3
  streamlit
4
+ logomaker
5
+ promb
src/example.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import pandas as pd
3
- from transformers import RobertaForMaskedLM, RobertaTokenizer
4
-
5
- tokenizer = RobertaTokenizer.from_pretrained('prihodad/biophi-sapiens1-tokenizer')
6
- vh_model = RobertaForMaskedLM.from_pretrained('prihodad/biophi-sapiens1-vh')
7
- vl_model = RobertaForMaskedLM.from_pretrained('prihodad/biophi-sapiens1-vl')
8
-
9
- def sapiens_predict(tokenizer, model, seq, probs=True):
10
- encoded_input = tokenizer(seq, return_tensors='pt')
11
- with torch.no_grad():
12
- logits = model(**encoded_input).logits[0][1:-1].cpu()
13
-
14
- index_to_token = {idx: token for token, idx in tokenizer.get_vocab().items()}
15
- return pd.DataFrame(
16
- logits.numpy() if not probs else torch.softmax(logits, dim=-1).numpy(),
17
- columns=[index_to_token[i] for i in range(logits.shape[1])]
18
- )[list('ACDEFGHIKLMNPQRSTVWY')]
19
-
20
- seq = "QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS"
21
- probs = sapiens_predict(
22
- tokenizer,
23
- vh_model,
24
- seq,
25
- )
26
-
27
- top = probs.idxmax(axis=1)
28
-
29
- print(" Input:", seq)
30
- print(" ", ''.join('|' if aa == bb else ' ' for aa, bb in zip(seq, top)))
31
- print("Output:", ''.join(top))
32
-
33
- print("Probs:")
34
- print(probs.head())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/streamlit_app.py CHANGED
@@ -2,97 +2,186 @@ import altair as alt
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
- from transformers import RobertaForMaskedLM, RobertaTokenizer
6
- import torch
7
- import torch.nn.functional as F
8
 
9
  st.set_page_config(layout="wide")
10
 
11
- @st.cache_resource
12
- def get_tokenizer():
13
- return RobertaTokenizer.from_pretrained('prihodad/biophi-sapiens1-tokenizer')
14
 
15
 
16
- @st.cache_resource
17
- def get_vh_model():
18
- return RobertaForMaskedLM.from_pretrained('prihodad/biophi-sapiens1-vh')
19
 
 
 
 
 
 
 
 
 
20
 
21
- @st.cache_resource
22
- def get_vl_model():
23
- return RobertaForMaskedLM.from_pretrained('prihodad/biophi-sapiens1-vl')
24
 
 
 
 
 
 
 
25
 
26
- def sapiens_predict(tokenizer, model, seq, probs=True):
27
- encoded_input = tokenizer(seq, return_tensors='pt')
28
- with torch.no_grad():
29
- logits = model(**encoded_input).logits[0][1:-1].cpu()
30
 
31
- index_to_token = {idx: token for token, idx in tokenizer.get_vocab().items()}
32
- df = pd.DataFrame(
33
- logits.numpy() if not probs else torch.softmax(logits, dim=-1).numpy(),
34
- index=range(1, len(seq) + 1),
35
- columns=[index_to_token[i] for i in range(logits.shape[1])]
36
- )[list('ACDEFGHIKLMNPQRSTVWY')]
37
- df.index.name = "position"
38
- return df
39
 
 
40
 
41
- def display_prediction(probs, seq):
42
- top = probs.idxmax(axis=1)
43
 
44
- out = f" Input: {seq}" + "\n"
45
- out += " " + "".join(' ' if aa == bb else '#' for aa, bb in zip(seq, top)) + "\n"
46
- out += "Output: " + "".join(top)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- st.code(out, language="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- probs_long = probs.reset_index().melt(id_vars="position", var_name="aa", value_name="probability")
 
51
 
52
- heatmap = alt.Chart(probs_long).mark_rect().encode(
53
- x=alt.X("position:O", title="Sequence Position"),
54
- y=alt.Y("aa:N", title="Amino Acid", sort=None),
55
- color=alt.Color("probability:Q", scale=alt.Scale(scheme="viridis")),
56
- tooltip=["position", "aa", "probability"]
57
- ).properties(
58
- height=600,
59
- title="Amino Acid Probabilities"
60
- )
61
- st.altair_chart(heatmap, use_container_width=True)
62
 
 
 
63
 
64
- st.write('# Sapiens human antibody language model')
 
 
65
 
66
- vh = st.text_area("Heavy chain", value="QVQLQQSGAELARPGASVKMSCKASGYTFTRYTMHWVKQRPGQGLEWIGYINPSRGYTNYNQKFKDKATLTTDKSSSTAYMQLSSLTSEDSAVYYCARYYDDHYCLDYWGQGTTLTVSS")
67
 
68
- vl = st.text_area("Light chain")
 
 
 
69
 
70
- if len(vh) > 180 or len(vl) > 180:
71
- st.error("Please enter only the variable region - got sequence of length > 180")
72
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- with st.spinner("Predicting..."):
75
- tokenizer = get_tokenizer()
76
- vh_model = get_vh_model()
77
- vl_model = get_vl_model()
78
 
79
- st.write("### VH")
80
- if vh:
81
- vh_probs = sapiens_predict(tokenizer, vh_model, vh)
82
- display_prediction(vh_probs, vh)
83
- else:
84
- st.write("No heavy chain provided")
85
 
86
- st.write("### VL")
87
- if vl:
88
- vl_probs = sapiens_predict(tokenizer, vl_model, vl)
89
- display_prediction(vl_probs, vl)
90
- else:
91
- st.write("No light chain provided")
92
 
93
- st.write("## Try it out yourself")
94
 
95
- st.write("Install dependencies `pip install transformers torch` and run:")
96
 
97
- with open("src/example.py") as f:
98
- st.code(f.read(), language="python")
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import streamlit as st
5
+ from promb import init_db, print_nearest
6
+ from io import StringIO
7
+ import logomaker
8
 
9
  st.set_page_config(layout="wide")
10
 
11
+ @st.cache_resource(show_spinner=False)
12
+ def init_db_cached(db_name, *args, **kwargs):
13
+ return init_db(db_name, *args, **kwargs)
14
 
15
 
16
+ st.write('# promb - protein humanness evaluation')
 
 
17
 
18
+ st.code('''pip install promb
19
+ * █
20
+ ▄▄▄▄ ▄▄▄ ▄█▄ ▄▄▄▄ █▄▄▄
21
+ █ █ █ █▓███ █ █ █ █ █
22
+ █▄▄▄▀ █ ▀███▀ █ █ █▄▄▄▀
23
+
24
+ ▀ protein mutation burden
25
+ ''')
26
 
27
+ if st.button("Load example"):
28
+ st.session_state["seq"] = "SPLQKASDSLINIAIKMLRNGINPELAKKLWDIAYKISMSHIDPSSFYEALKELKKLIEEQEEELIEA"
 
29
 
30
+ with st.form(border=False, key="input"):
31
+ seq = st.text_area(
32
+ "Amino acid sequence",
33
+ key="seq"
34
+ )
35
+ seq = "".join(seq.split())
36
 
 
 
 
 
37
 
38
+ left, mid, right = st.columns(3)
39
+ with left:
40
+ database_name = st.selectbox("Database", options=["human-reference", "human-swissprot", "human-oas"])
41
+ with mid:
42
+ peptide_length = st.number_input("Peptide length", value=9, min_value=2, max_value=20, disabled=database_name == "human-oas")
43
+ with right:
44
+ num_nearest = st.number_input("Nearest peptides", value=1, min_value=1, max_value=5, help="Number of nearest human peptides used for visualization and to compute PSSM and suggest humanizing mutations")
 
45
 
46
+ st.form_submit_button("Run", type="primary")
47
 
48
+ if not seq:
49
+ st.stop()
50
 
51
+ with st.spinner(f"Finding nearest {peptide_length}mer peptides in {database_name}..."):
52
+ db = init_db_cached(database_name, peptide_length if database_name != "human-oas" else None)
53
+ peptides = db.chop_seq_peptides(seq)
54
+ nearest = db.find_nearest_peptides(peptides, n=num_nearest)
55
+
56
+ st.write("## Result")
57
+
58
+ st.metric("Human Peptide Content", "{:.1%} human".format(db.compute_peptide_content(seq)))
59
+ num_mutations = 0
60
+ for peptide, hits in zip(peptides, nearest):
61
+ num_mutations += sum(aa != bb for aa, bb in zip(peptide, hits[0]))
62
+
63
+ st.metric("Mutation Burden", "{:.1f} mutations per {}mer".format(num_mutations / len(peptides), peptide_length))
64
+
65
+ likelihood = pd.DataFrame({
66
+ "likelihood": db.compute_positional_likelihood(seq, nearest_peptides=nearest),
67
+ "metric": "likelihood",
68
+ "aa": list(seq),
69
+ "position": range(1, len(seq)+1)
70
+ })
71
+
72
+ wrap = 50
73
+ st.write("### Positional likelihood")
74
+ st.write(f"Fraction of nearest overlapping {peptide_length}mers that contain the input amino acid at that position. Positions with values close to 0 can be considered to be non-human, and values close to 1 to be human.")
75
+ for start in range(0, len(seq), wrap):
76
+ chunk = likelihood.iloc[start:start+wrap]
77
+ heatmap = alt.Chart(chunk).mark_rect().encode(
78
+ x=alt.X("position:O", title="Sequence Position"),
79
+ y=alt.Y("metric:N", title="Metric"),
80
+ color=alt.Color("likelihood:Q", scale=alt.Scale(scheme="reds", reverse=True, domain=(0, 1))),
81
+ tooltip=["position", "likelihood", "aa"]
82
+ )
83
+ text = alt.Chart(chunk).mark_text(baseline="middle", fontSize=12).encode(
84
+ x=alt.X("position:O", title="Sequence Position"),
85
+ y=alt.Y("metric:N", title="Metric"),
86
+ text=alt.Text("aa:N"),
87
+ tooltip=["position", "likelihood", "aa"]
88
+ )
89
+ chart = (heatmap + text).properties(
90
+ width=250 + (15 * len(chunk)),
91
+ height=180,
92
+ title=f"Positions {start+1}-{start+wrap}"
93
+ )
94
+ st.altair_chart(chart, use_container_width=False)
95
 
96
+ st.write("### Position-specific scoring matrix")
97
+ st.write("A PSSM (PWM) computed by counting occurences of amino acids in nearest overlapping human peptides at each position.")
98
+ pssm = db.compute_pssm(seq, nearest_peptides=nearest)
99
+ freqs_long = pssm.reset_index().melt(id_vars="position", var_name="aa", value_name="count")
100
+ heatmap = alt.Chart(freqs_long).mark_rect().encode(
101
+ x=alt.X("position:O", title="Sequence Position"),
102
+ y=alt.Y("aa:N", title="Amino Acid", sort=None),
103
+ color=alt.Color("count:Q", scale=alt.Scale(scheme="viridis")),
104
+ tooltip=["position", "aa", "count"]
105
+ ).properties(
106
+ height=600,
107
+ title="Amino Acid Frequencies"
108
+ )
109
+ st.altair_chart(heatmap, use_container_width=True)
110
 
111
+ st.write("#### Sequence logo")
112
+ st.write("PSSM computed from nearest human peptides visualized using logomaker library")
113
 
114
+ logo = logomaker.Logo(pssm, figsize=(min(50, 2 + 0.2 * len(seq)), 2))
115
+ st.pyplot(logo.fig)
 
 
 
 
 
 
 
 
116
 
117
+ st.write("#### Suggested mutations")
118
+ st.write("PSSM but without counting amino acids found in input sequence at each position")
119
 
120
+ pssm_mutations = db.compute_pssm(seq, nearest_peptides=nearest, ignore_wildtype=True)
121
+ logo = logomaker.Logo(pssm_mutations, figsize=(min(50, 2 + 0.2 * len(seq)), 2))
122
+ st.pyplot(logo.fig)
123
 
124
+ st.write("### Nearest human peptides")
125
 
126
+ stream = StringIO()
127
+ print_nearest(peptides, nearest, file=stream)
128
+ with st.container(height=400):
129
+ st.code(stream.getvalue())
130
 
131
+ st.write("## Humanization (naive approach)")
132
+
133
+ st.write("Generate **slightly** humanized variants by applying 1-3 mutations based on nearest overlapping peptides")
134
+
135
+ if st.button("Generate humanized mutants", type="primary"):
136
+
137
+ st.write("### Point mutant candidates")
138
+
139
+ with st.spinner("Generating point mutants..."):
140
+ for candidate in db.suggest_point_mutant_candidates(seq, nearest_peptides=nearest)[:5]:
141
+ mutations = " ".join(f"{aa}{pos}{bb}" for pos, (aa, bb) in enumerate(zip(seq, candidate), start=1) if aa != bb)
142
+ pmb = db.compute_pmb(candidate)
143
+ st.code(f">{mutations} PMB={pmb:.2f}\n{candidate}")
144
+
145
+ st.write("### Double mutant candidates")
146
+
147
+ with st.spinner("Generating double mutants..."):
148
+ for candidate in db.suggest_double_mutant_candidates(seq, nearest_peptides=nearest)[:5]:
149
+ mutations = " ".join(f"{aa}{pos}{bb}" for pos, (aa, bb) in enumerate(zip(seq, candidate), start=1) if aa != bb)
150
+ pmb = db.compute_pmb(candidate)
151
+ st.code(f">{mutations} PMB={pmb:.2f}\n{candidate}")
152
+
153
+ st.write("### Triple mutant candidates")
154
+
155
+ with st.spinner("Generating triple mutants..."):
156
+ for candidate in db.suggest_triple_mutant_candidates(seq, nearest_peptides=nearest)[:5]:
157
+ mutations = " ".join(f"{aa}{pos}{bb}" for pos, (aa, bb) in enumerate(zip(seq, candidate), start=1) if aa != bb)
158
+ pmb = db.compute_pmb(candidate)
159
+ st.code(f">{mutations} PMB={pmb:.2f}\n{candidate}")
160
+
161
+ st.divider()
162
 
163
+ st.write("## Run locally")
 
 
 
164
 
165
+ st.write("Install and run `promb` locally:")
 
 
 
 
 
166
 
167
+ st.code("""
168
+ # Install promb
169
+ pip install promb
170
+ # See cli commands
171
+ promb --help
172
+ """, language="text")
173
 
174
+ st.write("More instructions in the GitHub repo: https://github.com/MSDLLCpapers/promb")
175
 
176
+ st.write("You can also clone this space as a git repository and run it locally:")
177
 
178
+ st.code("""
179
+ # Clone huggingface spaces repository
180
+ git clone https://huggingface.co/spaces/prihodad/promb-humanness
181
+ # Open the directory
182
+ cd promb-humanness
183
+ # Install dependencies (you should do this in a separate conda/venv environment)
184
+ pip install -r requirements.txt
185
+ # Run Streamlit app
186
+ streamlit run src/streamlit_app.py
187
+ """, language="text")