"""XAI for Transformers Intent Classifier App.""" from collections import Counter from itertools import count from operator import itemgetter from re import DOTALL, sub import streamlit as st from plotly.express import bar from transformers import (AutoModelForSequenceClassification, AutoTokenizer, pipeline) from transformers_interpret import SequenceClassificationExplainer hide_streamlit_style = """ """ hide_plotly_bar = {"displayModeBar": False} st.markdown(hide_streamlit_style, unsafe_allow_html=True) repo_id = "remzicam/privacy_intent" task = "text-classification" title = "XAI for Intent Classification and Model Interpretation" st.markdown( f"

{title}

", unsafe_allow_html=True ) @st.cache(allow_output_mutation=True, suppress_st_warning=True) def load_models(): """ It loads the model and tokenizer from the HuggingFace model hub, and then creates a pipeline object that can be used to make predictions. Also, it creates model interpretation object. Returns: the privacy_intent_pipe and cls_explainer. """ model = AutoModelForSequenceClassification.from_pretrained( repo_id, low_cpu_mem_usage=True ) tokenizer = AutoTokenizer.from_pretrained(repo_id) privacy_intent_pipe = pipeline( task, model=model, tokenizer=tokenizer, return_all_scores=True ) cls_explainer = SequenceClassificationExplainer(model, tokenizer) return privacy_intent_pipe, cls_explainer privacy_intent_pipe, cls_explainer = load_models() def label_probs_figure_creater(input_text:str): """ It takes in a string, runs it through the pipeline, and returns a figure and the label with the highest probability Args: input_text (str): The text you want to analyze Returns: A tuple of a figure and a string. """ outputs = privacy_intent_pipe(input_text)[0] sorted_outputs = sorted(outputs, key=lambda k: k["score"]) prediction_label = sorted_outputs[-1]["label"] fig = bar( sorted_outputs, x="score", y="label", color="score", color_continuous_scale="rainbow", width=600, height=400, ) fig.update_layout( title="Model Prediction Probabilities for Each Label", xaxis_title="", yaxis_title="", xaxis=dict( # attribures for x axis showline=True, showgrid=True, linecolor="black", tickfont=dict(family="Calibri"), ), yaxis=dict( # attribures for y axis showline=True, showgrid=True, linecolor="black", tickfont=dict( family="Times New Roman", ), ), plot_bgcolor="white", title_x=0.5, ) return fig, prediction_label def xai_attributions_html(input_text: str): """ 1. The function takes in a string of text as input. 2. It then uses the explainer to generate attributions for each word in the input text. 3. It then uses the explainer to generate an HTML visualization of the attributions. 4. It then cleans up the HTML visualization by removing some unnecessary HTML tags. 5. It then returns the attributions and the HTML visualization Args: input_text (str): The text you want to explain. Returns: the word attributions and the html. """ word_attributions = cls_explainer(input_text) #remove special tokens word_attributions = word_attributions[1:-1] # remove strings shorter than 1 chrachter word_attributions = [i for i in word_attributions if len(i[0]) > 1] html = cls_explainer.visualize().data html = html.replace("#s", "") html = html.replace("#/s", "") html = sub("", "", html, 4, DOTALL) html = sub("", "", html, 4, DOTALL) return word_attributions, html+"
" def explanation_intro(prediction_label: str): """ generates model explanaiton html markdown from prediction label of the model. Args: prediction_label (str): The label that the model predicted. Returns: A string """ return f"""
The model predicted the given sentence as '{prediction_label}'. The figure below shows the contribution of each token to this decision. Green tokens indicate a positive contribution, while red tokens indicate a negative contribution. The bolder the color, the greater the value.

""" def explanation_viz(prediction_label: str, word_attributions): """ It takes in a prediction label and a list of word attributions, and returns a markdown string that contains the word that had the highest attribution and the prediction label Args: prediction_label (str): The label that the model predicted. word_attributions: a list of tuples of the form (word, attribution score) Returns: A string """ top_attention_word = max(word_attributions, key=itemgetter(1))[0] return f"""The token **_'{top_attention_word}'_** is the biggest driver for the decision of the model as **'{prediction_label}'**""" def word_attributions_dict_creater(word_attributions): """ It takes a list of tuples, reverses it, splits it into two lists, colors the scores, numerates duplicated strings, and returns a dictionary Args: word_attributions: This is the output of the model explainer. Returns: A dictionary with the keys "word", "score", and "colors". """ word_attributions.reverse() words, scores = zip(*word_attributions) # colorize positive and negative scores colors = ["red" if x < 0 else "lightgreen" for x in scores] # darker tone for max score max_index = scores.index(max(scores)) colors[max_index] = "darkgreen" # numerate duplicated strings c = Counter(words) iters = {k: count(1) for k, v in c.items() if v > 1} words_ = [x + "_" + str(next(iters[x])) if x in iters else x for x in words] # plotly accepts dictionaries return { "word": words_, "score": scores, "colors": colors, } def attention_score_figure_creater(word_attributions_dict): """ It takes a dictionary of words and their attention scores and returns a bar graph of the words and their attention scores with specified colors. Args: word_attributions_dict: a dictionary with keys "word", "score", and "colors" Returns: A figure object """ fig = bar(word_attributions_dict, x="score", y="word", width=400, height=500) fig.update_traces(marker_color=word_attributions_dict["colors"]) fig.update_layout( title="Word-Attention Score", xaxis_title="", yaxis_title="", xaxis=dict( # attribures for x axis showline=True, showgrid=True, linecolor="black", tickfont=dict(family="Calibri"), ), yaxis=dict( # attribures for y axis showline=True, showgrid=True, linecolor="black", tickfont=dict( family="Times New Roman", ), ), plot_bgcolor="white", title_x=0.5, ) return fig form = st.form(key="intent-form") input_text = form.text_area( label="Text", value="At any time during your use of the Services, you may decide to share some information or content publicly or privately.", ) submit = form.form_submit_button("Submit") if submit: label_probs_figure, prediction_label = label_probs_figure_creater(input_text) st.plotly_chart(label_probs_figure, config=hide_plotly_bar) explanation_general = explanation_intro(prediction_label) st.markdown(explanation_general, unsafe_allow_html=True) with st.spinner(): word_attributions, html = xai_attributions_html(input_text) st.markdown(html, unsafe_allow_html=True) explanation_specific = explanation_viz(prediction_label, word_attributions) st.info(explanation_specific) word_attributions_dict = word_attributions_dict_creater(word_attributions) attention_score_figure = attention_score_figure_creater(word_attributions_dict) st.plotly_chart(attention_score_figure, config=hide_plotly_bar)