Лен: открытый подход Google к гибкости в машинном обучении

Думая об машинном обучении, первые фреймворки, которые приходят на ум, — это Tensorflow и PyTorch, которые в настоящее время являются современными фреймворками, если вы хотите работать с Deep Neural Networks. Технологии быстро меняются, и требуется большая гибкость, поэтому исследователи Google разрабатывают новую высокопроизводительную среду для сообщества открытого кода: Flax.

Основой для расчетов служит JAX вместо NumPy, который также является исследовательским проектом Google. Одним из самых больших преимуществ JAX является использование XLA, специального компилятора для линейной алгебры, который позволяет выполнять на GPU и TPU, а также,

Для тех, кто не знает, TPU (тензорный процессор) — это специальная микросхема, оптимизированная для машинного обучения. JAX переопределяет части NumPy для запуска ваших функций на GPU / TPU.

Лен фокусируется на ключевых моментах, таких как:

  • легко читать код
  • предпочитает дублированиевместо плохой абстракции или раздутых функций
  • полезные сообщения об ошибкахкажется, они узнали из сообщений об ошибках Tensorflow
  • легкая расширяемость базовых реализаций

Хватит похвал, теперь давайте начнем кодировать.

Поскольку пример MNIST становится скучным, я создам Классификацию изображений для семейства Симпсонов, к сожалению, Мэгги отсутствует в наборе данных :-(.

Образцы изображений набора данных

Сначала мы устанавливаем необходимые библиотеки и распаковываем наш набор данных. К сожалению, вам все еще понадобится Tensorflow на этом этапе, потому что Flax пропускает хороший конвейер ввода данных.

pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ((0-9)*).((0-9)*),.*/cuda12/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ((0-9)*).((0-9)*).*/cp12/p"`-none-linux_x86_64.whl jax
pip install -q git+https://github.com/google/flax.git@dev-setup
pip install tensorflow
pip install tensorflow_datasets
unzip simpsons_faces.zip

Теперь мы импортируем библиотеки. Вы видите, что у нас есть две «версии» numpy, обычная numpy lib и одна часть API, которую реализует JAX. Оператор печати печатает CPU, GPU или TPU в соответствии с доступным оборудованием.

from jax.lib import xla_bridge
import jax
import flax

import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds

print(xla_bridge.get_backend().platform)

Для обучения и оценки нам сначала нужно создать два набора данных Tensorflow и преобразовать их в массивы numpy / jax, потому что FLAX не принимает типы данных TF. В настоящее время это немного глупо, потому что метод оценки не принимает партии.

Мне пришлось создать один большой пакет для шага eval и создать из него словарь функций TF, который теперь можно анализировать и который можно подавать на наш шаг eval после каждой эпохи.

def train():

  train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
  test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
  
  test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
  #test_ds is one giant batch
  test_ds = test_ds.batch(1000)
  #test ds is a feature dictonary!
  test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds)
  test_ds = tfds.as_numpy(test_ds)
  test_ds = {'image': test_ds(0).astype(jnp.float32), 'label': test_ds(1).astype(jnp.int32)}

  _, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), (((1, 160, 120, 3), jnp.float32)))

  model = flax.nn.Model(CNN, initial_params)

  optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)

  for epoch in range(50):
    for batch in tfds.as_numpy(train_ds):
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)

    print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics('loss'), metrics('accuracy') * 100))

Модель

Класс CNN содержит нашу сверточную нейронную сеть. Когда вы знакомы с Tensorflow / Pytorch, вы видите, что это довольно просто. Каждый вызов нашего flax.nn.Conv определяет обучаемое ядро.

Я использовал MNIST-Example и расширил его несколькими дополнительными слоями. В итоге у нас есть плотный слой с четырьмя выходными нейронами, потому что у нас есть проблема четырех классов.

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape(0), -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=64)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=4)
    x = flax.nn.softmax(x)
    return x

В отличие от Tensorflow, функция активации вызывается явно, что позволяет очень легко тестировать новые и собственные письменные функции активации. FLAX основан на абстракции модуля, и инициация и вызов сети выполняются с помощью функции apply.

Метрики в FLAX

Конечно, мы хотим измерить, насколько хорошей становится наша сеть. Поэтому мы вычисляем наши показатели, такие как потери и точность. Наша точность затем вычисляется с помощью библиотеки JAX вместо NumPy, потому что мы можем использовать JAX на TPU / GPU.

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}
Для измерения наших потерь мы используем перекрестную энтропийную потерю, в отличие от Tensorflow, которую вы рассчитываете сами, у нас пока нет возможности использовать готовые объекты потерь. Как вы можете видеть, мы используем

@jax.vmap

в качестве декоратора функции для нашей функции потерь. Это векторизует наш код для эффективной работы с пакетами.

@jax.vmap
def cross_entropy_loss(logits, label):
  return -jnp.log(logits(label))
Как работает

cross_entropy_loss

работай?

@jax.vmap

принимает как массивы, логиты и метки, так и выполняет

cross_entropy_loss

на каждой паре, что позволяет параллельный расчет партии. Формула кросс-энтропии для одного примера:

Наша основная правда y равна 0 или 1 для одного из четырех выходных нейронов, поэтому нам не нужна формула суммы в нашем коде, потому что мы просто вычисляем log (y_hat) правильной метки. Среднее значение в нашем расчете потерь используется, потому что у нас есть партии.

Обучение

На нашем этапе мы снова используем декоратор функций,

@jax.jit

, для ускорения нашей функции. Это работает очень похоже на Tensorflow. Пожалуйста, имейте в виду

batch(0)

наши данные изображения и

batch(1)

наш лейбл.

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch(0))
    loss = jnp.mean(cross_entropy_loss(
        logits, batch(1)))
    return loss
  grad = jax.grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer
Функция потерь loss_fn возвращает потери для нашей текущей модели,

optimizer.target

, и наш

jax.grad()

рассчитывает его градиент. После расчета мы применяем градиент как в Tensorflow.

Шаг оценки очень прост и минималистичен во льне. Обратите внимание, что полный набор данных оценки передается этой функции.

@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds('image'))
  return compute_metrics(logits, eval_ds('label'))

После 50 эпох у нас очень высокая точность. Конечно, мы можем продолжить настройку модели и оптимизировать гиперпараметр.

Для этого эксперимента я использовал Google Colab, поэтому, если вы хотите проверить его самостоятельно, создайте новую среду с GPU / TPU и импортируйте мой ноутбук из Github, Обратите внимание, что FLAX в настоящее время не работает под Windows.

Выводы

Важно отметить, что FLAX в настоящее время все еще в альфа и не является официальным продуктом Google.

Работа пока дает надежду на быстрый, легкий и легко настраиваемый каркас ML, Пока что полностью отсутствует конвейер ввода данных, поэтому Tensorflow все еще нужно использовать.

К сожалению, текущий набор оптимизаторов ограничен ADAM и SGD с Momentum. Мне особенно понравилось очень строгое направление использования этой платформы и высокая гибкость.

Мои следующие планы состоят в разработке некоторых функций активации, которые еще не доступны. Также было бы очень интересно сравнить скорость между Tensorflow, PyTorch и FLAX.



Источник: Лен: открытый подход Google к гибкости в машинном обучении


Похожие материалы по теме: Лен: открытый подход Google к гибкости в машинном обучении

Leave a comment