skada
skada copied to clipboard
Deep copy a DomainAwareNet after fitting
After fitting a DomainAwareNet
, it's not possible to deep copy it.
Strangely it works before the fitting.
Error raised:
AttributeError: Can't pickle local object '_get_intermediate_layers.<locals>.hook'
It's necessary to be able to deep copy it to use the CircularValidation
scorer.
To reproduce:
import torch
import numpy as np
from skada.deep.base import DomainAwareCriterion, DomainAwareModule, DomainAwareNet, DomainBalancedDataLoader, BaseDALoss
from skada.metrics import (
ImportanceWeightedScorer,
PredictionEntropyScorer,
SoftNeighborhoodDensity,
DeepEmbeddedValidation,
CircularValidation,
)
from skada.deep.modules import ToyModule2D
from skada import make_da_pipeline
from sklearn.model_selection import ShuffleSplit, cross_validate
from skada.datasets import make_shifted_datasets
from copy import deepcopy
class TestLoss(BaseDALoss):
"""Test Loss to check the deep API"""
def __init__(
self,
):
super().__init__()
def forward(
self,
*args,
):
"""Compute the domain adaptation loss"""
return 0
da_dataset = make_shifted_datasets(
n_samples_source=20,
n_samples_target=20,
shift="concept_drift",
noise=0.1,
random_state=42,
return_dataset=True,
)
X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"])
X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"])
module = ToyModule2D()
estimator = DomainAwareNet(
DomainAwareModule(module, "dropout"),
iterator_train=DomainBalancedDataLoader,
criterion=DomainAwareCriterion(torch.nn.CrossEntropyLoss(), TestLoss()),
batch_size=10,
max_epochs=2,
train_split=None,
)
X = X.astype(np.float32)
X_test = X_test.astype(np.float32)
estimator_copy = deepcopy(estimator) # Doesn't raise errors
estimator.fit(X, y, sample_domain=sample_domain)
estimator_copy_after_fit = deepcopy(estimator) # Raises AttributeError