torch_ecg icon indicating copy to clipboard operation
torch_ecg copied to clipboard

Are Pre-Trained weights available for ECG_SEQ_LAB_NET ?

Open ben120-web opened this issue 11 months ago • 12 comments

I am looking to test this deep learning Peak detector (ECG_SEQ_LAB_NET) against my own, so to try this (while avoiding re-training) I am wondering if the weights for this network can be available?

Thanks

ben120-web avatar Jan 17 '25 14:01 ben120-web

Essentially looking to establish the feasibility on the following actions:

  1. Use pre-trained ECG_SEQ_LAB_NET network to predict peak locations and compare against my own.
  2. Re-train the network on both the data used in this challenge, and my own data to see if performance increases.

ben120-web avatar Jan 17 '25 15:01 ben120-web

Yes, there are a few:

  • https://drive.google.com/file/d/18Jta73DjqXVarEYjN_CWeYqM8rH7I3An/view?usp=sharing
  • https://drive.google.com/file/d/1ueTZ2pVPp6mgLhaNt9Z6AOJqqM5WwoSY/view?usp=sharing
  • https://drive.google.com/file/d/1RV7SIUDGuhlhTxmCW7ORuMnPemxLu1rf/view?usp=sharing

You can load with ECG_SEQ_LAB_NET.from_remote or download to local storage and use ECG_SEQ_LAB_NET.from_checkpoint

wenh06 avatar Jan 17 '25 16:01 wenh06

This is great. Thanks very much. Excellent repo 👍

ben120-web avatar Jan 17 '25 16:01 ben120-web

Thank you. If you have any further questions or suggestions, please let me know.

wenh06 avatar Jan 17 '25 17:01 wenh06

One other thing, I am trying to load the trained models in by calling the from_remote method however seems to throw an error for me.

I am using:


 remote_url = r"https://drive.google.com/uc?id=18Jta73DjqXVarEYjN_CWeYqM8rH7I3An" # Path to google drive folder.
model_dir = r"C:\models"  # Directory to store the downloaded model

# Load the model from remote
try:
    model, _ = ECG_SEQ_LAB_NET.from_remote(
        url=remote_url,
        model_dir=model_dir,
        filename="best_model.pth",
        device=torch.device("cpu")  #Using CPU currently, have disabled CUDA.
    )

    model.eval()  # Set the model to evaluation mode
    print("Model loaded successfully and ready for R-peak detection.")
    
except Exception as e:
    print(f"An error occurred while loading the model: {e}")

I return the following error:

[WinError 32] The process cannot access the file because it is being used by another process: 'C:\models\tmpn8qdrek_'

Have you seen this before?

I should note, it does look like a model is loaded right before this error is hit.

Image

ben120-web avatar Jan 17 '25 19:01 ben120-web

I haven't seen this type of error before. I asked GitHub Copilot and it answered me:

The [WinError 32] The process cannot access the file because it is being used by another process error typically occurs when a file is being accessed by another process and cannot be opened or modified. To handle this error, you can use a retry mechanism with a delay to wait for the file to become available.

I typically use Ubuntu and it worked fine:

Image

Now that you've downloaded the file successfully (with a temporary filename), you can try load it with the from_checkpoint method.

I will later add a delay after the downloading.

wenh06 avatar Jan 18 '25 08:01 wenh06

Thanks @wenh06 , appreciate the support. I have used form from_checkpoint which seems to have correctly loaded the model into Python.

My next issue is with obtaining the R peak locations from the model. Currently I am testing on a 5000 sample ECG signal with the following pre-processing:

    # Preprocess the ECG signal
    input_tensor = torch.tensor(ecg_signal, dtype=torch.float32).unsqueeze(0)  # (1, 5000)
    input_tensor = input_tensor.unsqueeze(0)  # (1, 1, 5000)


    # Call model to obtain outputs
        with torch.no_grad():
            output = model(input_tensor)

To obtain the R-Peaks I am running the following:

def post_process_output(output):
    """
    Post-process the model output to extract R-peak indices.
    output: torch.Tensor containing model outputs.
    Returns a list of detected R-peak indices.
    """
    probabilities = torch.sigmoid(output)  # Convert logits to probabilities
    qrs_probs = probabilities.squeeze().cpu().numpy()  # Convert to numpy array

    threshold = 0.5  # Threshold for detecting peaks
    r_peak_candidates = (qrs_probs > threshold).nonzero()[0]

    refractory_period = 200  # Minimum distance between peaks (in samples)
    r_peaks = []
    last_peak = -refractory_period

    for idx in r_peak_candidates:
        if idx - last_peak > refractory_period:
            r_peaks.append(idx)
            last_peak = idx

    return r_peaks

The r_peaks being returned to me don't look correct so I am assuming there is something wrong in my setting. Do you know what is happening?

Thanks

ben120-web avatar Jan 19 '25 14:01 ben120-web

The model was trained with ECGs with a sampling frequency of 500 Hz. Did you resample your input signal to 500 Hz?

The second returned value of the from_checkpoint method is the training config of the model, which would give the user the preprocessing steps (e.g. resampling, normalization (no normalization for this model), etc.) needed to convert a raw signal to a torch.Tensor that is ready to feed the forward method.

Currently, the models in torch_ecg.models all have an inference method that would raise NotImplementedError("implement a task specific inference method") error.

You remind me that we can add some sort of Inferencer that could be constructed from a training config, and it could directly produce the desired output (for instance the indices of R peaks) from a raw signal.

wenh06 avatar Jan 19 '25 15:01 wenh06

Yea, the input signal is at 500Hz 🤔

Do the signals need segmented before passing into the model? Ie to I need to split the 10 second (5000 / 500) into 2 second (1000 samples) segments?

ben120-web avatar Jan 19 '25 17:01 ben120-web

I guess the problem is that the output mask of the model is a downsampled version (the ratio is 8 or 16?). You can check the shape of the output tensor. An example of an appropriate inference function can be found here, and specifically the post-processing function.

wenh06 avatar Jan 20 '25 02:01 wenh06

The signal need not be split into multiple segments.

wenh06 avatar Jan 20 '25 02:01 wenh06

I guess the problem is that the output mask of the model is a downsampled version (the ratio is 8 or 16?)

You are right yea, it is downsampled by a factor of 8 (5000 samples to 625). Using the original predictions, and upsampling them by a factor of 8, and lowering the Minimum distance between peaks to 50 samples (from 200 samples) , I get the following peaks detected.

Image

ben120-web avatar Jan 20 '25 11:01 ben120-web