Building a Spam Detection System Using BERT and PyTorch

3 min read .

In this tutorial, we will build a spam detection system using a BERT model implemented with PyTorch. The dataset used contains labeled SMS messages, which we will use to train a classifier to distinguish between spam and non-spam messages.

Dataset Overview

The dataset can be downloaded from this GitHub repository. It consists of two columns: Category (which is either ham for non-spam or spam for spam) and Message (the SMS content). We will load this data into a Pandas DataFrame and preprocess it for training.

Setting Up the Environment

First, let’s import the necessary libraries and check if a GPU is available for training.

import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

We will use BERT (bert-base-uncased) for sequence classification, which can classify text into different categories, in this case, ham (non-spam) or spam.

Data Preprocessing

Next, we’ll load and preprocess the dataset:

df = pd.read_csv('dataset/spam.csv', encoding='latin1')
df.rename(columns = {'Category': 'label', 'Message': 'text'}, inplace=True)
df['label'] = df['label'].replace({'ham': 0, 'spam': 1})

Here, we rename the columns to label and text for easier reference and convert the labels to numerical values (0 for ham, 1 for spam).

Tokenization

We tokenize the SMS messages using the BERT tokenizer:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_data(text):
    return tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

df['tokenized'] = df['text'].apply(tokenize_data)

Preparing the DataLoader

We will prepare the dataset for training by converting the tokens into tensors and splitting the dataset into training and validation sets:

input_ids = torch.cat([item['input_ids'] for item in df['tokenized']], dim=0)
attention_masks = torch.cat([item['attention_mask'] for item in df['tokenized']], dim=0)
labels = torch.tensor(df['label'].values)

dataset = TensorDataset(input_ids, attention_masks, labels)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=2)
val_dataloader = DataLoader(val_dataset, batch_size=2)

Model Initialization and Training

We initialize the BERT model and define the optimizer and learning rate scheduler:

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)

epochs = 3
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Next, we define the training and validation functions:

def train_model(dataloader, model, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))

    for i, batch in progress_bar:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        model.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
        progress_bar.set_description(f"Training loss: {loss.item():.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"Training loss: {avg_loss}")

def validate_model(dataloader, model, device):
    model.eval()
    preds, true_labels = [], []

    for batch in dataloader:
        input_ids, attention_mask, labels = batch
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average='binary')
    print(f"Validation Accuracy: {accuracy}")
    print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}")

Finally, we train and validate the model:

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train_model(train_dataloader, model, optimizer, scheduler, device)
    validate_model(val_dataloader, model, device)

Saving the Model

After training, the model and tokenizer are saved for later use:

model.save_pretrained("saved_model")
tokenizer.save_pretrained("saved_tokenizer")

print("Model and tokenizer have been saved.")

Inference with the Trained Model

Now, let’s see how to use the trained model for inference:

import torch
from transformers import BertTokenizer, BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("saved_model")
tokenizer = BertTokenizer.from_pretrained("saved_tokenizer")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def predict(text):
    model.eval()
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    return predicted_class

text_to_predict = "gary suggested that the deal maker contact these producers verbally , and then"
prediction = predict(text_to_predict)
print(f"Prediction: {'Spam' if prediction == 1 else 'Not Spam'}")

With this code, you can predict whether a given text is spam or not.

Conclusion

In this guide, we’ve built a spam detection system using a BERT model. We covered loading and preprocessing the dataset, tokenizing text data, training the BERT model, and using it for inference. This tutorial should give you a solid foundation to further explore text classification tasks with BERT and PyTorch.

Tags:
AI

See Also

chevron-up