Source code for pytorch_ood.model.gru

"""
Text classifier used by Hendrycks et al.
"""
import torch
from torch import nn


[docs] class GRUClassifier(nn.Module): """ Classifier with token embedding and multi layer gated recurrent unit (GRU) for text classification, as used in the OOD/Error Detection baseline paper. :see Implementation: `GitHub <https://github.com/hendrycks/outlier-exposure/blob/master/NLP_classification/train.py>`__ """ def __init__(self, num_classes, n_vocab, embedding_dim=50): """ :param num_classes: number of classes in the dataset :param n_vocab: size of the vocabulary, i.e. number of distinct tokens :param embedding_dim: embedding size """ super().__init__() self.embedding = nn.Embedding(n_vocab, embedding_dim, padding_idx=1) self.gru = nn.GRU( input_size=50, hidden_size=128, num_layers=2, bias=True, batch_first=True, bidirectional=False, ) self.fc = nn.Linear(128, num_classes)
[docs] def features(self, x: torch.Tensor) -> torch.Tensor: """ :param x: batch of tokens :returns: features """ embeds = self.embedding(x) return self.gru(embeds)[1][1] # select h_n, and select the 2nd layer
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ :param x: batch of tokens :returns: logits """ embeds = self.embedding(x) z = self.gru(embeds)[1][1] # select h_n, and select the 2nd layer logits = self.fc(z) return logits