BERTopic
BERTopic copied to clipboard
topic_model.transform(docs)[0][i] is sometimes different from topic_model.transform(docs[i])[0][0]
Hello
I read https://maartengr.github.io/BERTopic/api/bertopic.html#bertopic._bertopic.BERTopic.transform and understood from the documents parameter (described as "A single document or a list of documents to predict on") that I could submit a list of documents or a single document and still receive the same result when predicting with a fitted model.
I found that this is not true. Am I overlooking something?
Below you find a minimal working example
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups
docs = fetch_20newsgroups(subset='all')['data'][:200]
topic_model = BERTopic().fit(docs)
topics,_=topic_model.transform(docs)
import numpy as np
topics=np.array(topics)
#calling the model with a single document several times
import tqdm
topics_single = []
for doc in tqdm.tqdm(docs):
topic, _ = topic_model.transform([doc])
topics_single.append(topic[0])
topics_single = np.array(topics_single)
mask_identical = topics_single == topics
percentage_equal = 100 * np.sum(mask_identical) / len(mask_identical)
print(f"{percentage_equal=}%") #returns for example about 60%, but varying
#loop till finding a different entry
for i in range(len(docs)):
print(
i,
topic_model.transform(docs[i])[0][0],
topic_model.transform([docs[i]])[0][0],
topics[i],
topic_model.transform(docs)[0][i],
)
if topic_model.transform(docs[i])[0][0] != topic_model.transform(docs)[0][i]:
print(f"Different outcome at iteration {i}")
break
The repeated execution with the same documents seems fine:
topics2,_=topic_model.transform(docs)
percentage_equal_executed_with_multiple_docs = 100*np.sum(np.array(topics2)==topics)/len(topics)
print(f"{percentage_equal_executed_with_multiple_docs=}%") #this gives 100%
Thank you in advance!
PS: The python version is 3.10.12 The list of installed packages is
absl-py==1.0.0
accelerate==0.23.0
adagio==0.2.4
aiohttp==3.8.6
aiosignal==1.3.1
ansi2html==1.9.1
antlr4-python3-runtime==4.11.1
anyio==3.5.0
appdirs==1.4.4
arch==6.2.0
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
astor==0.8.1
asttokens==2.0.5
astunparse==1.6.3
async-timeout==4.0.3
attrs==22.1.0
audioread==3.0.1
azure-core==1.29.1
azure-cosmos==4.3.1
azure-storage-blob==12.19.0
azure-storage-file-datalake==12.14.0
backcall==0.2.0
bcrypt==3.2.0
beautifulsoup4==4.11.1
bertopic==0.16.0
black==22.6.0
bleach==4.1.0
blinker==1.4
blis==0.7.11
boto3==1.24.28
botocore==1.27.96
cachetools==5.3.2
catalogue==2.0.10
category-encoders==2.6.2
certifi==2022.12.7
cffi==1.15.1
chardet==4.0.0
charset-normalizer==2.0.4
click==8.1.7
cloudpathlib==0.16.0
cloudpickle==2.0.0
cmake==3.27.7
cmdstanpy==1.2.0
comm==0.1.2
confection==0.1.3
configparser==5.2.0
contourpy==1.0.5
cryptography==39.0.1
cycler==0.11.0
cymem==2.0.8
Cython==0.29.32
dacite==1.8.1
dash==2.14.2
dash-core-components==2.0.0
dash-html-components==2.0.0
dash-table==5.0.0
dask==2023.12.0
databricks-automl-runtime==0.2.20
databricks-cli==0.18.0
databricks-feature-engineering==0.1.2
databricks-feature-store==0.16.1
databricks-sdk==0.1.6
dataclasses-json==0.6.2
datasets==2.14.5
dbl-tempo==0.1.26
dbus-python==1.2.18
debugpy==1.6.7
decorator==5.1.1
deepspeed==0.11.1
defusedxml==0.7.1
dill==0.3.6
diskcache==5.6.3
distlib==0.3.7
distributed==2023.12.0
distro==1.7.0
distro-info==1.1+ubuntu0.1
docstring-to-markdown==0.11
dtw-python==1.3.0
einops==0.7.0
entrypoints==0.4
evaluate==0.4.1
executing==0.8.3
facets-overview==1.1.1
fastjsonschema==2.19.0
fasttext==0.9.2
filelock==3.9.0
filterpy==1.4.5
flash-attn==2.3.2
Flask==2.2.5
flatbuffers==23.5.26
fonttools==4.25.0
frozenlist==1.4.0
fs==2.4.16
fsspec==2023.6.0
fugue==0.8.7
fugue-sql-antlr==0.2.0
future==0.18.3
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.27
gluonts==0.14.3
google-api-core==2.14.0
google-auth==2.21.0
google-auth-oauthlib==1.0.0
google-cloud-core==2.3.3
google-cloud-storage==2.11.0
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.6.0
googleapis-common-protos==1.61.0
greenlet==2.0.1
grpcio==1.48.2
grpcio-status==1.48.1
gunicorn==20.1.0
gviz-api==1.10.0
h5py==3.7.0
hdbscan==0.8.33
hjson==3.1.0
hmmlearn==0.3.0
holidays==0.35
horovod==0.28.1
htmlmin==0.1.12
httplib2==0.20.2
huggingface-hub==0.16.4
idna==3.4
ImageHash==4.3.1
imbalanced-learn==0.11.0
importlib-metadata==7.0.0
importlib-resources==6.1.1
ipykernel==6.25.0
ipython==8.14.0
ipython-genutils==0.2.0
ipywidgets==7.7.2
isodate==0.6.1
itsdangerous==2.0.1
jedi==0.18.1
jeepney==0.7.1
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.2.0
joblibspark==0.5.1
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.17.3
jupyter-client==7.3.4
jupyter-server==1.23.4
jupyter_core==5.2.0
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
keras==2.14.0
keras-self-attention==0.51.0
keyring==23.5.0
kiwisolver==1.4.4
kotsu==0.3.3
langchain==0.0.314
langcodes==3.3.0
langsmith==0.0.64
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.3
libclang==15.0.6.1
librosa==0.10.1
lightgbm==4.1.0
lit==17.0.5
llvmlite==0.39.1
locket==1.0.0
lxml==4.9.1
Mako==1.2.0
Markdown==3.4.1
MarkupSafe==2.1.1
marshmallow==3.20.1
matplotlib==3.7.0
matplotlib-inline==0.1.6
mccabe==0.7.0
mistune==0.8.4
ml-dtypes==0.2.0
mlflow-skinny==2.8.0
mne==1.6.0
more-itertools==8.10.0
mpmath==1.2.1
msgpack==1.0.7
multidict==6.0.4
multimethod==1.10
multiprocess==0.70.14
murmurhash==1.0.10
mypy-extensions==0.4.3
nbclassic==0.5.2
nbclient==0.5.13
nbconvert==6.5.4
nbformat==5.7.0
nest-asyncio==1.5.6
networkx==2.8.4
ninja==1.11.1.1
nltk==3.7
nodeenv==1.8.0
notebook==6.5.2
notebook_shim==0.2.2
numba==0.56.4
numpy==1.23.5
oauthlib==3.2.0
openai==0.28.1
opt-einsum==3.3.0
packaging==22.0
pandas==1.5.3
pandocfilters==1.5.0
paramiko==2.9.2
parso==0.8.3
partd==1.4.1
pathspec==0.10.3
pathy==0.10.3
patsy==0.5.3
petastorm==0.12.1
pexpect==4.8.0
phik==0.12.3
pickleshare==0.7.5
Pillow==9.4.0
platformdirs==2.5.2
plotly==5.9.0
pluggy==1.0.0
pmdarima==2.0.3
polars==0.19.19
pooch==1.8.0
preshed==3.0.9
prompt-toolkit==3.0.36
prophet==1.1.5
protobuf==4.24.0
psutil==5.9.0
psycopg2==2.9.3
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pyaml==23.9.7
pyarrow==8.0.0
pyarrow-hotfix==0.5
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.11.1
pycatch22==0.4.2
pycparser==2.21
pydantic==1.10.6
pyflakes==3.1.0
Pygments==2.11.2
PyGObject==3.42.1
PyJWT==2.3.0
pykalman-bardo==0.9.7
PyNaCl==1.5.0
pynndescent==0.5.11
pyod==1.1.2
pyodbc==4.0.32
pyparsing==3.0.9
pyright==1.1.294
pyrsistent==0.18.0
pytesseract==0.3.10
python-apt==2.4.0+ubuntu2
python-dateutil==2.8.2
python-editor==1.0.4
python-lsp-jsonrpc==1.1.1
python-lsp-server==1.8.0
pytoolconfig==1.2.5
pytz==2022.7
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==23.2.0
qpd==0.4.4
regex==2022.7.9
requests==2.28.1
requests-oauthlib==1.3.1
responses==0.18.0
retrying==1.3.4
rope==1.7.0
rsa==4.9
s3transfer==0.6.2
safetensors==0.4.0
scikit-base==0.6.1
scikit-learn==1.1.1
scikit-optimize==0.9.0
scikit-posthocs==0.8.0
scipy==1.10.0
seaborn==0.12.2
seasonal==0.3.1
SecretStorage==3.3.1
Send2Trash==1.8.0
sentence-transformers==2.2.2
sentencepiece==0.1.99
shap==0.43.0
simplejson==3.17.6
six==1.16.0
skpro==2.1.1
sktime==0.24.1
slicer==0.0.7
smart-open==5.2.1
smmap==5.0.0
sniffio==1.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.3.2.post1
soxr==0.3.7
spacy==3.7.1
spacy-legacy==3.0.12
spacy-loggers==1.0.5
spark-tensorflow-distributor==1.0.0
SQLAlchemy==1.4.39
sqlglot==20.2.0
sqlparse==0.4.2
srsly==2.4.8
ssh-import-id==5.11
stack-data==0.2.0
stanio==0.3.0
statsforecast==1.6.0
statsmodels==0.13.5
stumpy==1.12.0
sympy==1.11.1
tabulate==0.8.10
tangled-up-in-unicode==0.2.0
tbats==1.1.3
tblib==3.0.0
tenacity==8.1.0
tensorboard==2.14.0
tensorboard-data-server==0.7.2
tensorboard-plugin-profile==2.14.0
tensorflow==2.14.0
tensorflow-estimator==2.14.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.3.0
terminado==0.17.1
thinc==8.2.1
threadpoolctl==2.2.0
tiktoken==0.5.1
tinycss2==1.2.1
tokenize-rt==4.2.1
tokenizers==0.14.0
tomli==2.0.1
toolz==0.12.0
torch==2.0.1+cu118
torchvision==0.15.2+cu118
tornado==6.1
tqdm==4.64.1
traitlets==5.7.1
transformers==4.34.0
triad==0.9.3
triton==2.0.0
tsfresh==0.20.1
tslearn==0.5.3.2
typeguard==2.13.3
typer==0.9.0
typing-inspect==0.9.0
typing_extensions==4.4.0
ujson==5.4.0
umap-learn==0.5.5
unattended-upgrades==0.1
urllib3==1.26.14
virtualenv==20.16.7
visions==0.7.5
wadllib==1.3.6
wasabi==1.1.2
wcwidth==0.2.5
weasel==0.3.4
webencodings==0.5.1
websocket-client==0.58.0
Werkzeug==2.2.2
whatthepatch==1.0.2
widgetsnbextension==3.6.1
wordcloud==1.9.2
wrapt==1.14.1
xarray==2023.12.0
xgboost==1.7.6
xxhash==3.4.1
yapf==0.33.0
yarl==1.9.2
ydata-profiling==4.2.0
zict==3.0.0
zipp==3.11.0
With BERTopic it always depends on the underlying models that you choose. Some work differently than others so it is not possible to get the same behavior across all algorithms out there. As a result, it is important when creating your topic model to view BERTopic as something built from individual components.
Here, the inference that you refer to is generally a result of HDBSCAN which does an approximation to assign documents to clusters using a different process than during training. Moreover, it does not do this in isolation. This means that if you add documents to the .transform
step, HDBSCAN will use those to perform its assignment. If you only give a single document, the behavior will change.
To illustrate this further, if you use k-Means, this will not be the case since its inference process does not depend on all other documents.
A small tip, I believe there are a number of issues, both open and closed that discuss this so I would advise searching through them for more details.
Thank you for your clarification! Is it worth adding a note to the documentation to prevent future questions?
Sure! I believe an update to the docstrings of .transform
would be appropriate I believe. Although it will not prevent all future questions (since not everyone reads docstrings) I believe that would be a good first step. If you want, a PR would be appreciated.
With BERTopic it always depends on the underlying models that you choose. Some work differently than others so it is not possible to get the same behavior across all algorithms out there. As a result, it is important when creating your topic model to view BERTopic as something built from individual components.
Here, the inference that you refer to is generally a result of HDBSCAN which does an approximation to assign documents to clusters using a different process than during training. Moreover, it does not do this in isolation. This means that if you add documents to the
.transform
step, HDBSCAN will use those to perform its assignment. If you only give a single document, the behavior will change.To illustrate this further, if you use k-Means, this will not be the case since its inference process does not depend on all other documents.
A small tip, I believe there are a number of issues, both open and closed that discuss this so I would advise searching through them for more details.
Hi, I'm not sure why this (single document versus multiple documents yields different predictions) is expected behavior; according to the hdbscan docs about 'approximate_predict()' (https://hdbscan.readthedocs.io/en/latest/prediction_tutorial.html), which is the method that topic_model.transform -> hdbscan_delegator employs under the hood in BERTopic, I cannot see any description of such behavior.
It should freeze the whole condensed tree; I cannot see why the existence of other documents is affecting the classification result of a particular document here. Could anyone please explain this to me? (this matters because it means that the classification results of new points 'depend' on the batch (of documents) that I construct at the transform phase, which makes the whole classification result unstable and unreliable.)