<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          使用JAX實(shí)現(xiàn)完整的Vision Transformer

          共 10389字,需瀏覽 21分鐘

           ·

          2023-03-04 09:17

          來(lái)源DeepHub IMBA
          本文約3200,建議閱讀10+分鐘

          本文將展示如何使用JAX/Flax實(shí)現(xiàn)Vision Transformer (ViT),以及如何使用JAX/Flax訓(xùn)練ViT。


          Vision Transformer

          在實(shí)現(xiàn)Vision Transformer時(shí),首先要記住這張圖。

          以下是論文描述的ViT執(zhí)行過(guò)程。

          • 從輸入圖像中提取補(bǔ)丁圖像,并將其轉(zhuǎn)換為平面向量。
          • 投影到 Transformer Encoder 來(lái)處理的維度。
          • 預(yù)先添加一個(gè)可學(xué)習(xí)的嵌入([class]標(biāo)記),并添加一個(gè)位置嵌入。
          • 由 Transformer Encoder 進(jìn)行編碼處理。
          • 使用[class]令牌作為輸出,輸入到MLP進(jìn)行分類。


          細(xì)節(jié)實(shí)現(xiàn)

          下面,我們將使用JAX/Flax創(chuàng)建每個(gè)模塊。

          1. 圖像到展平的圖像補(bǔ)丁

          下面的代碼從輸入圖像中提取圖像補(bǔ)丁。這個(gè)過(guò)程通過(guò)卷積來(lái)實(shí)現(xiàn),內(nèi)核大小為patch_size * patch_size, stride為patch_size * patch_size,以避免重復(fù)。

           class Patches(nn.Module):   patch_size: int   embed_dim: int
          def setup(self): self.conv = nn.Conv( features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding='VALID' )
          def __call__(self, images): patches = self.conv(images) b, h, w, c = patches.shape patches = jnp.reshape(patches, (b, h*w, c)) return patches

          2和3. 對(duì)展平補(bǔ)丁塊的線性投影/添加[CLS]標(biāo)記/位置嵌入

          Transformer Encoder 對(duì)所有層使用相同的尺寸大小hidden_dim。上面創(chuàng)建的補(bǔ)丁塊向量被投影到hidden_dim維度向量上。與BERT一樣,有一個(gè)CLS令牌被添加到序列的開(kāi)頭,還增加了一個(gè)可學(xué)習(xí)的位置嵌入來(lái)保存位置信息。

           class PatchEncoder(nn.Module):   hidden_dim: int
          @nn.compact def __call__(self, x): assert x.ndim == 3 n, seq_len, _ = x.shape # Hidden dim x = nn.Dense(self.hidden_dim)(x) # Add cls token cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim)) cls = jnp.tile(cls, (n, 1, 1)) x = jnp.concatenate([cls, x], axis=1) # Add position embedding pos_embed = self.param( 'position_embedding', nn.initializers.normal(stddev=0.02), # From BERT (1, seq_len + 1, self.hidden_dim) ) return x + pos_embed

          4. Transformer encoder

          如上圖所示,編碼器由多頭自注意(MSA)和MLP交替層組成。Norm層 (LN)在MSA和MLP塊之前,殘差連接在塊之后。

           class TransformerEncoder(nn.Module):   embed_dim: int   hidden_dim: int   n_heads: int   drop_p: float   mlp_dim: int
          def setup(self): self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p) self.mlp = MLP(self.mlp_dim, self.drop_p) self.layer_norm = nn.LayerNorm(epsilon=1e-6)
          def __call__(self, inputs, train=True): # Attention Block x = self.layer_norm(inputs) x = self.mha(x, train) x = inputs + x # MLP block y = self.layer_norm(x) y = self.mlp(y, train)
          return x + y

          MLP是一個(gè)兩層網(wǎng)絡(luò)。激活函數(shù)是GELU。本文將Dropout應(yīng)用于Dense層之后。

           class MLP(nn.Module):   mlp_dim: int   drop_p: float   out_dim: Optional[int] = None
          @nn.compact def __call__(self, inputs, train=True): actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense(features=self.mlp_dim)(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) x = nn.Dense(features=actual_out_dim)(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) return x

          多頭自注意(MSA)

          qkv的形式應(yīng)為[B, N, T, D],如Single Head中計(jì)算權(quán)重和注意力后,應(yīng)輸出回原維度[B, T, C=N*D]。

           class MultiHeadSelfAttention(nn.Module):   hidden_dim: int   n_heads: int   drop_p: float
          def setup(self): self.q_net = nn.Dense(self.hidden_dim) self.k_net = nn.Dense(self.hidden_dim) self.v_net = nn.Dense(self.hidden_dim)
          self.proj_net = nn.Dense(self.hidden_dim)
          self.att_drop = nn.Dropout(self.drop_p) self.proj_drop = nn.Dropout(self.drop_p)
          def __call__(self, x, train=True): B, T, C = x.shape # batch_size, seq_length, hidden_dim N, D = self.n_heads, C // self.n_heads # num_heads, head_dim q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D) k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
          # weights (B, N, T, T) weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D) normalized_weights = nn.softmax(weights, axis=-1)
          # attention (B, N, T, D) attention = jnp.matmul(normalized_weights, v) attention = self.att_drop(attention, deterministic=not train)
          # gather heads attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)
          # project out = self.proj_drop(self.proj_net(attention), deterministic=not train)
          return out

          5. 使用CLS嵌入進(jìn)行分類

          最后MLP頭(分類頭)。

           class ViT(nn.Module):   patch_size: int   embed_dim: int   hidden_dim: int   n_heads: int   drop_p: float   num_layers: int   mlp_dim: int   num_classes: int
          def setup(self): self.patch_extracter = Patches(self.patch_size, self.embed_dim) self.patch_encoder = PatchEncoder(self.hidden_dim) self.dropout = nn.Dropout(self.drop_p) self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim) self.cls_head = nn.Dense(features=self.num_classes)
          def __call__(self, x, train=True): x = self.patch_extracter(x) x = self.patch_encoder(x) x = self.dropout(x, deterministic=not train) for i in range(self.num_layers): x = self.transformer_encoder(x, train) # MLP head x = x[:, 0] # [CLS] token x = self.cls_head(x) return x

          使用JAX/Flax訓(xùn)練

          現(xiàn)在已經(jīng)創(chuàng)建了模型,下面就是使用JAX/Flax來(lái)訓(xùn)練。

          數(shù)據(jù)集

          這里我們直接使用 torchvision的CIFAR10。

          首先是一些工具函數(shù):

           def image_to_numpy(img):   img = np.array(img, dtype=np.float32)   img = (img / 255. - DATA_MEANS) / DATA_STD   return img  def numpy_collate(batch):   if isinstance(batch[0], np.ndarray):     return np.stack(batch)   elif isinstance(batch[0], (tuple, list)):     transposed = zip(*batch)     return [numpy_collate(samples) for samples in transposed]   else:     return np.array(batch)

          然后是訓(xùn)練和測(cè)試的dataloader:

           test_transform = image_to_numpy train_transform = transforms.Compose([     transforms.RandomHorizontalFlip(),     transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),     image_to_numpy ])  # Validation set should not use the augmentation. train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True) val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True) train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) _, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) test_set = CIFAR10('data', train=False, transform=test_transform, download=True)  train_loader = torch.utils.data.DataLoader(     train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) val_loader = torch.utils.data.DataLoader(     val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) test_loader = torch.utils.data.DataLoader(     test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, )

          初始化模型

          初始化ViT模型:

           def initialize_model(     seed=42,     patch_size=16, embed_dim=192, hidden_dim=192,     n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10 ):   main_rng = jax.random.PRNGKey(seed)   x = jnp.ones(shape=(5, 32, 32, 3))   # ViT   model = ViT(       patch_size=patch_size,       embed_dim=embed_dim,       hidden_dim=hidden_dim,       n_heads=n_heads,       drop_p=drop_p,       num_layers=num_layers,       mlp_dim=mlp_dim,       num_classes=num_classes  )   main_rng, init_rng, drop_rng = random.split(main_rng, 3)   params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']   return model, params, main_rng  vit_model, vit_params, vit_rng = initialize_model()

          創(chuàng)建TrainState

          在Flax中常見(jiàn)的模式是創(chuàng)建管理訓(xùn)練的狀態(tài)的類,包括輪次、優(yōu)化器狀態(tài)和模型參數(shù)等等。還可以通過(guò)在apply_fn中指定apply_fn來(lái)減少學(xué)習(xí)循環(huán)中的函數(shù)參數(shù)列表,apply_fn對(duì)應(yīng)于模型的前向傳播。

           def create_train_state(     model, params, learning_rate ):   optimizer = optax.adam(learning_rate)   return train_state.TrainState.create(       apply_fn=model.apply,       tx=optimizer,       params=params  )      state = create_train_state(vit_model, vit_params, 3e-4)

          循環(huán)訓(xùn)練

           def train_model(train_loader, val_loader, state, rng, num_epochs=100):   best_eval = 0.0   for epoch_idx in tqdm(range(1, num_epochs + 1)):     state, rng = train_epoch(train_loader, epoch_idx, state, rng)     if epoch_idx % 1 == 0:       eval_acc = eval_model(val_loader, state, rng)       logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)       if eval_acc >= best_eval:         best_eval = eval_acc         save_model(state, step=epoch_idx)       logger.flush()   # Evaluate after training   test_acc = eval_model(test_loader, state, rng)   print(f'test_acc: {test_acc}')    def train_epoch(train_loader, epoch_idx, state, rng):   metrics = defaultdict(list)   for batch in tqdm(train_loader, desc='Training', leave=False):     state, rng, loss, acc = train_step(state, rng, batch)     metrics['loss'].append(loss)     metrics['acc'].append(acc)   for key in metrics.keys():     arg_val = np.stack(jax.device_get(metrics[key])).mean()     logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx)     print(f'[epoch {epoch_idx}] {key}: {arg_val}')   return state, rng

          驗(yàn)證

           def eval_model(data_loader, state, rng):   # Test model on all images of a data loader and return avg loss   correct_class, count = 0, 0   for batch in data_loader:     rng, acc = eval_step(state, rng, batch)     correct_class += acc * batch[0].shape[0]     count += batch[0].shape[0]   eval_acc = (correct_class / count).item()   return eval_acc

          訓(xùn)練步驟

          在train_step中定義損失函數(shù),計(jì)算模型參數(shù)的梯度,并根據(jù)梯度更新參數(shù);在value_and_gradients方法中,計(jì)算狀態(tài)的梯度。在apply_gradients中,更新TrainState。交叉熵?fù)p失是通過(guò)apply_fn(與model.apply相同)計(jì)算logits來(lái)計(jì)算的,apply_fn是在創(chuàng)建TrainState時(shí)指定的。

           @jax.jit def train_step(state, rng, batch):   loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True)   # Get loss, gradients for loss, and other outputs of loss function   (loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)   # Update parameters and batch statistics   state = state.apply_gradients(grads=grads)   return state, rng, loss, acc

          計(jì)算損失

           def calculate_loss(params, state, rng, batch, train):   imgs, labels = batch   rng, drop_rng = random.split(rng)   logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})   loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()   acc = (logits.argmax(axis=-1) == labels).mean()   return loss, (acc, rng)

          結(jié)果

          訓(xùn)練結(jié)果如下所示。在Colab pro的標(biāo)準(zhǔn)GPU上,訓(xùn)練時(shí)間約為1.5小時(shí)。

           test_acc: 0.7704000473022461

          如果你對(duì)JAX感興趣,請(qǐng)看這里是本文的完整代碼:

          https://github.com/satojkovic/vit-jax-flax

          作者:satojkovic


          編輯:黃繼彥


          瀏覽 54
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  婷婷AV无码在线 | 国产精品久草 | 国产99久久99热 | 亚洲一级操逼 | 国内自拍观看 |