PyTorch-SRGAN-tutorial icon indicating copy to clipboard operation
PyTorch-SRGAN-tutorial copied to clipboard

A tutorial to super resolution and SRGAN in PyTorch

A PyTorch SRGAN tutorial (Russian)

Этот репозиторий является руководством по обучению и использованию модели SRGAN (Super Resolution GAN). Предполагается, что у читателя имеются базовые знания о нейронных сетях!

Содержание

Туториал состоит из 4 основных частей:

О задаче SISR и SRGAN

SISR (Single Image Super Resolution) - задача увеличения разрешения одного изображения. Из входного изображения с низким разрешением (Low Resolution, LR) необходимо реконструировать изображение с высоким разрешением (Super Resolution, SR), которое будет максимально похоже на изначальное фото (High Resolution, HR).

Одни из самых известных методов увеличения разрешения изображения - это билинейная и бикубическая интерполяции. Эти методы работают быстро, но выдают результат с чрезмерно гладкой текстурой. Более мощные подходы основаны на использовании нейросетей, которых обучают генерировать изображения высокого качества из изображений с низким качеством.

Для многих подходов, использующих нейронные сети, применяются попиксельные (pixel-wise) функции потерь (например, MSE или MAE). Минимизация такой ошибки приводит к тому, что модель находит средние вероятностные решения для каждого пикселя, поэтому они получаются слишком гладкими и, следовательно, имеют низкое качество восприятия, как показано на рисунке ниже.
Различные методы SISR

SRGAN

Оригинальная статья

SRGAN (Super Resolution GAN) - подход к решению задачи SISR, основанный на GAN-ах (генеративно-состязательных сетях).

Генеративно-состязательные сети - это алгоритм, построенный на комбинации из двух нейронных сетей, одна из которых (генератор) генерирует образцы, а другая (дискриминатор) старается отличить правильные («подлинные») образцы от неправильных.

Успех SRGAN в задаче SISR обусловлен двумя важными особенностями:

  • Генеративно-состязательные сети позволяют создавать более реалистичные изображения, чем нейросети, основанные на оптимизации MSE между пикселями. Модели, ориентированные на оптимизацию MSE по пикселям, "усредняли" текстуры, что делало их чрезмерно гладкими. Использование GAN-ов сдвигает реконструированное фото в сторону множества естественных изображений, позволяя получить более реалистичные решения.
    Alt Распределение возможных фотореалистичных изображений
  • Второй важной особенностью стало использование perceptual loss, которая основывается на евклидовых расстояниях (MSE, MAE), вычисленных в пространстве признаков глубокой сверточной нейронной сети (например, предварительно обученной VGG). Такая функция ошибки будет более инвариантна к изменениям пикселей на изображении, чем попиксельные MSE или MAE.

Архитектура сети

Основу генеративной сети составляют B residual блоков с одинаковой структурой. В каждом блоке находятся два свёрточных слоя с ядрами 3x3 и 64 каналами, за которыми расположены batch-norm слои. В качестве функции активации используется PReLU (Parametric Rectified Linear Unit). Входное изображение увеличивается попиксельно с помощью двух свёрточных слоев с операцией PixelShuffle и функцией активации PReLU.

В качестве дискриминатора используется типичная для классификатора изображений архитектура с fully-connected слоями в конце.

Архитектуры генератора и дискриминатора в картинке: SRGAN structure

Подготовка датасета

Обучение нейронной сети будет производится на парах изображений LR-HR. Однако не обязательно заранее подготавливать LR и HR пары изображений вместе, достаточно подготовить только High Resolution фотографии. Low Resolution изображения мы сможем получить из HR изображений, уменьшив их билинейной/бикубической интерполяцией прямо во время обучения.

Этап подготовки данных описан и реализован в ноутбуке 1_create_dataset.ipynb

Датасеты, которые я рекомендую использовать:

  1. DIV2K - https://data.vision.ee.ethz.ch/cvl/DIV2K/
  2. Flickr2k - https://drive.google.com/drive/folders/1AAI2a2BmafbeVExLH-l0aZgvPgJCk5Xm
  3. FFHQ - https://github.com/NVlabs/ffhq-dataset

Обучение

Обучение модели разделяют на два этапа: сначала на обычной MSE обучают генератор и получают сеть, которую называют SRResNet. Затем обучают SRGAN, проинициализованный весами SRResNet. Такое разделение необходимо, чтобы генератор не выдавал шум на начальных стадиях обучения и дискриминатор не начинал сразу выигрывать. Это позволяет избежать попадания в нежелательный локальный минимум при обучении SRGAN.

Код для обучения SRResNet вы можете найти в ноутбуке: 2_train_srresnet.ipynb

Обучение SRGAN реализовано в ноутбуке: 3_train_srgan.ipynb

Запуск модели

Сравнение методов апскейла

Теперь давайте протестируем модель. Веса SRGAN и SRResNet, обученные на датасетах DIV2k и Flickr2k, можно найти здесь: weights/pretrained/

Соответствующий ноутбук: 4_evaluate_model.ipynb

К сожалению, такая модель не будет хорошо работать на многих изображениях из-за различных шумов. Например, она будет давать плохой результат на фотографиях с расширениями jpg/jpeg. Формат jpeg использует алгоритмы сжатия изображений, поэтому, чтобы модель хорошо работала и на них, нужно ее обучить убирать артефакты сжатия. Есть и другие шумы, которые нужно удалять в процессе увеличения качества изображения, но пожалуй более подробно расскажу об этом в следующем гайде)

Если нужна более совершенная модель, которая будет убирать артефакты и повышать качество изображений, то советую заглянуть в этот репозиторий (Real-ESRGAN).