Deep-Learning-for-Causal-Inference icon indicating copy to clipboard operation
Deep-Learning-for-Causal-Inference copied to clipboard

Large dataset

Open macsermkiat opened this issue 2 years ago • 2 comments

Hi, this is a great tutorial! Thank you for sharing.

I have a question about implementing Dragonnet with a large dataset (in my case 200k subjects). Since to calculate loss it needs to construct a large matrix (200k x 200k) in float32 dtype, that cannot fit into memory. Do you have any suggestions?

Thanks

macsermkiat avatar Oct 26 '22 14:10 macsermkiat

Glad you enjoyed it. :) I haven't looked at these models particularly recently, so could you clarify for me which part is the bottleneck?

kochbj avatar Oct 26 '22 21:10 kochbj

There are many places involving construction of large tensors. For example, calculate the Euclidean distance between PhiControl and PhiTreatment in Tutorial 3,:

Class AIPW_Metrics:
    def pdist2sq(x,y):
        x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
        y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True)
        dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0))
        return dist

Since the Phi layer has 200 nodes, the x2 shape is (Nc,200) and y2 is (Nt,200) ; Nc = number of treatment subjects and Nt = number of control subjects.

Then dist shape will be (Nc, Nt) or roughly (100000,100000) in my dataset. For float32 dtype, the memory needed is 8 * 10e5 * 10e5 bytes which do not fit into GPU memory. So I have to move to the CPU, which also affects other functions.

Also, calculating the distance of the large matrix will take a long time due to quadratic complexity.

macsermkiat avatar Oct 27 '22 11:10 macsermkiat

Hmm, those calculations are only for calculating nearest neighbors in representation space for validation. A couple quick solutions:

  1. Skip them, they are only used in validation. Do it periodically on frozen models, perhaps in parallel on CPU?
  2. Do the training on the CPU. It is slower, but unless you are trying to do this in a production setting (which I strongly advise against since these algorithms are still pretty immature) it might be okay. You don't need to do thousands of training iterations, so unless you are doing a lot of simulations it could be reasonable?
  3. Try a different validation technique altogether like Alaa's influence functions or credence?

Hope this helps a bit! -B

https://proceedings.mlr.press/v162/parikh22a.html https://www.vanderschaar-lab.com/papers/Validating_CI_models_via_IF_manuscript.pdf

kochbj avatar Oct 29 '22 20:10 kochbj

Thank yous. This helps a lot!

macsermkiat avatar Oct 30 '22 14:10 macsermkiat