ARM-Net
ARM-Net copied to clipboard
A ready-to-use framework of the state-of-the-art models for structured (tabular) data learning with PyTorch. Applications include recommendation, CRT prediction, healthcare analytics, anomaly detectio...
ARM-Net: Adaptive Relation Modeling Network for Structured Data
This repository contains our PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data. We also provide the implementation of relevant baseline models for structured (tabular) data learning.

ARM-Net for Large Real-world Datasets
Benchmark Datasets
- Frappe - App Recommendation
- MovieLens - Movie Recommendation
- Avazu - Click-Through Rate Prediction
- Criteo - Display Advertising Challenge
- Diabetes130 - Diabetes Readmission Prediction
Summary of Results
- Main results are summarized below.
- ARM-Net achieves the overall best performance.
- More results and technical details can be found in the paper.
- Note that these results are reported using a fixed embedding size of 10 for a fair comparison, and higher AUC can be obtained by increasing the embedding size:
E.g., with a larger embedding size of 100, ARM-Net (single head, without ensemble with a DNN)
can obtain 0.9817 AUC on Frappe with only 10 exponential neurons.
CUDA_VISIBLE_DEVICES=0 python train.py --model armnet_1h --nemb 100 --h 10 --alpha 1.7 --lr 0.001 --exp_name frappe_armnet_1h_nemb_100
The AUC and Model Size of this ARM-Net with different embedding sizes are listed below.
Embedding Size | 10 | 20 | 30 | 40 | 50 | 60 | 70 | 80 | 90 | 100 | 110 | 120 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
AUC | 0.9777 | 0.9779 | 0.9801 | 0.9803 | 0.9798 | 0.9807 | 0.9808 | 0.9810 | 0.9810 | 0.9817 | 0.9811 | 0.9805 |
Model Size | 177K | 262K | 348K | 434K | 520K | 606K | 692K | 779K | 866K | 953K | 1.04M | 1.13M |

Baseline Models
Model | Code | Reference |
---|---|---|
Logistic Regression | LR, lr.py | - |
Factorization Machine | FM, fm.py | [ICDE-10] FM |
Attentional Factorization Machine | AFM, afm.py | [IJCAI-17] AFM |
Higher-Order Factorization Machines | HOFM, hofm.py | [NeurIPS-16] HOFM |
Deep Neural Network | DNN, dnn.py | - |
Graph Convolutional Networks | GCN, gcn.py | [ICLR-17] GCN |
Graph Convolutional Networks | GAT, gat.py | [ICLR-18] GAT |
Wide&Deep | Wide&Deep, wd.py | [RecSys-16] Wide&Deep |
Product Neural Network | IPNN/KPNN, pnn.py | [ICDE-16] PNN |
Neural Factorization Machine | NFM, nfm.py | [SIGIR-17] NFM |
DeepFM | DeepFM, dfm.py | [IJCAI-17] DeepFM |
Deep & Cross Network | DCN/DCN+, dcn.py | [KDD-17] DCN |
Gated Linear Unit | SA_GLU, sa_glu.py | [ICML-17] GLU |
xDeepFM | CIN/xDeepFM, xdfm.py | [KDD-18] xDeepFM |
Context-Aware Self-Attention Network | GC_ARM, gc_arm.py | [AAAI-19] GC-ARM |
AFN | AFN/AFN+, afn.py | [AAAI-20] AFN |
ARM-Net | ARM-Net/ARM-Net+, armnet.py | [SIGMOD-21] ARM-Net |
ARM-Net-1h (one-head, recommended) | ARM-Net/ARM-Net+, armnet_1h.py | [SIGMOD-21] ARM-Net-1h |
ARM-Net for Small to Medium Tabular Datasets (121 UCI datasets)
We also implement and evaluate prior arts, latest models and our ARM-Net on UCI datasets. These datasets are multi-class real-world classification tasks, whose features are all converted into numerical features following common practice. Models and Utilities for evaluating models on 121 UCI Datasets are included in this branch.
Summary of UCI Results
- Main results are summarized below.
- ARM-Net achieves overall best performance.
- More results and technical details can be found here.
Model | Rank(Best_Cnt) | abalone | acute-inflammation | acute-nephritis | adult | annealing | arrhythmia | audiology-std | balance-scale | balloons | bank | blood | breast-cancer | breast-cancer-wisc | breast-cancer-wisc-diag | breast-cancer-wisc-prog | breast-tissue | car | cardiotocography-10clases | cardiotocography-3clases | chess-krvk | chess-krvkp | congressional-voting | conn-bench-sonar-mines-rocks | conn-bench-vowel-deterding | connect-4 | contrac | credit-approval | cylinder-bands | dermatology | echocardiogram | ecoli | energy-y1 | energy-y2 | fertility | flags | glass |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
n_samples |
- | 4177 | 120 | 120 | 48842 | 898 | 452 | 196 | 625 | 16 | 4521 | 748 | 286 | 699 | 569 | 198 | 106 | 1728 | 2126 | 2126 | 28056 | 3196 | 435 | 208 | 990 | 67557 | 1473 | 690 | 512 | 366 | 131 | 336 | 768 | 768 | 100 | 194 | 214 |
n_features |
- | 9 | 7 | 7 | 15 | 32 | 263 | 60 | 5 | 5 | 17 | 5 | 10 | 10 | 31 | 34 | 10 | 7 | 22 | 22 | 7 | 37 | 17 | 61 | 12 | 43 | 10 | 16 | 36 | 35 | 11 | 8 | 9 | 9 | 10 | 29 | 10 |
LR | 6-th (0/36) | 0.6293/0.0080 | 0.9833/0.0211 | 0.9533/0.0552 | 0.8423/0.0008 | 0.1280/0.0172 | 0.5442/0.0184 | 0.7040/0.0480 | 0.8718/0.0310 | 0.7250/0.0935 | 0.8904/0.0023 | 0.7610/0.0043 | 0.6923/0.0171 | 0.9490/0.0090 | 0.9641/0.0103 | 0.6626/0.0656 | 0.5283/0.1371 | 0.8032/0.0052 | 0.7595/0.0118 | 0.8798/0.0120 | 0.2743/0.0009 | 0.9438/0.0035 | 0.5705/0.0328 | 0.7385/0.0186 | 0.7121/0.0088 | 0.7547/0.0004 | 0.4829/0.0383 | 0.8557/0.0119 | 0.6305/0.0647 | 0.9399/0.0313 | 0.7600/0.0605 | 0.7988/0.0510 | 0.8391/0.0123 | 0.8448/0.0297 | 0.5800/0.1066 | 0.4206/0.0365 | 0.5290/0.0281 |
FM | 5-th (3/36) | 0.6329/0.0067 | 0.9767/0.0389 | 0.8700/0.0945 | 0.8443/0.0005 | 0.1960/0.1493 | 0.5283/0.0211 | 0.4880/0.0588 | 0.9224/0.0087 |
0.5750/0.1275 | 0.8882/0.0028 | 0.7647/0.0000 | 0.6909/0.0604 | 0.9599/0.0048 | 0.9697/0.0048 |
0.6626/0.0849 | 0.5094/0.0818 | 0.8882/0.0097 | 0.7616/0.0161 | 0.8903/0.0172 | 0.3127/0.0035 | 0.9796/0.0038 | 0.5705/0.0306 | 0.9502/0.0087 |
0.9502/0.0087 | 0.8264/0.0005 | 0.4524/0.0140 | 0.8638/0.0093 | 0.7016/0.0250 | 0.9202/0.0350 | 0.7846/0.0600 | 0.7595/0.0680 | 0.8823/0.0086 | 0.8604/0.0283 | 0.7720/0.0688 | 0.3423/0.0200 | 0.5907/0.0361 |
DNN | 4-th (6/36) | 0.6560/0.0051 | 0.9900/0.0200 | 0.9500/0.0316 | 0.8519/0.0015 | 0.4420/0.2346 | 0.6442/0.0114 | 0.6880/0.0466 | 0.8987/0.0048 | 0.5500/0.2318 | 0.8900/0.0035 | 0.7583/0.0050 | 0.7147/0.0082 | 0.9633/0.0033 | 0.9648/0.0107 | 0.7091/0.0475 | 0.5849/0.0396 | 0.9442/0.0034 | 0.7797/0.0121 | 0.9178/0.0031 | 0.6842/0.0147 | 0.9775/0.0032 | 0.5834/0.0147 | 0.7481/0.0377 | 0.9745/0.0063 |
0.8501/0.0023 | 0.5084/0.0158 | 0.8417/0.0187 | 0.7359/0.0386 |
0.9639/0.0101 |
0.7846/0.0337 | 0.8524/0.0166 |
0.8688/0.0107 | 0.8865/0.0094 |
0.8320/0.0722 | 0.4969/0.0272 |
0.5850/0.0316 |
SNN | 3rd (6/36) | 0.6457/0.0043 | 0.9567/0.0389 | 0.9000/0.0548 | 0.8489/0.0009 | 0.2280/0.2671 | 0.5841/0.0410 | 0.7200/0.0253 |
0.9058/0.0240 | 0.7250/0.1225 | 0.8885/0.0019 | 0.8885/0.0019 | 0.7105/0.0105 | 0.9656/0.0041 |
0.9690/0.0112 | 0.6727/0.0903 | 0.6000/0.0690 |
0.9632/0.0066 |
0.8008/0.0125 |
0.9029/0.0086 | 0.6796/0.0141 | 0.9726/0.0061 | 0.5779/0.0209 | 0.7135/0.0300 | 0.9693/0.0100 | 0.8491/0.0013 | 0.5106/0.0098 | 0.8719/0.0121 |
0.7000/0.0163 | 0.9388/0.0269 | 0.7877/0.0439 | 0.8179/0.035 | 0.8714/0.0142 | 0.8854/0.0154 | 0.7600/0.1180 | 0.4804/0.0231 | 0.5738/0.0602 |
Perceiver-IO | 2nd (6/36) | 0.6381/0.0143 | 1.0000/0.0000 |
0.9367/0.0531 | 0.8521/0.0011 | 0.7600/0.0000 |
0.5602/0.0053 | 0.0080/0.0160 | 0.8821/0.0166 | 0.7750/0.0500 |
0.8850/0.0000 | 0.7620/0.0000 | 0.7063/0.0088 | 0.9352/0.0313 | 0.9556/0.0142 | 0.7596/0.0118 |
0.3208/0.0597 | 0.9326/0.0120 | 0.5325/0.0861 | 0.7817/0.0035 | 0.6834/0.0151 | 0.8106/0.0895 | 0.6129/0.0000 |
0.5635/0.0817 | 0.6732/0.0521 | 0.7538/0.0000 | 0.4457/0.0122 | 0.7745/0.1075 | 0.6133/0.0078 | 0.4295/0.0754 | 0.7662/0.0834 | 0.6440/0.0239 | 0.8417/0.0295 | 0.8807/0.0325 | 0.8560/0.0480 |
0.3010/0.0247 | 0.4093/0.0415 |
ARM-Net |
1st (15/36) |
0.6603/0.0034 |
0.9767/0.0389 | 0.9600/0.0800 |
0.8562/0.0011 |
0.1500/0.1131 | 0.6487/0.0214 |
0.5520/0.0299 | 0.9135/0.0070 | 0.7500/0.0791 | 0.8922/0.0012 |
0.8922/0.0012 |
0.7203/0.0193 |
0.9530/0.0118 | 0.9521/0.0186 | 0.6828/0.0485 | 0.5170/0.0638 | 0.9463/0.0086 | 0.7868/0.0054 | 0.9146/0.0051 |
0.6982/0.0109 |
0.9826/0.0040 |
0.5760/0.0193 | 0.7712/0.0335 | 0.9675/0.0115 | 0.8672/0.0028 |
0.5228/0.0119 |
0.8620/0.0187 | 0.7133/0.0305 | 0.9497/0.0181 | 0.8338/0.0406 |
0.8214/0.0279 | 0.8844/0.0048 |
0.8750/0.0304 | 0.8240/0.0528 | 0.4330/0.0526 | 0.6150/0.0232 |
ARM-Net for Log-Based Anomaly Detection
ARM-Net can also be readily adapted for supporting Log-based Anomaly Detection. Log-based anomaly detection aims to discover abnormal system behaviors (binary classification) by analyzing log sequences that are generated routinely by the system at runtime.

Each log is a message in unstructued data format (raw text), which can be parsed into structured data format of a number of key information fields, e.g., date, pid, level, event ID and etc. Models and Utilities for supporting End-to-end Log-based Anomaly Detection can be found in this branch.
Citation
If you use our code in your research, please cite:
S. Cai, K. Zheng, G. Chen, H.V. Jagadish, B.C. Ooi, M. Zhang. ARM-Net: Adaptive Relation Modeling Network for Structured Data. ACM International Conference on Management of Data (SIGMOD), 2021
Contact
To ask questions or report issues, you can drop us an email.