skada icon indicating copy to clipboard operation
skada copied to clipboard

Deep copy a DomainAwareNet after fitting

Open YanisLalou opened this issue 11 months ago • 2 comments

After fitting a DomainAwareNet, it's not possible to deep copy it. Strangely it works before the fitting. Error raised: AttributeError: Can't pickle local object '_get_intermediate_layers.<locals>.hook'

It's necessary to be able to deep copy it to use the CircularValidation scorer.

To reproduce:

import torch
import numpy as np
from skada.deep.base import DomainAwareCriterion, DomainAwareModule, DomainAwareNet, DomainBalancedDataLoader, BaseDALoss
from skada.metrics import (
    ImportanceWeightedScorer,
    PredictionEntropyScorer,
    SoftNeighborhoodDensity,
    DeepEmbeddedValidation,
    CircularValidation,
)
from skada.deep.modules import ToyModule2D
from skada import make_da_pipeline
from sklearn.model_selection import ShuffleSplit, cross_validate
from skada.datasets import make_shifted_datasets
from copy import deepcopy


class TestLoss(BaseDALoss):
    """Test Loss to check the deep API"""

    def __init__(
        self,
    ):
        super().__init__()

    def forward(
        self,
        *args,
    ):
        """Compute the domain adaptation loss"""
        return 0


da_dataset = make_shifted_datasets(
        n_samples_source=20,
        n_samples_target=20,
        shift="concept_drift",
        noise=0.1,
        random_state=42,
        return_dataset=True,
    )

X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"])

module = ToyModule2D()

estimator = DomainAwareNet(
	DomainAwareModule(module, "dropout"),
	iterator_train=DomainBalancedDataLoader,
	criterion=DomainAwareCriterion(torch.nn.CrossEntropyLoss(), TestLoss()),
	batch_size=10,
	max_epochs=2,
	train_split=None,
)

X = X.astype(np.float32)
X_test = X_test.astype(np.float32)

estimator_copy = deepcopy(estimator) # Doesn't raise errors

estimator.fit(X, y, sample_domain=sample_domain)

estimator_copy_after_fit = deepcopy(estimator) # Raises AttributeError

YanisLalou avatar Mar 05 '24 15:03 YanisLalou