basicbert

A wrapper class and usage guide for Google's Bidirectional Encoder Representation from Transformers (BERT) text classifier.

Written by David Stein (david@djstein.com).

Also available at https://www.github.com/neuron-whisperer/basicbert.

The Short Version

The purpose of this project is to provide a wrapper class for the Google BERT transformer-based machine learning model and a usage guide for text classification. The objective is to enable developers to apply BERT out-of-the-box for ordinary text classification tasks.

Background

Transformers have become the primary machine learning technology for text processing tasks. One of the best-known transformer platforms is the Google BERT model, which features several different pretrained models that may be generally applied to a variety of tasks with a modest amount of training. The BERT codebase includes a basic file (run_classifier.py) that can be configured for different tasks via a set of command-line parameters.

Despite the impressive capabilities of Google BERT, the codebase suffers from a variety of limitations and disadvantages, such as the following:

These and many other problems arose during my initial experimentation with BERT for a simple project. The entire codebase and documentation entirely fail to answer basic questions, like: How do I export a trained model, or use one to predict the class of an input on the fly, in the manner of an API?

My initial work with BERT required a significant amount of time examining and experimenting with the codebase to understand and circumvent these problems, and to wrangle BERT into a form that can be used with a minimum of hassle. The result is a simple wrapper class that can be (a) configured via a simple text configuration file and (b) invoked with simple commands to perform everyday classification tasks.

Implementation

The heart of this project is basicbert.py, which is designed to run in a Python 3 / TensorFlow 1.15.0 environment.

basicbert.py has been adapted from the Processor subclasses of run_classifier.py, and it reuses as much of the base code as possible. The wrapper exposes a few simple functions: reset(), train(), eval(), test(), export(), and predict(). It can be used in this manner to perform text classification of .tsv files with an arbitrarily collected set of labels. A set of utility functions is also provided to prepare the training data and to reset the training state.

basicbert.py can be configured by creating or editing a file called config.txt in the same folder as basicbert.py. The configuration file has a simple key/value syntax (e.g.: num_train_epochs = 10). If the file does not exist or does not contain some options, basicbert.py will default to some standard values.

basicbert.py subclasses the logging.Filter class and hooks a filter function to the TensorFlow logging process, which redirects all TensorFlow output to filter(self, record). The filter function scrapes a minimal amount of needed data (training progress and loss) from the voluminous TensorFlow output and discards the rest. For debugging, basicbert.py can be configured to save the complete TensorFlow output to a separate text file (by setting the tf_output_file configuration parameter).

basicbert.py can export the model from the latest checkpoint and load it to perform inference. This likely requires saving the labels used for training, which BERT does not do by default. basicbert.py saves the list as labels.txt in the input folder, but this is configurable via config.txt.

Usage

The following steps will train a BERT model and perform some testing and prediction.

Step 1: Prepare Codebase

...as follows:

    output_spec = tf.contrib.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        training_hooks=[tf.train.LoggingTensorHook({'loss': total_loss}, every_n_iter=1)],
        scaffold_fn=scaffold_fn)

(Why is this necessary? Because BERT calculates the loss during training, but only reports the per-epoch loss during training if you request it - and run_classifier.py does not. See this GitHub thread for more information about this modification.)

Step 2: Prepare Data

Step 3: Use basicbert

By default, basicbert.py will train a BERT model on ten epochs of the test data, reporting the loss for each epoch and saving checkpoints along the way. The training process can be canceled at any point, and will automatically resume from the last checkpoint.

If basicbert.py finishes training for the number of epochs indicated in config.txt, then subsequent training commands will be disregarded unless the number of epochs is increased. Alternatively, you may specify the number of training epochs, which will be completed irrespective of the number of previously completed epochs:

    python3 basicbert.py train 3

basicbert can also be used programmatically:

    from basicbert import *
    bert = BERT()
    bert.train()     # returns loss for the last training epoch

The BERT() initializer attempts to load its configuration from config.txt in the same folder as basicbert.py. If config.txt is not present, BERT will use predefined defaults. The BERT initializer optionally accepts a configuration dictionary and uses any values in the dictionary will take highest priority, and will fall back on config.txt or defaults for any values missing from the dictionary.

eval() returns a dictionary of results, with keys: eval_accuracy, eval_loss, global_step, loss.

test() returns an array of tuples, each representing the test result for one example. Each tuple has the following format: (sample_id, best_label, best_confidence, {each_label: each_confidence}).

As previously noted, BERT is configured by default to export models to a subfolder of the output folder, where the subfolder is named by a timestamp. You may move the files to any other path you choose, and may indicate the new location in config.txt. If you choose to leave them in the output folder, when basicbert.py loads the model during prediction, it will examine the subfolders and choose the subfolder with the largest number (presumably the last and best checkpoint). export() returns the path of the exported model.

Example:

    python3 basicbert.py predict The quick brown fox jumped over the lazy dogs.
    bert.predict('The quick brown fox jumped over the lazy dogs.')

The command-line call displays the predicted class, the probability, and the complete list of classes and probabilities. predict() returns a tuple of (string predicted_class, float probability, {string class: float probability}).

Note: As previously noted, an exported BERT model does not include the label set. Without the labels, BERT will have no idea how to map the predicted categories to the assigned labels. To address this deficiency, predict() looks for either labels.txt or train.tsv to retrieve the label set. A path to the label set file can be specified in config.txt.

Utility Functions

The following utility functions are also available for the following tasks:

prepare_data() prepares .tsv files for use with BERT. It reads the specified file (or, by default, data.csv in the script folder), which should be a CSV that is formatted as follows:

    unique_per_sample_identifier, label, text

For example:

    sentence_001, label_1, This is a first sentence to be classified.

    sentence_002, label_2, This is a second sentence to be classified.

Rows are separated by newline characters. Sentences may contain or omit quote marks. Sentences may contain commas (even without quote marks).

The function accepts two floating-point parameters: train and dev, each indicating the number of rows to store in each file. The number of samples for the test set is calculated as (1.0 - train - dev). The function reads the sample data, shuffles the rows, and determines the number of samples to store in each file. It then writes the following files to the same folder:

train.tsv: tab-separated file for training data set

dev.tsv: tab-separated file for validation data set

test.tsv: tab-separated file for test data set

labels.txt: newline-separated list of labels

test-labels.tsv: tab-separated file of correct labels for test data, formatted as follows:

    unique_per_sample_identifier \t label

More Projects