from rdkit.Chem import MolFromSmiles, MolToSmiles
from fastai.vision.all import *
from rdkit.Chem import AllChem
from rdkit import Chem
import numpy as np
import pandas as pd
import IPython.display
from PIL import Image

Hi! This is will be the first post from my new blog. I hope you guys enjoy it.

I'm doing the Fast.AI course and I decided to try solve some problems from my field of research (Cheminformatics) using the fastai package. In this blog, I will try to follow the chapters from the Deep Learning for Coders with fastai & Pytorch. For instance, this notebook was inspired by chapter 2 (the notebook version can be found here)

Let's begin!

So, what are we going to do?

In this notebook we will try to solve a quantitative structure-activity relationship QSAR problem. What does this mean? Well, QSAR is a major aspect of Cheminformatics. The goal of QSAR is to find useful information in chemical and biological data and use this information for inference on new data. For example, we might want to predict if a molecule is active on a particular biological target. We could start with a dataset of experimentally measured bioactivities (e.g. $IC_{50}$ values, inhibition constants etc) and train a model for bioactivity prediction. This model can be used to predict the bioactivity of other molecules of interest. By using QSAR, BigPharma and research groups can generate new hypothesis much faster and cheaper than testing a bunch of molecules in vitro.

Traditionally, machine learning methods such as random forest, support vector machines and gradient boosting dominate the field. That's because these classical methods usually give very good results for a range of datasets and are quite easy to train. Until recently, researchers did not apply deep learning in large scale to bioactivity prediction. When they did, it was usually in the form of fully connected neural with just 2-5 layers. However, the last 5 years saw a BOOM in the number of publications using deep learning in very interesting ways! For example, recurrent neural networks are being applied to generate molecules, convnets are showing SOTA performance on binding affinity and pose prediction and multi-task learning was used successfully to win a Kaggle competition for bioactivity prediction!

The most common type of data for QSAR is tabular. Researchers usually calculate many chemical features to describe a collection of molecules. As an example, one of the most common consists in a binary vector indicating the presence/absence of chemical groups in a molecule. We can then use this fingerprint to train a macihine learning model for bioactivity prediction.

In this notebook we will use a different strategy. Instead of calculating a bunch of vectors, we'll convert each molecule to an image and feed that input to a neural network! As Jeremy said in the book:

Another point to consider is that although your problem might not look like a computer vision problem, it might be possible with a little imagination to turn it into one. For instance, if what you are trying to classify are sounds, you might try converting the sounds into images of their acoustic waveforms and then training a model on those images.

Let's try that!

How can we convert molecules to images?

In reality, we are not going to convert molecules to images of molecules. What we actually need is a way to represent molecules the same way as images. That way consists of using arrays. For example, an image can be represented as 3D array of shape $(W, H, C)$, where $W$ is the width, $H$ is the height and $C$ is the number of channels. If we could do that to molecules, then it would be straightforward to use it as input to a model.

There are many ways to do that, but we are going to use one that I think is very interesting. In 2017, Garrett B. Goh, Charles Siegel, Abhinav Vishnu, Nathan O. Hodas and Nathan Baker published a preprint showing that machine learning models actually don't need to know much about chemistry or biology to make a prediction!

In their original manuscript, the authors called their model Chemception and showed that using very, very simple image-like inputs it was possible to achieve SOTA performance on some public datasets. That's quite an achievement! Until yesterday, the cheminformatics community was using handcrafted features and now it seems we don't even need to tell many things about molecules to train a predictive model!

As the author mentioned in the preprint:

In addition, it should be emphasized that no additional source of chemistry-inspired features or inputs, such as molecular descriptors or fingerprints were used in training the model. This means that Chemception was not explicitly provided with even the most basic chemical concepts like “valency” or “periodicity”.

This means that Chemception had to learn A LOT about chemistry from scratch, using only not very informative inputs (to humans, at least)!

I really find this amazing!

Just to clarify:I'm not saying the model would be useful in real settings. But it is quite amazing to see a good performance without using elaborate chemical descriptors.

The Chemception model is a convolutional neural network for QSAR tasks. An overview of their method is shown below:

Load data

mols = pd.read_csv('/home/marcossantana/Documentos/GitHub/fiocruzcheminformatics/_data/fxa_ic50_processed.csv',sep=';')
mols.head(2)
doc_id standard_value standard_type standard_relation pchembl molregno canonical_smiles chembl_id target_dictionary target_chembl_id l1 l2 l3 confidence_score act processed_smiles is_valid
0 47181 1.5 IC50 = 8.82 459679 COc1ccc(NC(=O)c2ccc(C(=N)N(C)C)cc2)c(C(=O)Nc2ccc(Cl)cn2)c1 CHEMBL512351 194 CHEMBL244 Enzyme Protease Serine protease 8 Active COc1ccc(NC(=O)c2ccc(C(=N)N(C)C)cc2)c(C(=O)Nc2ccc(Cl)cn2)c1 False
1 30088 29000.0 IC50 = 4.54 655811 Cc1ccc(Oc2nc(Oc3cccc(C(=N)N)c3)c(F)c(NC(C)CCc3ccccc3)c2F)c(C(=O)O)c1 CHEMBL193933 194 CHEMBL244 Enzyme Protease Serine protease 9 Inactive Cc1ccc(Oc2nc(Oc3cccc(C(=N)N)c3)c(F)c(NC(C)CCc3ccccc3)c2F)c(C(=O)O)c1 False

The chemcepterize_mol function below will take care of converting the molecules SMILES strings (a one-line representation of the chemical structure) to the image-like format that we want.

Our workflow will go like this: first we define an embedding dimension and a resolution. You can think of this is a black canvas where the structures will be plotted and each atom will have a resolution consisting of how many pixels will be used to represent it. The dimensions of the canvas will be given by:

$$DIM = \frac{EMBED*2}{RES}$$

where $EMBED$ is the embedding size and $RES$ the resolution

The next step consists of calculating some basic chemical information from the structure, such as bond order, charges, atomic numbers and the hybridization states. This information will be converted into a matrix of shape $(P, DIM, DIM)$, where $P$ is the number of properties or channels in the image (in this case we will use 3) and $DIM$ is the dimension of the canvas.

The properties we will calculated are shown in the figure below from the original manuscript.

def chemcepterize_mol(mol, embed=20.0, res=0.5):
    dims = int(embed*2/res)
    #print(dims)
   
    #print(mol)
    #print(",,,,,,,,,,,,,,,,,,,,,,")
    cmol = Chem.Mol(mol.ToBinary())
    #print(cmol)
    #print(",,,,,,,,,,,,,,,,,,,,,,")
    cmol.ComputeGasteigerCharges()
    AllChem.Compute2DCoords(cmol)
    coords = cmol.GetConformer(0).GetPositions()
    #print(coords)
    #print(",,,,,,,,,,,,,,,,,,,,,,")
    vect = np.zeros((dims+2,dims+2,4)) # I added 2 pixels on to height and width because this function sometimes does not work if the molecule is too big.
    
    #Bonds first
    for i,bond in enumerate(mol.GetBonds()):
        bondorder = bond.GetBondTypeAsDouble()
        bidx = bond.GetBeginAtomIdx()
        eidx = bond.GetEndAtomIdx()
        bcoords = coords[bidx]
        ecoords = coords[eidx]
        frac = np.linspace(0,1,int(1/res*2)) #
        for f in frac:
            c = (f*bcoords + (1-f)*ecoords)
            idx = int(round((c[0] + embed)/res))
            idy = int(round((c[1]+ embed)/res))
            #Save in the vector first channel
            vect[ idx , idy ,0] = bondorder
            
    #Atom Layers
    for i,atom in enumerate(cmol.GetAtoms()):
            idx = int(round((coords[i][0] + embed)/res))
            idy = int(round((coords[i][1]+ embed)/res))
            #Atomic number
            vect[ idx , idy, 1] = atom.GetAtomicNum()
            
            #Hybridization
            hyptype = atom.GetHybridization().real
            vect[ idx , idy, 2] = hyptype
            
            #Gasteiger Charges
            charge = atom.GetProp("_GasteigerCharge")
            vect[ idx , idy, 3] = charge

            
    return Tensor(vect[:, :, :3].T) # We will omit the last dimension just to fit our fastai models. But you can also adapt the architeture to deal with 4 or more channels.

Note: an embedding of 32 will give a 128 x 128 canvas. We can use this correlation to generate images of any size. (224 * 32)/128.

First, we will create a column called mol that maps our molecular structures to rdkit.Chem.rdchem.Mol objects. This is essential because we are going to use Rdkit to calculate everything and rdkit.Chem.rdchem.Mol has a bunch of nice functionalities to work with molecular graphs.

mol= MolFromSmiles('c1ccccc1') # A rdkit.Chem.rdchem.Mol object representing benzene
type(mol)
rdkit.Chem.rdchem.Mol
mols['mol'] = mols['processed_smiles'].apply(MolFromSmiles)

Now we are going to vectorize our molecules and transform them to image-like matrices.But first, let's test our function.

def vectorize(mol, embed, res):
    return chemcepterize_mol(mol, embed=embed, res=res)
v = vectorize(mols["mol"][0],embed=56, res=0.5).T.numpy()

Ok, now let's see what that does!

As you can see, the image is mostly black space and the molecule is just a tiny, tiny part of it (shown in red). The black spaces have no chemical information at all! That's why the authors said Chemception had to learn everything from scratch!

plt.show(print(v.shape))
plt.imshow(v)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(226, 226, 3)
<matplotlib.image.AxesImage at 0x7fac6c69b590>

We can make bigger molecules by reducing the embedding size. But beware that will also reduce the total image size.

larger_img = vectorize(mols["mol"][0],embed=16, res=0.5).T.numpy()
plt.show(print((larger_img.shape)))
plt.imshow(larger_img)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
(66, 66, 3)
<matplotlib.image.AxesImage at 0x7fac6c601150>
mols["molimage"] = mols["mol"].apply(partial(vectorize, embed=32, res=0.25))
plt.imshow(mols["molimage"][0].T.numpy())
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7fac6c51ce90>

Training our model

Now that we know how to convert the molecules to the desired format, we are ready to train a model using fastai! In order to do that, we need to define three things

1) How to split data

First, let's define how to split our data into training and validation sets. Luckly, our dataset comes with a column called "is_valid" showing which molecules should be used for validation and training. Therefore, we will use the ColSplitter class from fastai to get the indeces. We could also do a random split here or any kind at all. Fastai is very flexible!

splits = ColSplitter('is_valid')(mols)
splits
((#1703) [0,1,2,3,4,5,6,7,8,9...],
 (#426) [1703,1704,1705,1706,1707,1708,1709,1710,1711,1712...])

2) Define how to get the items

We need to tell fastai how to get the items that we'll feed to our model. In this case, we will use the images we created and stored in the "molimage" column of our dataframe.

x_tfms = ColReader('molimage')

3) Define the targets

Now we need to tell fastai where are our targets. In this case, our targets are in the column "act", showing the bioactivity of each molecule. We will also tell fastai to treat the values of this column as categories, which will be used to train a classification model.

y_tfms = [ColReader('act'),Categorize]

The fastai book uses the DataBlock functionality to create the dataloaders at Chapter 2 and Jeremy says that we actually need four things

  1. What kind of data to work with;
  2. How to get the items;
  3. How to label these items;
  4. How to create a validation set.

But since we are using a custom data type, we'll skip the step defining the kind of data.

Create dataset

Now we can create our dataset:

mol_dataset = Datasets(mols,[x_tfms,y_tfms], splits=splits)

x,y = mol_dataset[0]
x.shape,y
(torch.Size([3, 258, 258]), TensorCategory(0))

Create dataloaders

dls = mol_dataset.dataloaders(batch_size=8)

Let's inspect one batch to see if everything is already:

x,y = dls.one_batch()
x,y

(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         ...,
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0'),
 TensorCategory([0, 0, 1, 1, 0, 1, 1, 0], device='cuda:0'))

It seems everything is in order. Alright! Let's train this beast!

Fit

metrics = [Recall(pos_label=1),Recall(pos_label = 0), Precision(pos_label = 0), MatthewsCorrCoef()]
learn = cnn_learner(dls, resnet18, metrics=metrics,ps=0.25)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/marcossantana/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth

Note: We are going to use Restnet18 instead of the original Inception architeture.
learn.fine_tune(10)
epoch train_loss valid_loss recall_score recall_score precision_score matthews_corrcoef time
0 1.108745 1.240608 0.331169 0.716912 0.654362 0.050385 00:15
epoch train_loss valid_loss recall_score recall_score precision_score matthews_corrcoef time
0 0.754158 0.685205 0.292208 0.897059 0.691218 0.241307 00:18
1 0.698066 0.897160 0.201299 0.786765 0.635015 -0.014106 00:18
2 0.630345 0.794150 0.396104 0.750000 0.686869 0.152768 00:18
3 0.607187 0.601677 0.512987 0.794118 0.742268 0.317116 00:18
4 0.556304 0.543766 0.538961 0.830882 0.760943 0.386714 00:19
5 0.392019 0.590917 0.558442 0.849265 0.772575 0.428208 00:18
6 0.336817 0.561924 0.487013 0.882353 0.752351 0.409180 00:19
7 0.271662 0.635271 0.461039 0.882353 0.743034 0.385314 00:19
8 0.192005 0.622967 0.577922 0.834559 0.777397 0.426781 00:19
9 0.189609 0.632492 0.636364 0.819853 0.799283 0.461058 00:18

It seems everything went pretty well! In addition, the Matthew's correlation coefficient is quite decent (~0.46).

Interpretation

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

The performance is pretty decent, if without any kind of data augmentation and hyperparameter optimization.

Let's try with a little bit of data augmentation.

In the original paper, the authors used random rotations of the images. Why is that? Well, since most of the image is empty space, if we distort it even a little bit, by cropping or squishing, it might completly change the molecule represented. Rotating the image is a solution to the data augmentation problem because in this particular case it won't change the meaning of our images.

dls = mol_dataset.dataloaders(batch_size=8,after_batch=Rotate(max_deg=180))
learn = cnn_learner(dls, resnet18, metrics=metrics,ps=0.25)
learn.fine_tune(10)
epoch train_loss valid_loss recall_score recall_score precision_score matthews_corrcoef time
0 1.085553 0.814067 0.428571 0.746324 0.697595 0.180596 00:13
epoch train_loss valid_loss recall_score recall_score precision_score matthews_corrcoef time
0 0.740874 0.696819 0.500000 0.742647 0.724014 0.245222 00:18
1 0.703036 0.758111 0.461039 0.720588 0.702509 0.183554 00:19
2 0.653901 0.586844 0.383117 0.867647 0.712991 0.289424 00:18
3 0.637761 0.676066 0.272727 0.794118 0.658537 0.076307 00:19
4 0.537117 0.690890 0.435065 0.768382 0.706081 0.212265 00:18
5 0.483605 0.649515 0.448052 0.860294 0.733542 0.341583 00:19
6 0.348445 0.609656 0.500000 0.867647 0.753994 0.400096 00:19
7 0.269874 0.642070 0.571429 0.827206 0.773196 0.411629 00:18
8 0.203648 0.646935 0.506494 0.871324 0.757188 0.411164 00:19
9 0.198849 0.674963 0.506494 0.871324 0.757188 0.411164 00:18
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Exporting the model

Now that we trained the model, we can export it and use it for inference.

learn.export()
learn_inf = load_learner('export.pkl')
learn_inf.predict(mols['molimage'][2])
('Inactive', tensor(1), tensor([0.0191, 0.9809]))

Fin