ExplainableAI.jl icon indicating copy to clipboard operation
ExplainableAI.jl copied to clipboard

Compatibility with Transformers.jl

Open ceferisbarov opened this issue 1 year ago • 9 comments
trafficstars

Transformers.jl models require NamedTuple input. ExplainableAI.jl analyzers require a derivative of AbstractArray. We can solve this by modifying XAIBase.jl and ExplainableAI.jl to support the Transformers.jl interface. I can start working on a PR if the maintainers are interested.

ceferisbarov avatar Jul 15 '24 22:07 ceferisbarov

@adrhill this is related to https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl/pull/413 and perhaps a good first step towards integrating our systems a bit more 😄

pat-alt avatar Jul 16 '24 05:07 pat-alt

Sorry for the late answer @ceferisbarov, @pat-alt, I caught a bad case of COVID and spent last week recovering from it! We should absolutely make this package compatible with Transformers.jl.

ExplainableAI.jl analyzers require a derivative of AbstractArray.

This constraint is not intended. I dug into it and comes from overly strict type annotations in the XAIBase interface. I've opened a more specific issue in https://github.com/Julia-XAI/XAIBase.jl/issues/18.

I'll leave this issue open to track compatibility of ExplainableAI.jl with Transformers.jl. Do you have some specific use case you expected to work that you could share?

adrhill avatar Jul 22 '24 12:07 adrhill

I hope you are doing better now!

Here is an example:

using Transformers
using Transformers.TextEncoders
using Transformers.HuggingFace

using ExplainableAI

classifier = hgf"gtfintechlab/FOMC-RoBERTa:ForSequenceClassification"

encoder = hgf"gtfintechlab/FOMC-RoBERTa"[1]

analyzer = IntegratedGradients(classifier)

input = encode(encoder, "Hello, world!")

expl = analyze(input, analyzer)

input variable is a NamedTuple. We can either

  • modify the analyze function and analyzers to accept this format or
  • create dedicated classes that accept a model and a tokenizer and handles the process itself. SequenceClassificationExplainer is a good example of this interface:
from transformers_interpret import SequenceClassificationExplainer
cls_explainer = SequenceClassificationExplainer(
    model,
    tokenizer)
word_attributions = cls_explainer("I love you, I like you")

ceferisbarov avatar Jul 22 '24 21:07 ceferisbarov

Sorry to hear @adrhill, hope you've recovered by now

pat-alt avatar Jul 24 '24 07:07 pat-alt

Thanks, things are getting better!

I'm addressing this issue by updating the ecosystem interface in https://github.com/Julia-XAI/XAIBase.jl/pull/20. Since this will already be a breaking change, is there anything else you'd like to see changed @ceferisbarov?

adrhill avatar Jul 26 '24 11:07 adrhill

That was quick, thanks! I don't have anything else to add.

I can use the new version and give feedback if I face any issues. Please, let me know if I can help in any other way.

ceferisbarov avatar Jul 27 '24 00:07 ceferisbarov

I just merged PR #166, which includes the changes from https://github.com/Julia-XAI/XAIBase.jl/pull/20. Could you try out whether things now work for you on the main branch?

adrhill avatar Jul 27 '24 20:07 adrhill

Sorry, I am having laptop issues, so I won't be able to try it this week.

To be clear, I am supposed to create a new analyzer, since the existing ones do not support Transformer models, right?

ceferisbarov avatar Jul 29 '24 02:07 ceferisbarov

Hi @ceferisbarov, I hope #176 clears up how to use the package. Existing analyzers should support anything that takes an input and is differentiable.

adrhill avatar Oct 11 '24 14:10 adrhill

Trying this code still gives an error:

let
	classifier = hgf"bert-base-uncased:ForSequenceClassification"
	encoder = hgf"bert-base-uncased"[1]
	analyzer = IntegratedGradients(classifier)
	input = encode(encoder, "Hello, world!")
	expl = analyze(input, analyzer)
end

VarLad avatar Feb 21 '25 03:02 VarLad

This should work in theory:

using ExplainableAI
using TextHeatmaps
using Transformers
using Transformers.TextEncoders
using Transformers.HuggingFace

classifier = hgf"bert-base-uncased:ForSequenceClassification"
encoder = hgf"bert-base-uncased"[1]
input = encode(encoder, "Hello, world!")

analyzer = Gradient(x -> classifier(x).logit)
expl = analyze(input, analyzer)

However, it looks like classifier(x).logit doesn't support gradient computations:

julia> isnothing(expl.val)
true

adrhill avatar Feb 26 '25 13:02 adrhill

If there is a way to use AD with Transformers.jl, we could of course add a package extension to XAIBase to make all of this more ergonomic.

adrhill avatar Feb 26 '25 13:02 adrhill