EPS's CODE Magazine division is dedicated to being responsive to our readers, and one of the ways we do that is by allowing readers to contact us through a page on our website https://codemag.com/contact. As you can imagine, this page generates a lot of spam and part of someone's job has been to review all the messages that come in from this page, get rid of the spam messages, and reply to the legitimate ones. The staff that do this have gone from seeing a few spam messages in the early days of the website 23 years ago, to seeing a massive amount of spam today. A large part of the job has become the removal of the spam messages and not the more fulfilling work of responding to the readers.
Naturally, we assumed that AI could solve this problem and relieve the tedium of weeding out spam. With all the big news in AI lately, it's easy to associate AI with large language models (LLMs) and generative pretrained transformers (GPT) like OpenAI's ChatGPT and Google's Bard. Although those tools are great at natural language interactions and can even be coaxed into doing an okay job of classifying spam, they're not the ideal tools for the job. They're relatively slow, can be fairly expensive and, although they're great at general knowledge, they're not built for specific and narrow tasks like this. A better solution is to use another branch of AI: machine learning (ML). ML has been around for a while and isn't as shiny and new as LLMs and GPT, but it's incredibly well suited to spam classification. One of the things that pointed us toward using ML is that we have a LOT of very specific data to use to train our model because we have all of the actual messages submitted to us over the years through our website in a database and the staff has already identified those that are spam. What better training data could we have than that?
ML has been around for a while and isn't as shiny and new as LLMs and GPT, but it's incredibly well suited to spam classification.
Our first step was to run our data through some of the ML algorithms that are readily available and find one that we felt was effective enough to use. Microsoft's ML.NET tools and libraries were the obvious starting point because our website and the services it relies on are all built into .NET. So we threw together an application that allowed us to try different algorithms and compare their effectiveness, and we soon found one that stood out. ML.NET identifies various types of algorithms to use for different types of predictions. Spam identification falls under the type called Classification/Categorization.
In particular, we needed a binary classification algorithm, something that evaluates the data and makes a binary determination, true or false. Or, in this case, spam or not spam. ML.NET contains catalogs of algorithms and the BinaryClassificationCatalog contains about a dozen algorithms for us to try, so we started running some tests. An algorithm called SdcaLogisticRegression
gave us the best results for this particular task with our particular data. After training it with random samplings of about 80% of our data, we were able to test the model's effectiveness by running the resulting model on the remaining random 20% of our data and found that the results were excellent. The model nearly always got it right. This was a huge benefit to the magazine staff. As staff reviewed new inquiries coming in from the website, we used our model to determine if they're legitimate or spam and the staff only had to deal with legitimate inquiries.
Although some of the production code is proprietary, you can take a look at the interesting parts in the following sample code. The website uses microservices, in this case, the Inquiry microservice, which is where this code lives. The first method shown in Listing 1 is TrainAIForSpam()
, which reads fresh training data from the database, trains the model, stores the trained model into Azure BLOB storage, and returns metrics from the process to the caller.
Listing 1: TrainAIForSpam gets data, trains the model, stores the trained model and returns metrics.
public TrainAIForSpamResponse TrainAIForSpam()
{
// Get Inquiry data from database for training
using var biz = new InquiryBusinessObject();
var trainingData = biz.GetInquiryTrainingData();
// Add metrics to response
var response = new TrainAIForSpamResponse();
response.SpamCount = trainingData.Count(t => t.IsSpam);
response.HamCount = trainingData.Count(t => !t.IsSpam);
// Load training data into ML.NET
var mlContext = new MLContext();
var dataView = mlContext.Data.LoadFromEnumerable(trainingData);
IDataView trainData = dataView;
IDataView testData = null;
// Use 80% of the data for training, 20% for testing
var split = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2d);
trainData = split.TrainSet;
testData = split.TestSet;
// Establish output, input column names
var pipelineDataProcessing = mlContext.Transforms.Text
.FeaturizeText("MessageFeaturized", nameof(SpamTrainingData.Message))
.Append(mlContext.Transforms.Text
.FeaturizeText("SubjectFeaturized", nameof(SpamTrainingData.Subject))
.Append(mlContext.Transforms.Concatenate(
"Features",
"MessageFeaturized",
"SubjectFeaturized"))
.AppendCacheCheckpoint(mlContext);
// Create the trainer
IEstimator<ITransformer> trainer =
mlContext.BinaryClassification.Trainers
.SdcaLogisticRegression(nameof(SpamTrainingData.IsSpam), "Features");
// Train the model
var trainingPipeline = pipelineDataProcessing.Append(trainer);
var start = DateTime.UtcNow;
var model = trainingPipeline.Fit(trainData);
var trainingDurationMilliseconds =
(DateTime.UtcNow - start).TotalMilliseconds;
// Add training metrics to reponse
response.Metrics.TrainingStartTimeStamp = start;
response.Metrics.TrainingDurationMilliseconds = trainingDurationMilliseconds;
// Do a test run
var predictions = model.Transform(testData);
// Retrieve metrics from the test run
var metrics = predictions.Schema.Any(c => c.Name == "Probability")
? mlContext.BinaryClassification
.Evaluate(predictions, nameof(SpamTrainingData.IsSpam))
: mlContext.BinaryClassification
.EvaluateNonCalibrated(predictions, nameof(SpamTrainingData.IsSpam));
response.Metrics.Accuracy = metrics.Accuracy;
response.Metrics.F1Score = metrics.F1Score;
response.Metrics.PositivePrecision = metrics.PositivePrecision;
response.Metrics.NegativePrecision = metrics.NegativePrecision;
response.Metrics.PositiveRecall = metrics.PositiveRecall;
response.Metrics.NegativeRecall = metrics.NegativeRecall;
response.Metrics.TestFraction = request.TestFraction;
// Update model in Azure BLOB storage used in production
using var stream = new MemoryStream();
mlContext.Model.Save(model, dataView.Schema, stream);
stream.Position = 0;
await BlobHelper.UploadAsync(
stream,
"ml-models",
"SpamClassificationModel.zip",
"Trained ML Models");
response.TrainedModelSavedToAzure = true;
return response;
}
This snippet shows the SQL query that retrieves the training data from the database. Not shown is the code that converts the data to a List<SpamTrainingData>
used by the training code, shown in this snippet.
SELECT Subject,
Message,
SpamStatus,
FROM Inquiry
WHERE SpamStatus = 1 OR SpamStatus = 4
This next snippet shows the potential values for the SpamStatus
column in the database. The query above only selects records that have been verified by a human to be spam (4) or not spam (1). We don't want to train our model on guesses it has previously made.
public enum SpamStatus
{
Unknown = 0,
ConfirmedNotSpam = 1,
SuspectedNotSpam = 2,
SuspectedSpam = 3,
ConfirmedSpam = 4
}
Although the system is nearly perfect, it isn't perfect. As staff respond to legitimate messages, the SpamStatus
is changed from SuspectedNotSpam
to ConfirmedNotSpam
. When staff comes across a message marked legitimate that's actually spam, they hit a button to change the status from SuspectedNotSpam
to ConfirmedSpam
. Those records can now be used to periodically retrain the model to ensure that it remains effective as new variations of spam emerge. On occasion, staff also reviews the automated SuspectedSpam
records to ensure that we didn't accidentally blacklist a legitimate message. Although this is often just a formality, we do sometimes find misclassified records. After review, all SuspectedSpam
records are set to ConfirmedSpam
or ConfirmedNotSpam
(if any were misclassified). We can then also use these records to retrain the model, improving the accuracy each time we retrain.
We typically classify new messages in bulk just prior to a staff member attending to them. Listing 2 shows the code that works that magic by calling on the model.
Listing 2: The code to classify messages as spam
public ClassifySpamResponse ClassifySpam()
{
// Retrieve unclassified inquiries from the database
using var biz = new InquiryBusinessObject();
var unclassified = biz.GetAllSpamUnclassifiedInquiries();
// Retrieve our trained model from Azure BLOB storage
var stream = BlobHelper.Download("ml-models", "SpamClassificationModel.zip");
// Load the model into ML.NET
var mlContext = new MLContext();
var spamModel = mlContext.Model.Load(stream, out var modelInputSchema);
// Create instance of the prediction engine w/ the model
var predictEngine = mlContext.Model.CreatePredictionEngine
<SpamTrainingData, SpamPredicton>(spamModel);
// Create a response object to return to the caller
var response = new ClassifySpamResponse();
// Processed each unclassified message
foreach (var entry in unclassifiedEntries)
{
var result = predictEngine.Predict(new SpamTrainingData {
Subject = entry.Subject,
Message = entry.Message
});
// Update spam classification in the database
biz.UpdateSpamStatus(entry.Id, result.PredictedLabel
? SpamStatus.SuspectedSpam
: SpamStatus.SuspectedNotSpam);
response.TotalInquriesProcessed++;
if (spamResponse.IsSpam)
response.ClassifiedAsLikelySpam++;
else
response.ClassifiedAsLikelyNotSpam++;
}
return response;
}
Let's review. We started by doing some testing to find the best algorithm for classifying spam. We knew that we needed a binary classification algorithm and we had existing data to use for testing. We tried all of the algorithms in ML.NET's BinaryClassificationCatalog. We tried including data from more columns in our tests, such as the sender's name and company name, but adding this data didn't seem to make predictions any better, so we removed them. After training and testing each algorithm by processing a random sampling of 80% of the data for training and testing on the remaining 20%, we found the best algorithm for the job.
Including data such as the sender's name and company name didn't seem to make predictions any better.
We added code (based on our algorithm testing code) to our Inquiry service to retrain our model using the winning algorithm and to upload the updated model (a relatively small .ZIP file) to Azure BLOB storage. We could then periodically call this service to retrain our model to make it aware of new types of spam we encounter.
Now, before staff responds to new inquiries, the software automatically downloads our model from Azure and runs it against the new inquiries from the database, classifying each entry as either SuspectedSpam
or SuspectedNotSpam
and the staff only sees the SuspectedNotSpam
messages. As they respond to the messages, each message is marked as ConfirmedNotSpam
. If spam did somehow get through, it's manually marked as ConfirmedSpam
and ignored. On occasion, we do a cursory examination of the SuspectedSpam
messages and mark them as ConfirmedSpam
. Any misclassified messages are marked ConfirmedNotSpam
and responded to. Through these processes, we continually accumulate more and better data to occasionally retrain our model and keep it in top shape. We're working on determining how often we should do this, but currently, we retrain roughly once per quarter.
Spam classification through ML continues to save many precious staff hours as well as a little bit of sanity. It's handling some of our mundane tasks, with just a little bit of oversight so our staff can stay focused on the things that provide higher business value.