Setting up a simple GridSearchCV in Scikit-learn

 In most machine learning projects we need to train the model with different model parameters so as to get best results out of the model. Also, we need to divide the training datasets into training and validation datasets while training the model. This  method improves the machine learning model while avoiding over-fitting. GridSearchCV in Scikit-learn package provides all these functionalities. Using GridSearchCV , we can make several training and validation sets automatically without needing to code it ourselves. We will see here how to set up a simple GridSearchCV using k-neighborhood search algorithm.

 

Let's say we have training dataset (train_X) and test dataset (train_y) and we want to set up a grid search for training the k-neighborhood search model for optimizing the best value for number of neighbors (k). We also want to divide the training dataset into training and validation sets.

First we import the required packages.

from sklearn.neighbors import KNeighborsClassifier

from sklearn.model_selection import GridSearchCV

from sklearn.metrics import accuracy_score

 

Define the model

 
model = KNeighborsClassifier()

 

Define which values of k we want to test


ks = list(range(1,11))

 

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Define grid parameters
grid_params = dict(n_neighbors=ks)
 
Define GridSearchCV
grid = GridSearchCV(model, grid_params, cv=10, scoring='accuracy',
return_train_score=False,verbose=1)

The parameter cv=10 divides the training datasets into ten parts. For first iteration of cross-validation, nine parts joined are taken as training set and one part is used as validation set and model performance is tested on the validation set.. For each iteration of cross-validation, one part of the divided data is taken as validation set.

Finally we train the model by following command:
grid_search = grid.fit(train_X, train_y)
This will  test all the values of k on all cross-validation sets. 

Popular posts from this blog

Principal Coordinate analysis in R and python

Principal Coordinate Analysis (PCoA) in R