Building a Spam Detection System Using BERT and PyTorch
In this tutorial, we will build a spam detection system using a BERT model implemented with PyTorch. The dataset used contains labeled SMS messages, which we will use to train a classifier to distinguish between spam and non-spam messages.
Dataset Overview
The dataset can be downloaded from this GitHub repository. It consists of two columns: Category
(which is either ham
for non-spam or spam
for spam) and Message
(the SMS content). We will load this data into a Pandas DataFrame and preprocess it for training.
Setting Up the Environment
First, let’s import the necessary libraries and check if a GPU is available for training.
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm import tqdm
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
We will use BERT (bert-base-uncased
) for sequence classification, which can classify text into different categories, in this case, ham
(non-spam) or spam
.
Data Preprocessing
Next, we’ll load and preprocess the dataset:
df = pd.read_csv('dataset/spam.csv', encoding='latin1')
df.rename(columns = {'Category': 'label', 'Message': 'text'}, inplace=True)
df['label'] = df['label'].replace({'ham': 0, 'spam': 1})
Here, we rename the columns to label
and text
for easier reference and convert the labels to numerical values (0
for ham, 1
for spam).
Tokenization
We tokenize the SMS messages using the BERT tokenizer:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_data(text):
return tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
df['tokenized'] = df['text'].apply(tokenize_data)
Preparing the DataLoader
We will prepare the dataset for training by converting the tokens into tensors and splitting the dataset into training and validation sets:
input_ids = torch.cat([item['input_ids'] for item in df['tokenized']], dim=0)
attention_masks = torch.cat([item['attention_mask'] for item in df['tokenized']], dim=0)
labels = torch.tensor(df['label'].values)
dataset = TensorDataset(input_ids, attention_masks, labels)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=2)
val_dataloader = DataLoader(val_dataset, batch_size=2)
Model Initialization and Training
We initialize the BERT model and define the optimizer and learning rate scheduler:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
epochs = 3
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
Next, we define the training and validation functions:
def train_model(dataloader, model, optimizer, scheduler, device):
model.train()
total_loss = 0
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
for i, batch in progress_bar:
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
model.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
scheduler.step()
progress_bar.set_description(f"Training loss: {loss.item():.4f}")
avg_loss = total_loss / len(dataloader)
print(f"Training loss: {avg_loss}")
def validate_model(dataloader, model, device):
model.eval()
preds, true_labels = [], []
for batch in dataloader:
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
true_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(true_labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, preds, average='binary')
print(f"Validation Accuracy: {accuracy}")
print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}")
Finally, we train and validate the model:
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
train_model(train_dataloader, model, optimizer, scheduler, device)
validate_model(val_dataloader, model, device)
Saving the Model
After training, the model and tokenizer are saved for later use:
model.save_pretrained("saved_model")
tokenizer.save_pretrained("saved_tokenizer")
print("Model and tokenizer have been saved.")
Inference with the Trained Model
Now, let’s see how to use the trained model for inference:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained("saved_model")
tokenizer = BertTokenizer.from_pretrained("saved_tokenizer")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict(text):
model.eval()
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
return predicted_class
text_to_predict = "gary suggested that the deal maker contact these producers verbally , and then"
prediction = predict(text_to_predict)
print(f"Prediction: {'Spam' if prediction == 1 else 'Not Spam'}")
With this code, you can predict whether a given text is spam or not.
Conclusion
In this guide, we’ve built a spam detection system using a BERT model. We covered loading and preprocessing the dataset, tokenizing text data, training the BERT model, and using it for inference. This tutorial should give you a solid foundation to further explore text classification tasks with BERT and PyTorch.