Multiple-k: Picking the number of folds for cross-validation

Ludvig Renbo Olsen



When performing cross-validation, we tend to go with the common 10 folds (k=10). In this vignette, we try different number of folds settings and assess the differences in performance. To make our results robust to this choice, we average the results of different settings.   The functions of interest are cross_validate_fn() and groupdata2::fold().  
Contact the author at  


When performing cross-validation, it is common to use 10 folds. Why? It is the common thing to do of course! Not 9 or 11, but 10, and sometimes 5, and sometimes n-1 folds (i.e. leave-one-out cross-validation).

While having a standard setting means one less thing to worry about, let’s spend a few minutes discussing this choice. Whether it is reasonable comes down to the context of course, but these are some general thoughts on the topic:

In the plot below, some generated data has been split into 3 (left) and 10 (right) folds. Each line represents the best linear model for one of the folds (i.e. the model that would have the lowest prediction error when testing on that fold). When k=3, a single fold with a highly different distribution from the other two folds can have a big impact on the cross-validated prediction error. When k=10, a few of the folds may differ greatly as well, but on average, the model will be closer to the model that overall reduces the prediction error the most:

Note that this picture changes with different random seeds. To check whether the lower number of folds indeed tend to give higher prediction errors, we run this 100 times and average the results. That is, we randomly generate 100 datasets and cross-validate a linear model (y ~ x) on each of them. We then average the RMSE (Root Mean Square Error) and MAE (Mean Absolute Error) to get the following results:

#> # A tibble: 2 x 3
#>   `Fold Column`  RMSE   MAE
#>   <fct>         <dbl> <dbl>
#> 1 k = 3         0.976 0.766
#> 2 k = 10        0.902 0.758

Both the RMSE and MAE are higher in the k=3 setting. As a matter of fact, this was the case in 91% of the runs. This supports, that (on average) the prediction error should be lower with a larger k. Let’s see a violin plot of the simulations as well:

So… Why not just always use the highest possible number of folds?

A higher number of folds means training a lot more models, which can be computationally heavy and time-consuming. So finding a lower k that yields a similar prediction error most of the time, can be very useful. For the rest of this vignette, we won’t go in-depth with such limited-resources scenarios though.

Does it matter?

We might consider whether this even matters? If the goal of our cross-validation is to compare a set of models and then choose the best one, what matters the most is whether the same model would be picked with different settings of k. But how can we make sure that the same model is selected without trying multiple settings? And if it is not, which result do we choose (without cherry-picking)?

An approach to minimizing the effect of k on our model selection could be to to run the cross-validation with multiple ks and then average the results. In general, repeated cross-validation (where we average over results from multiple fold splits) is a great choice when possible, as it is more robust to the random fold splits. In the next section, we will run repeated cross-validation with different k settings and plot the results.

Multiple-k repeated cross-validation

Our goal here is two-fold:

  1. Try multiple values of k (different numbers of folds) and see the effect on the prediction error.

  2. Repeat each scenario multiple times to get more robust results.

Whereas the previous section used a regression example (continuous y-variable), we will now perform multiclass classification on the iris dataset. This fairly well-known dataset has three species of iris flowers with 50 flowers from each species. The predictors are length and width measurements of the sepals and petals.

First, we attach the needed packages and set a random seed:

library(cvms)  # version >= 1.2.2 
library(groupdata2)  # version >= 1.4.1


As the fold creation and cross-validation will take some time to run, we can enable parallelization to speed up the processes:

# Enable parallelization
# NOTE: Uncomment to run
# library(doParallel)
# doParallel::registerDoParallel(6)

Now, we load the data and convert it to a tibble:

# Load iris

# Convert iris to a tibble
iris <- dplyr::as_tibble(iris)
#> # A tibble: 150 x 5
#>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#>           <dbl>       <dbl>        <dbl>       <dbl> <fct>  
#>  1          5.1         3.5          1.4         0.2 setosa 
#>  2          4.9         3            1.4         0.2 setosa 
#>  3          4.7         3.2          1.3         0.2 setosa 
#>  4          4.6         3.1          1.5         0.2 setosa 
#>  5          5           3.6          1.4         0.2 setosa 
#>  6          5.4         3.9          1.7         0.4 setosa 
#>  7          4.6         3.4          1.4         0.3 setosa 
#>  8          5           3.4          1.5         0.2 setosa 
#>  9          4.4         2.9          1.4         0.2 setosa 
#> 10          4.9         3.1          1.5         0.1 setosa 
#> # … with 140 more rows

We count the rows per species, to ensure everything is in order:

iris %>% 
#> # A tibble: 3 x 2
#>   Species        n
#>   <fct>      <int>
#> 1 setosa        50
#> 2 versicolor    50
#> 3 virginica     50

Creating folds

When creating the folds, we would like to balance them such that the distribution of the species are similar in all the folds. In groupdata2::fold(), this is possible with the cat_col argument. This also ensures that there’s at least 1 of each species in each fold. With this approach, our maximum number of folds becomes 50. The lowest number of folds we can meaningfully generate is technically 2, but that seems unlikely in practice, so we will set our lower limit at 3 (arbitrary yes, but I get to choose here!). We thus pick 10 ks in that range using the seq() function.

As we are interested in comparing the results at each k setting, we repeat each of the settings 3 times to have more robustness towards the randomness when splitting. You might want to increase this to 10 repetitions, but that increases running time too much for this tutorial. If you are only interested in the average results, you might not need to repeat each setting, as the multiple settings of k becomes a kind of repeated cross-validation in itself.

We pass this sequence of counts to the k argument in groupdata2::fold(). We must also set the num_fold_cols argument to match the length of our sequence. As explained previously, we set cat_col = "Species" to ensure a balanced distribution of the species in all folds. Finally, we enable parallelization to speed things up:

We see that the .folds_* columns have been added with the fold identifiers. We will need the names of the generated fold columns in a second so here’s my favorite approach to generating them with paste0():

groupdata2 has the tool summarize_group_cols() for inspecting the generated fold columns (and factors in general). We can use this to assure ourselves that the right number of folds were created in each of the fold columns:

Fold Column Num Groups Mean Rows Median Rows Std Rows IQR Rows Min Rows Max Rows
.folds_1 3 50.000000 51 1.7320508 1.50 48 51
.folds_4 8 18.750000 18 1.3887301 0.75 18 21
.folds_7 13 11.538462 12 1.1266014 0.00 9 12
.folds_10 19 7.894737 9 1.4867839 3.00 6 9
.folds_13 24 6.250000 6 0.8469896 0.00 6 9
.folds_16 29 5.172414 6 1.3645765 3.00 3 6
.folds_19 34 4.411765 3 1.5199212 3.00 3 6
.folds_22 40 3.750000 3 1.3155870 0.75 3 6
.folds_25 45 3.333333 3 0.9534626 0.00 3 6
.folds_28 50 3.000000 3 0.0000000 0.00 3 3

Performing cross-validation

Now we are ready to cross-validate a model on our data. We will use the e1071::svm() Support Vector Machine model function. To use this with cross_validate_fn(), we can use the included model_fn and predict_fn functions. We further need to specify the kernel and cost hyperparameters.

For more elaborate examples of cross_validate_fn(), see here.

We define an (arbitrary) set of formulas to cross-validate:

Now, we are ready to run the cross-validation! We pass our data, formulas, functions, hyperparameters and fold column names to cross_validate_fn() and specify that the type of task is multiclass classification (i.e. multinomial). We also enable parallelization.

NOTE: This number of fold columns and formulas requires fitting 3180 model instances. That can take a few minutes to run, depending on your computer. Unfortunately, I have not yet found a way to include a progress bar when running in parallel.

Above, we see the averaged results from all the fold columns. The last model formula seems to have performed the best as it has the highest Overall Accuracy, Balanced Accuracy, F1, and so on. If our objective was to average the results with different settings of k to increase robustness to that choice, we could stop now. In the following section, we will have a look at the results from the different settings of k.

Inspecting the effect of k

There is a list of nested tibbles (data frames) in the cross-validation output called Results. This has the results from each fold column. Let’s extract it and format it a bit.

Note: In regression tasks, the Results tibbles would have the results from each fold, from each fold column, but in classification we gather the predictions from all folds within a fold column before evaluation.

As this is currently a list of tibbles (one for each formula), we first name it by the model formulas and then combine the tibbles to a single data frame with dplyr::bind_rows().

We now have a single tibble where a new column (Formula) specifies what model the results came from. When plotting, the full model formula strings are a bit long though, so let’s convert them to something shorter:

This leaves us with the following data frame:

Currently, we lack the number of folds for each of the fold columns. We have stored those in the fold column statistics we looked at in the beginning (fold_stats), so let’s add them with dplyr::left_join():

We further add a column that indicates which repetition of the k setting a fold column is. This is done with the l_starts method in groupdata2::group(), which automatically starts a new group every time the value of a column changes. So c(".folds_1", ".folds_1", ".folds_2", ".folds_2") would give the groups c(1, 1, 2, 2). By first grouping the data frame by the number of folds, these group indices start over for each setting of k. Note though, that this assumes that the Fold Column is sorted correctly. We should also remember to remove the grouping in the end, if we don’t need it in the following step.

Fold Column Num Folds Formula Repetition
.folds_1 3 SL+SW 1
.folds_1 3 PL+PW 1
.folds_1 3 SL+SW+PL+PW 1
.folds_1 3 SLSW+PLPW 1
.folds_2 3 SL+SW 2
.folds_2 3 PL+PW 2
.folds_2 3 SL+SW+PL+PW 2
.folds_2 3 SLSW+PLPW 2
.folds_3 3 SL+SW 3
.folds_3 3 PL+PW 3
.folds_3 3 SL+SW+PL+PW 3
.folds_3 3 SLSW+PLPW 3
.folds_4 8 SL+SW 1
.folds_4 8 PL+PW 1
.folds_4 8 SL+SW+PL+PW 1
.folds_4 8 SLSW+PLPW 1
.folds_5 8 SL+SW 2
.folds_5 8 PL+PW 2
.folds_5 8 SL+SW+PL+PW 2
.folds_5 8 SLSW+PLPW 2

To plot the average lines for each formula, we calculate the average Balanced Accuracy for each formula, for each number of folds setting:

With the data ready, we plot the effect of k on the Balanced Accuracy metric. Feel free to plot one of the other metrics as well!

For this dataset, the ranking of the models seems somewhat stable, although the PL+PW and SL+SW+PL+PW models are so close that it might switch their ranking at times. The variation in the three points at each k setting (for each formula) shows why repeated cross-validation is a good idea when possible. Without it, the random split can have a much bigger effect on the results (and potentially model ranking). By relying on the average of multiple k settings and repetitions, our results are more robust to fluctuations.

If you don’t want to run all these models always (e.g. in production), running this analysis at the beginning (and perhaps once in a while, in case of data drift) might help you check whether the choice of k makes a difference with your type of data.

Class level results

Finally, let’s have a look at the class level fold column results from the one-vs-all evaluations. These describe how well a model did for each of the species, for each of the fold columns. We will make a plot similar to the above to see whether the k setting affects the class level performance.

The cross-validation output has a nested tibble called Class Level Results for each formula. Within such tibble, we find another nested tibble (Results) that contains the results for each fold column, for each species.

To get the class level fold column results for the best model (i.e. Sepal.Length * Sepal.Width + Petal.Length * Petal.Width), we thus first get the Class Level Results for this (fourth) model and then extract the Results from that.

This gives us a list with the fold column results for each species which we concatenate to a single data frame.

Again, we add the repetition column, using the l_starts method in groupdata2::group():

To plot the average lines for each formula, we calculate the average Balanced Accuracy for each class, for each number of folds setting:

Now, we can plot the Balanced Accuracy by the number of folds:

The Versicolor and Virginica species seem affected by the number of folds. They both have slightly lower average balanced accuracies at the lower k settings.

In this vignette, we have covered the choice of the “number of folds” setting when using cross-validation. We have discussed why larger k settings should give lower prediction errors on average and shown how to make results robust to this setting by averaging over a range of k values. The groupdata2::fold() and cvms cross-validation functions enable this type of analysis.

This concludes the vignette. If elements are unclear or you need help to apply this to your context, you can leave feedback in a mail or in a GitHub issue :-)