Article Figure 9
Hi,
I came across your work and saw a visualization of semantic embedding using T-SNE in figure 9.
Is there any way I can reproduce that ? Thanks you
Thank you for your interest for our work. You can simply add output_hidden_states=True in generate function for outputing embeddings.
result = model.generate(
input_ids,
images=images,
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
top_p=top_p,
stopping_criteria=[stopping_criteria],
output_attentions=True,
output_hidden_states=True,
output_scores=True,
return_dict_in_generate=True
)
And you can use the following code to extract the embedding of the last token in last layer.
hidden_states = result.hidden_states hidden_states = query(text_prompt=query_text, image_prompt=query_image) embedding.append(hidden_states[0][-1][:,-1,:].cpu().numpy()[0])
So that you can get a semantic representation of one query and append it to the list of embeddings.
After generating several lists of embeddings, you can use t-SNE to produce such figures.
In our case, we save these embeddings to csv files and draw the figures with them. The code is put as follows
df = pd.read_csv("emb/harmful_150_embedding.csv", header=None)
harmful_150_embedding = df.to_numpy()
df = pd.read_csv("emb/harmless_150_embedding.csv", header=None)
harmless_150_embedding = df.to_numpy()
df = pd.read_csv("emb/figstep_harmful_150_embedding.csv", header=None)
figstep_harmful_150_embedding = df.to_numpy()
df = pd.read_csv("emb/figstep_harmless_150_embedding.csv", header=None)
figstep_harmless_150_embedding = df.to_numpy()
df = pd.read_csv("emb/harmless_mode2_150_embedding.csv", header=None)
harmless_mode2_embedding = df.to_numpy()
df = pd.read_csv("emb/harmful_mode2_150_embedding.csv", header=None)
harmful_mode2_embedding = df.to_numpy()
combined_data = np.vstack((harmful_150_embedding, harmless_150_embedding, figstep_harmful_150_embedding, figstep_harmless_150_embedding,harmful_mode2_embedding,harmless_mode2_embedding))
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize = (8,5),dpi=300)
color = ["#39c5bb","tomato","#EE82EE","orange","blue","skyblue"]
marker = ["o","s","*","x","+","^"]
label = ["prohibited $\mathbb{Q}^{va}$","benign $\mathbb{Q}^{va}$", "prohibited $\mathsf{FigStep}$","benign $\mathsf{FigStep}$","prohibited $\mathbb{Q}_2'$","benign $\mathbb{Q}_2'$"]
vis = TSNE(n_components=2).fit_transform(np.array(combined_data))
for i in ([1,0,5,4,3,2]):
scatter =ax.scatter(vis[150*i:150*(i+1),0],vis[150*i:150*(i+1),1],c=color[i],marker=marker[i],label=label[i])
plt.subplots_adjust(left=0.05,right=0.95)
plt.subplots_adjust(top=0.95,bottom=0.05)
ax.legend(loc='upper left',fontsize=14)
ax.axis('off')
plt.savefig("emb-llava-1.png",transparent=True)
I hope this will help you.
Hi again, thanks a lot for the answer, this seems to work for me.