Next Word Prediction using Katz Backoff Model - Part 3: Prediction Model Implementation

Executive Summary

The Capstone Project of the Johns Hopkins Data Science Specialization is to build an NLP application, which should predict the next word of a user text input. In Part 1, we have analysed and found some characteristics of the training dataset that can be made use of in the implementation. We have also discussed the Good-Turing smoothing estimate and Katz backoff model that powering our text prediction application in Part 2.

This part of the report will focus on the implementation of the text prediction, and some thoughts on the project.


The implementation is mainly consisted of three part:

  • data preprocessing: construct training data from the raw text file given.
  • frequency smoothing: update the counts of the $N$-grams in the training data using Good-Turing estimation
  • $N$-grams preparation: create different $N$-gram datasets according to user input.
  • calculation: apply Katz backoff model to predict the next word.

The program will make use of the Quanteda, readtext, and data.table packages to process and analyse the text data.


Data Preprocessing

To build a text prediction model, the first step is to prepare the dataset to train the model. In this project, we use 3 text files containing content of blog, news, and twitter in en_US locale. First, we need to read the text file and then combine them into one character vector. Note that we sample the content of the text file according to a specific sampling size to reduce the training data size.

sample_size = 0.8
news <- stri_split_lines1(readtext(paste0(path,"en_US/")))
news <- sample(news, size=length(news)*sample_size)
blogs <- stri_split_lines1(readtext(paste0(path,"en_US/en_US.blogs.txt")))
blogs <- sample(blogs, size=length(blogs)*sample_size)
twitter <- stri_split_lines1(readtext(paste0(path,"en_US/en_US.twitter.txt")))
twitter <- sample(twitter, size=length(twitter)*sample_size)
txt <- c(news, blogs, twitter)

In order to count the raw frequency of unigrams, bigrams, and trigrams, we need a function to create a document-feature matrix from the character vector. Moreover, we clean the original text input by first tokenize the character vector and remove some elements that are not going to be handled in the prediction model, such as numbers, various punctuation and symbols, URL, and English stopwords. Then the three document-feature matrices are going to be created from the token objects.

## Returns a Quanteda dfm from a given character vector
## txt - Character vector of text, each element in the vector is a document in dfm
## ng - The 'N' of N-gram <- function(txt, ng) {
  text.dfm <- txt %>% tokens(remove_numbers=T, remove_punct=T, remove_symbols=T, remove_hyphens=T, remove_twitter=T, remove_url=T) %>%  tokens_remove(stopwords("en")) %>% tokens_ngrams(n=ng) %>% dfm()

UniG <-, 1)  # dfm containing unigrams
BiG <-, 2)  # dfm containing bigrams
TriG <-, 3)  # dfm containing trigrams

According to the discoveries in the exploratory analysis in Part 1, uncommon $N$-grams can be removed to greatly reduce the training data size and does not affect the performance a lot.

CountNGramFreq <- function(NGrDfm) {
  FreqV <- colSums(NGrDfm)
  return(data.table(term=names(FreqV), c=FreqV))
UniFreq <- CountNGramFreq(UniG)
BiFreq <- CountNGramFreq(BiG)
TriFreq <- CountNGramFreq(TriG)

# To control which terms to be ignored with raw count < min_count
min_count = 4

UniFreq <- UniFreq[c>min_count,]
BiFreq <- BiFreq[c>min_count,]
TriFreq <- TriFreq[c>min_count,]

Good-Turing Smoothing

There are 4 steps to perform the GT smoothing, which are:

  1. Count the frequency of frequency $N_r$.

    ## Calculate the "frequency of frequency r" (N_r)
    CountNC <- function(FreqVec) {
      CountTbl <- table(FreqVec[,.(c)])
      return(data.table(cbind(c=as.integer(names(CountTbl)), Nr=as.integer(CountTbl))))
    UniBins <- CountNC(UniFreq)
    BiBins <- CountNC(BiFreq)
    TriBins <- CountNC(TriFreq)
  2. Average all the non-zero counts using equation $Z_r=\frac{N_r}{0.5(t-q)}$.

    ## Average non-zero count, replace N_r with Z_r
    avg.zr <- function(Bins) {
      max <- dim(Bins)[1]
      Bins[1, Zr:=2*Nr/Bins[2,c]]  # r=1, q=0, Zr=Nr/(0.5t)
      Bins[r, Zr:=2*Nr/(Bins[r+1,c]-Bins[r-1,c])]  # else, Zr=Nr/(0.5(t-q))
      Bins[max, Zr:=Nr/(c-Bins[(max-1),c])]  # r=max, t=2r-q, Zr=Nr/(r-q)
  3. Fit a linear regression model $\log(Z_r)=a+b \log( r )$.

    ## Replace Z_r with value computed from a linear regression that is fit to map Z_r to c in log space
    ## log(Z_r) = a + b*log(c)
    FitLM <- function(CountTbl) {
      return(lm(log(Zr) ~ log(c), data = CountTbl))
    UniLM <- FitLM(UniBins)
    BiLM <- FitLM(BiBins)
    TriLM <- FitLM(TriBins)
  4. Update $r$ with $r^*$ using Katz equation and constant $k$, with updated $Z_r$ corresponding to the specific $r$ read out from the linear regression model.

    ## Only perform the discounting to small count (c) n-grams, where c <= k, using Katz's formula
    Cal_GTDiscount <- function(cnt, N) {
      if (N==1) {
        model <- UniLM
      } else if (N==2) {
        model <- BiLM
      } else if (N==3) {
        model <- TriLM
      # Common parts
      Z1 <- exp(predict(model, newdata=data.frame(c=1)))
      Zr <- exp(predict(model, newdata=data.frame(c=cnt)))
      Zrp1 <- exp(predict(model, newdata=data.frame(c=(cnt+1))))
      Zkp1 <- exp(predict(model, newdata=data.frame(c=(k+1))))
      sub <- ((k+1)*Zkp1)/(Z1)
      new_r <- ((cnt+1)*(Zrp1)/(Zr)-cnt*sub)/(1-sub)
    UpdateCount <- function(FreqTbl, N) {
      FreqTbl[c>k ,cDis:=as.numeric(c)]
      FreqTbl[c<=k, cDis:=Cal_GTDiscount(c, N)]
    UpdateCount(UniFreq, 1)
    UpdateCount(BiFreq, 2)
    UpdateCount(TriFreq, 3)
    setkey(UniFreq, term)
    setkey(BiFreq, term)
    setkey(TriFreq, term)

$N$-grams Preparation

The Katz backoff model requires several sets of $N$-gram and $(N-1)$-gram data, according to the user input, to successfully calculate all the necessary probabilities for comparison and choose the most suitable next word candidate. In particular, in a trigram model case, according to the user input $(x,y)$, the following set of trigrams, bigrams, and unigrams are needed (please refer to Part 2 for the detail of equations):

If $C(x,y) \gt 0$:

  1. All observed trigrams \(\mathbf{OT}=\{(x,y,z)|C(x,y,z)\gt0\}\)
    To calculate $P^*(z|x,y)$ and numerator of $\alpha(x,y)$ in the case where $C(x,y,z) \gt 0$

  2. All unobserved trigrams $\mathbf{UOT}=\{(x,y,z)|C(x,y,z)=0\}$
    To calculate $P_\text{katz}(z|y)$

  3. All observed bigrams $\mathbf{OB}=\{(y,z)|C(y,z)\gt0\} \text{ and } \mathbf{OB} \subset \mathbf{UOT}$
    To calculate $P^*(z|y)$ and numerator of $\alpha(y)$ in the case where $C(y,z) \gt 0$

  4. All “tail” unigrams $\mathbf{UOBT}=\{z\}$ that end the unobserved bigrams $\mathbf{UOB}=\{(y,z)|C(y,z)=0\} \text{ and } \mathbf{UOB} \subset \mathbf{UOT}$
    To calculate $P_{ML}(z)$

If $C(x,y) = 0$:

  1. if $C(y) \gt 0$:

    1. All observed bigrams $\mathbf{OB}=\{(y,z)\}, \text{where }C(y,z) \gt 0$
      To calculate $P^*(z|y)$ and numerator of $\alpha(y)$ in the case where $C(y,z) \gt 0$

    2. All “tail” unigrams $\mathbf{UOBT}=\{z\}$ that end the unobserved bigrams $\mathbf{UOB}=\{(y,z)|C(y,z)=0\}$
      To calculate $P_{ML}(z)$

  2. if $C(y) = 0$:

    1. All unigrams $\{z\}$
      To calculate $P_{ML}(z)$

The above dataset preparation steps require below 2 operations:

  1. Retrieve observed $N$-grams from $(N-1)$-gram

    ## Return all the observed N-grams given the previous (N-1)-gram
    ## - wordseq: character vector of (N-1)-gram separated by underscore, e.g. "x1_x2_..._x(N-1)"
    ## - NgramFreq: datatable of N-grams <- function(wordseq, NgramFreq) {
      PreTxt <- sprintf("%s%s%s", "^", wordseq, "_")
      NgramFreq[grep(PreTxt, NgramFreq[,term], perl=T, useBytes=T),]
  2. Retrieve all the unigrams that end unobserved $N$-grams

    ## Return all the unigrams that end unobserved Ngrams
    get.unobs.Ngram.tails <- function(ObsNgrams, N) {
      ObsTails <- str_split_fixed(ObsNgrams[,term], "_", N)[,N]


After all the necessary data are ready, basically only 2 calculation operations are needed:

  1. Compute the probabilities of observed $N$-grams

    ## Compute the probabilities of observed N-gram.
    ## We need the counts from (N-1)-gram table since corpus doesn't include <EOS> explicitly,
    ## therefore the denominator will be smaller if only summing up all the terms
    ## from N-gram table
    cal.obs.prob <- function(ObsNgrams, Nm1Grams, wordseq) {
      PreCount <- Nm1Grams[wordseq, c, on=.(term)]
      ObsNgrams[,Prob:=ObsNgrams[,cDis]/PreCount]  # c_dis/c
  2. Calculate $\alpha$

    ## Compute Alpha
    ## Return the normalization factor Alpha
    ## - ObsNgrams: datatable contains all observed ngrams starting with wordseq
    ## - Nm1Grams: datatable of (N-1)-grams containing count of wordseq
    ## - wordseq: an observed history: w_{i-N+1}^{i-1}
    cal.alpha <- function(ObsNGrams, Nm1Grams, wordseq) {
      if (dim(ObsNGrams)[1] != 0) {
        # return(1-sum(ObsNGrams[,.(Qbo)]))  # We don't use this formular because End Of Sentence is not counted
        return(sum(ObsNGrams[,c-cDis]/Nm1Grams[wordseq, c, on=.(term)]))
      } else {

There is a tricky part when calculating $\alpha$ that the original formula $\alpha(w_{i-N+1}^{i-1})=\frac{1-\sum_{w:C(w_{i-N+1}^{i-1},w)\gt0} P^*(w|w_{i-N+1}^{i-1})}{1-\sum_{w:C(w_{i-N+1}^{i-1},w)\gt0} P^*(w|w_{i-N+2}^{i-1})}$ is not used since we don’t add the “End Of Sentence” (<EOS>) token for each sentence so that for example, the counts of all observed bigrams $(y,w)$ start with $y$ does not sum up to the count of unigram $y$, as the case that $y$ ends the sentence is ignored. Instead, we minus the discounted count of the $N$-gram from the raw count to find out how much probability mass can be taken off, and divided it by the raw count of the corresponding $(N-1)$-gram to normalize it. I.e.

$$ \alpha(w_{i-N+1}^{i-1})=\sum_{w:C(w_{i-N+1}^{i-1},w)\gt0}\frac{C(w_{i-N+1}^{i-1},w)-C^*(w_{i-N+1}^{i-1},w)}{C(w_{i-N+1}^{i-1})} $$

where $C^*()$ is the discounted count.

The final part is to put all the steps above together.

## Find next word
## Return a list of predicted next words according to previous 2 user input words
## - xy: character vector containing user-input bigram, separated by a space
## - words_num: number of candidates of next words returned
Find_Next_word <- function(xy, words_num) {
  xy <- gsub(" ", "_", xy)
  if (length(which(BiFreq$term == xy)) > 0) {  # C(x,y) > 0
    ## N-grams preparation
    # Retrieve all observed trigrams beginning with xy: OT
    ObsTriG <-, TriFreq)
    y <- str_split_fixed(xy,"_", 2)[,2]
    # Retrieve all observed bigrams beginning with y: OB
    ObsBiG <-, BiFreq)
    # Retrieve all unigrams end the unobserved bigrams UOBT: z where C(y,z) = 0, UOB in UOT
    UnObsBiTails <- get.unobs.Ngram.tails(ObsBiG, 2)
    # Exclude observed bigrams that also appear in observed trigrams: OB in UOT
    ObsBiG <- ObsBiG[!str_split_fixed(ObsTriG[,term], "_", 2)[,2], on="term"]

    ## Calculation part
    # Calculate probabilities of all observed trigrams: P^*(z|x,y)
    ObsTriG <- cal.obs.prob(ObsTriG, BiFreq, xy)
    # Calculate Alpha(x,y)
    Alpha_xy <- cal.alpha(ObsTriG, BiFreq, xy)
    # Calculate probabilities of all observed bigrams: P^*(z|y), (y,z) in UOT
    ObsBiG <- cal.obs.prob(ObsBiG, UniFreq, y)
    # Calculate Alpha(y)
    Alpha_y <- cal.alpha(ObsBiG, UniFreq, y)
    # Calculate P_{ML}(z), where c(y,z) in UOB: Alpha_y * P_{ML}(z)
    UnObsBiTails[, Prob:=UniFreq[UnObsBiTails, c, on=.(term)]/UniFreq[UnObsBiTails, sum(c), on=.(term)]]
    UnObsBiTails[, Prob:=Alpha_xy*Alpha_y*Prob]
    # Remove unused column in ObsTriG and ObsBiG
    ObsTriG[, c("c", "cDis"):=NULL]
    ObsTriG[, term:=str_remove(ObsTriG[, term], "([^_]+_)+")]
    ObsBiG[, c("c", "cDis"):=NULL]
    ObsBiG[, term:=str_remove(ObsBiG[, term], "([^_]+_)+")]
    # Compare OT, Alpha_xy * P_{Katz}(z|y)
    # P_{Katz}(z|y) = 1. P^*(z|y), 2. Alpha_y * P_{ML}(z)
    AllTriG <- setorder(rbind(ObsTriG, ObsBiG, UnObsBiTails), -Prob)
    return(AllTriG[Prob!=0][1:min(dim(AllTriG[Prob!=0])[1], words_num)])
  } else {  # C(x,y) = 0
    y <- str_split_fixed(xy,"_", 2)[,2]
    # c(y>0)
    if (length(which(UniFreq$term == y)) > 0) {
      # Retrieve all observed bigrams beginning with y: OB
      ObsBiG <-, BiFreq)
      # Calculate probabilities of all observed bigrams: P^*(z|y)
      ObsBiG <- cal.obs.prob(ObsBiG, UniFreq, y)
      # Calculate Alpha(y)
      Alpha_y <- cal.alpha(ObsBiG, UniFreq, y)
      # Retrieve all unigrams end the unobserved bigrams UOBT: z where C(y,z) = 0
      UnObsBiTails <- get.unobs.Ngram.tails(ObsBiG, 2)
      # Calculate P_{ML}(z), where c(y,z) in UOB: Alpha_y * P_{ML}(z)
      UnObsBiTails[, Prob:=UniFreq[UnObsBiTails, c, on=.(term)]/UniFreq[UnObsBiTails, sum(c), on=.(term)]]
      UnObsBiTails[, Prob:=Alpha_y*Prob]
      # Remove unused column in ObsBiG
      ObsBiG[, c("c", "cDis"):=NULL]
      ObsBiG[, term:=str_remove(ObsBiG[, term], "([^_]+_)+")]
      AllBiG <- setorder(rbind(ObsBiG, UnObsBiTails), -Prob)
    } else {  # c(y=0)
      # P^*z
      return(setorder(UniFreq, -cDis)[1:words_num,.(term, Prob=cDis/UniFreq[,sum(c)])])  

The Result

We are going to test the prediction result with some sample inputs. Note that user input will be preprocessed, i.e. removing elements that are not going to be handled in the prediction model, same as building the training dataset, before passing it to the prediction model.

## Remove elements not being used by prediction model
Preprocess <- function(wordseq) {
  names(wordseq) <- NULL
  quest <- wordseq %>% tokens(remove_numbers=T, remove_punct=T, remove_symbols=T, remove_hyphens=T, remove_twitter=T, remove_url=T) %>% tokens_remove(stopwords("en")) %>% tokens_tolower()
  return(paste(tail(quest$text1, 2), collapse = " "))

Next_word <- function(prephrase, words_num=5) {
  bigr <- Preprocess(prephrase)
  result <- Find_Next_word(bigr, words_num)
  if (dim(result)[1] == 0) {
    rbind(result, list("<Please input more text>", 1))

Note that each of the text inputs below contain only several words since a trigram model only consider the last 2 words at the end of the input.

Next_word("He likes to eat ice")
##     term      Prob
## 1: cream 0.9558824
Next_word("the prime minister")
##        term       Prob
## 1:    david 0.03930818
## 2: benjamin 0.03144654
## 3:    najib 0.02830189
## 4:     said 0.02044025
## 5:    datuk 0.01729560
Next_word("a nuclear power")
##        term         Prob
## 1:    plant 0.3120567376
## 2:   plants 0.1773049645
## 3: stations 0.0250246165
## 4:     play 0.0001301713
## 5:   outage 0.0001032972


The text prediction model built in this project is obviously far from sophisticated, as it does not handle the start and end of the sentences, ignores all the symbols which actually contains hints or meanings in a sentence, and all the “stopwords” are also removed without further processing. However, it is also a good starting point to experiment methods to improve the prediction performance in terms of speed, memory usage, accuracy, and even interface, etc.

As suggested in Part 2, there are other smoothing and backoff algorithms worth trying, and they can incorporate with other machine learning algorithms such as recurrent neural network to build a neural network language model (NNLM) utilising long short-term memory (LSTM).

Finally, the Coursera Capstone project gives me a taste on how to tackle a real-life question from scratch. The process of how to investigate the data, analyze the problem and related methods, as well as build an application and implement the algorithms of the solution, all of these give me a solid experience to tackle other problems in the future.

Leo Mak
Make the world a better place, piece by piece.
comments powered by Disqus