deepvats icon indicating copy to clipboard operation
deepvats copied to clipboard

Modify get_enc_embs to allow for custom batch sizes, and lazy evaluation

Open vrodriguezf opened this issue 2 years ago • 0 comments

This is a working code:

#|export
def get_enc_embs(X, enc_learn, module=None, cpu=False, average_seq_dim=True, 
                to_numpy=True, bs=None):
    """
        (From deepvats https://github.com/vrodriguezf/deepvats/blob/master/nbs/encoder.ipynb)
        Get the embeddings of X from an encoder, passed in `enc_learn as a fastai
        learner. By default, the embeddings are obtained from the last layer
        before the model head, although any layer can be passed to `model`.
        Input
        - `cpu`: Whether to do the model inference in cpu of gpu (GPU recommended)
        - `average_seq_dim`: Whether to aggregate the embeddings in the sequence dimensions
        - `to_numpy`: Whether to return the result as a numpy array (if false returns a tensor)
        - `bs`: Batch size to use for the inference (if None, uses the batch size of the
            validation dataloader of `enc_learn`)
    """
    if cpu:
        enc_learn.dls.cpu()
        enc_learn.cpu()
    else:
        enc_learn.dls.cuda()
        enc_learn.cuda()
    
    aux_dl = enc_learn.dls.valid.new_dl(X=X)
    # Set batch size for aux_dl
    if bs is not None:
        aux_dl.bs = bs
    elif enc_learn.dls.bs>0:
        aux_dl.bs = enc_learn.dls.bs
    else:
        aux_dl.bs = 64
    
    module = nested_attr(enc_learn.model,
                         AE_EMBS_MODULE_NAME[type(enc_learn.model)]) \
                if module is None else module
    def embs_generator():
        for b in aux_dl:
            emb = get_acts_and_grads(model=enc_learn.model,
                                     modules=module, 
                                     x=b[0], cpu=cpu)[0]
            if to_numpy: 
                emb = emb.numpy() if cpu else emb.cpu().numpy()
            yield emb
            
    embs = list(embs_generator())
    embs = np.concatenate(embs, axis=0) if to_numpy else to_concat(embs)
    if embs.ndim == 3 and average_seq_dim: embs = embs.mean(axis=2)
    return embs

vrodriguezf avatar Jun 14 '23 17:06 vrodriguezf