coresoftware icon indicating copy to clipboard operation
coresoftware copied to clipboard

CNN photon classifier

Open Shuonli opened this issue 1 year ago • 2 comments

Types of changes

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work for users)
  • [ ] Requiring change in macros repository (Please provide links to the macros pull request in the last section)
  • [x] I am a member of GitHub organization of sPHENIX Collaboration, EIC, or ECCE (contact Chris Pinkenburg to join)

What kind of change does this PR introduce? (Bug fix, feature, ...)

I trained a toy CNN to classify EMCal photon clusters from non-photon showers.

Training sample

The training sample is derived from Pythia 1M photon-jet sample and 10M Detroit jet sample. Cluster truth matching is done with the g4eval module For clusters with ET>7GeV, I took the 5x5 tower_calib energy around the maximum energy tower. For both photon and non-photon sample I require the primary particle that matched to the cluster to contribute >99% energy in the 5x5 tower block. For truth photon sample, e+ e- converted photons are identified from shower history and excluded from the training sample. For non-photon samples, when the primary particle is pi0 and eta, we require the decayed photon to have delta R < 0.05 (otherwise we could have only one photon in the 5x5 tower block and have it labeled as non-photon shower), and require di-photon asymmetry < 0.8.

Model

I implement a simple CNN with TensorFlow to train for the classification task, the model summary is as followed:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d_93 (Conv2D)              │ (None, 5, 5, 32)       │           320 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_92 (MaxPooling2D) │ (None, 2, 2, 32)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_94 (Conv2D)              │ (None, 2, 2, 64)       │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_93 (MaxPooling2D) │ (None, 1, 1, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_96 (Dropout)            │ (None, 1, 1, 64)       │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten_46 (Flatten)            │ (None, 64)             │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_105 (Dense)               │ (None, 16)             │         1,040 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_106 (Dense)               │ (None, 1)              │            17 │
└─────────────────────────────────┴────────────────────────┴───────────────┘

The optimizer and loss are those common one for CNN model

initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              #loss='categorical_crossentropy',
              loss='binary_crossentropy',
              metrics=['accuracy'])

af45f12832f1f821b7172584328cb9b For differentiate photons from non-photons, this CNN model seems to do a bit better job than just using the cluster prob, especially when we set to high probability threshold we can have a better non-photon rejection with will lead to a better signal background ratio for isolated photon study.

Implementation in Fun4All

The model is exported as ONNX format and get applied with our ONNX wrapper. I add a new function signature to make this work with models that require a higher dimensional input.

The RawClusterCNNClassifier model takes the cluster container and tower container, extract the 5x5 energy and feed it to the trained model, then modify the cluster prob based on the model output. The method is very flexible, so if you have your own trained model that you want to use instead of mine crappy CNN model, you can change the path to the ONNX session to yours and modify the model dimension accordingly and then it should just work.

TODOs (if applicable)

In-place modification of the cluster prob is a bad idea(but that's what I'm doing now). In the future I will either add a new field to the rawcluster object (the complication is this might screw up the DST readback if I changed it in v1, so I need to make a new class), or make a new cluster node with different name on the node tree.

I need more systematic study to test this method and need more parameter optimization/better feature design, and finally put the model in CDB.

But the framework is here and it works, so feel free to checkout this branch and play with it if you are interested.

Links to other PRs in macros and calibration repositories (if applicable)

Shuonli avatar Aug 17 '24 01:08 Shuonli

Build & test report

Report for commit 2d03d601f6f50f0a74c8a1a4aa1cc08e7dc72f2f: Jenkins on fire


Automatically generated by sPHENIX Jenkins continuous integration sPHENIX             jenkins.io

sphenix-jenkins-ci[bot] avatar Aug 17 '24 06:08 sphenix-jenkins-ci[bot]

Build & test report

Report for commit 1b859ac4bd082ee9d95b06b803ee748f237216f6: Jenkins passed


Automatically generated by sPHENIX Jenkins continuous integration sPHENIX             jenkins.io

sphenix-jenkins-ci[bot] avatar Aug 17 '24 06:08 sphenix-jenkins-ci[bot]

Build & test report

Report for commit 7d6ce239b8e74fe1db335cab39ed2db5d24bab92: Jenkins passed


Automatically generated by sPHENIX Jenkins continuous integration sPHENIX             jenkins.io

sphenix-jenkins-ci[bot] avatar Aug 18 '24 19:08 sphenix-jenkins-ci[bot]