For some time now I've been wondering how much chemistry is necessary to train a model for QSAR. Do we need hand-crafted descriptors, such as ECFP and MACCS fingerprints? Or maybe complex quantum mechanical properties calculated with state-of-the-art software?

On the last post, we used image-like input to train bioactivity classifier using fastai. But what if we tried with REAL images?

First, we will import some libraries, including rdkit to deal with molecules.

Import modules

from rdkit import Chem
from collections import Counter
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import rdDepictor
import os
import numpy as np
from PIL import Image, ImageOps
from io import BytesIO
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

Load data

moldf = pd.read_csv('/home/marcossantana/Downloads/fxa_ic50_processed.csv',sep=';').reset_index(drop=True)
moldf['filename'] = [f'mols_imgs/mol_{i}.png' for i in moldf.index]
moldf.head()

doc_id standard_value standard_type standard_relation pchembl molregno canonical_smiles chembl_id target_dictionary target_chembl_id l1 l2 l3 confidence_score act processed_smiles filename
0 3476 44.4 IC50 = 7.35 192068 N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O CHEMBL117716 194 CHEMBL244 Enzyme Protease Serine protease 8 Active N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O mols_imgs/mol_0.png
1 6512 180.0 IC50 = 6.75 203908 Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O CHEMBL337921 194 CHEMBL244 Enzyme Protease Serine protease 8 Active Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O mols_imgs/mol_1.png
2 6512 120.0 IC50 = 6.92 204329 Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O CHEMBL340500 194 CHEMBL244 Enzyme Protease Serine protease 8 Active Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O mols_imgs/mol_2.png
3 3476 311.0 IC50 = 6.51 192044 N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O CHEMBL117721 194 CHEMBL244 Enzyme Protease Serine protease 8 Active N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O mols_imgs/mol_3.png
4 3476 6.1 IC50 = 8.21 191486 N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O CHEMBL331807 194 CHEMBL244 Enzyme Protease Serine protease 8 Active N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O mols_imgs/mol_4.png
moldf.shape
(2129, 17)

Note: Never forget to separate a validation set!
train,valid = train_test_split(moldf,stratify=moldf['act'],test_size = 0.2)
train_idx, valid_idx = train.index, valid.index
del train, valid

Save images

Now we need a way to convert our molecules to images. We can use RDKit Draw class for that. I think the API is a quite complex, especially for non-coders. I'll not explain it in details here. Please, check the official RDKit documentation and this short tutorial about the new drawing code)

def generate_mol_images(mols):
    '''Generate images from a collection of molecules'''
    mols = list(map(Chem.MolFromSmiles, mols))
    print(len(mols))
    for i, mol in enumerate(tqdm(mols,total=len(mols))):
        mol2image(mol, filename=f'mol_{i}.png')
    #imgs = torch.stack(list(map(mol2image, mols)))

def mol2image(m,filename='',save_dir='mols_imgs'):
    '''Draw RDKit molecules as images'''
   # print(os.path.join(save_dir,filename))
    d2d = rdMolDraw2D.MolDraw2DCairo(400, 400) # or MolDraw2DSVG to get SVGs
    d2d.drawOptions().bondLineWidth=5 # bondLineWidth (if positive, this overrides the default line width for bonds)
    d2d.drawOptions().padding=0.05
    d2d.DrawMolecule(m)
    d2d.FinishDrawing()
    
    if os.path.exists(save_dir):
        path = os.path.join(save_dir, filename)
        d2d.WriteDrawingText(path)   
    else:
        os.mkdir(save_dir)
        path = os.path.join(save_dir, filename)
        d2d.WriteDrawingText(path)

The mol2image function receives as input a molecule and converts it to an image. Briefly, we set a canvas of shape 400 x 400 pixels and the define some drawing options using the drawOptions method; in this example we'll make the bonds larger (bondLineWidth = 5) and add a little bit of padding (padding = 0.05) to the canvas. This will give us a good baseline to easily draw molecules. You can play around with the drawing options if you wish.

After drawing, we save the images to a folder. Take a look at a sample:

Now let's save our images!

generate_mol_images(moldf['processed_smiles'].values)
2129

The ls method allows us to see the saved image files.

path = Path('mols_imgs')
path.ls()
(#2129) [Path('mols_imgs/mol_538.png'),Path('mols_imgs/mol_508.png'),Path('mols_imgs/mol_1741.png'),Path('mols_imgs/mol_628.png'),Path('mols_imgs/mol_265.png'),Path('mols_imgs/mol_1477.png'),Path('mols_imgs/mol_90.png'),Path('mols_imgs/mol_2056.png'),Path('mols_imgs/mol_1046.png'),Path('mols_imgs/mol_2099.png')...]

Datablock

In order to train our model, we first need to create a DataLoaders object containing our training and validation sets. First, we create our datablock, which is basically a blueprint that tells fastai how to get our dependent and independent variables, how to treat them (e.g. are they images, text or tabular? is the dependent variable categorical or continuous?) and some transformations, including resizing the images and applying augmentation methods to them.

d_block = DataBlock(blocks=(ImageBlock(), CategoryBlock()),
                   get_x=ColReader('filename'),
                   get_y = ColReader('act'),
                   item_tfms=Resize(256),
                   batch_tfms=Rotate(180),splitter=IndexSplitter(valid_idx))

The DataBlock API is very handy! We can inspect if our blueprint is alright but using the summary method to try to create a batch of our data. If it fails, we can see in which step. It's a very nice debugging strategy.

d_block.summary(moldf)

Setting-up type transforms pipelines
Collecting items from       doc_id  standard_value standard_type standard_relation  pchembl  \
0       3476            44.4          IC50                 =     7.35   
1       6512           180.0          IC50                 =     6.75   
2       6512           120.0          IC50                 =     6.92   
3       3476           311.0          IC50                 =     6.51   
4       3476             6.1          IC50                 =     8.21   
...      ...             ...           ...               ...      ...   
2124  109537          1023.0          IC50                 =     5.99   
2125  109537          3026.0          IC50                 =     5.52   
2126  109537          8451.0          IC50                 =     5.07   
2127  109537          5735.0          IC50                 =     5.24   
2128  109537             3.4          IC50                 =     8.47   

      molregno  \
0       192068   
1       203908   
2       204329   
3       192044   
4       191486   
...        ...   
2124   2333070   
2125   2319885   
2126   2333406   
2127   2325502   
2128    709600   

                                                                         canonical_smiles  \
0     N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O   
1                             Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O   
2                              Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O   
3            N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O   
4        N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O   
...                                                                                   ...   
2124               CC1(C)OCC([C@]2(C)C=C3CC[C@@H]4C(C)(C)[C@H](O)CC[C@@]4(C)[C@@H]3CC2)O1   
2125            CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)([C@@H](O)CO)C=C3CC[C@@H]2C1(C)C   
2126                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(CO)C=C3CC[C@@H]2C1(C)C   
2127                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(O)CC=C3CC[C@@H]2C1(C)C   
2128    CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C@@H]3NC(=O)C(=O)Nc3ccc(Cl)cn3)sc2C1   

          chembl_id  target_dictionary target_chembl_id      l1        l2  \
0      CHEMBL117716                194        CHEMBL244  Enzyme  Protease   
1      CHEMBL337921                194        CHEMBL244  Enzyme  Protease   
2      CHEMBL340500                194        CHEMBL244  Enzyme  Protease   
3      CHEMBL117721                194        CHEMBL244  Enzyme  Protease   
4      CHEMBL331807                194        CHEMBL244  Enzyme  Protease   
...             ...                ...              ...     ...       ...   
2124  CHEMBL4293622                194        CHEMBL244  Enzyme  Protease   
2125  CHEMBL4280434                194        CHEMBL244  Enzyme  Protease   
2126  CHEMBL4293958                194        CHEMBL244  Enzyme  Protease   
2127  CHEMBL4286054                194        CHEMBL244  Enzyme  Protease   
2128  CHEMBL1269025                194        CHEMBL244  Enzyme  Protease   

                   l3  confidence_score       act  \
0     Serine protease                 8    Active   
1     Serine protease                 8    Active   
2     Serine protease                 8    Active   
3     Serine protease                 8    Active   
4     Serine protease                 8    Active   
...               ...               ...       ...   
2124  Serine protease                 9  Inactive   
2125  Serine protease                 9  Inactive   
2126  Serine protease                 9  Inactive   
2127  Serine protease                 9  Inactive   
2128  Serine protease                 9    Active   

                                                                         processed_smiles  \
0     N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O   
1                             Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O   
2                              Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O   
3            N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O   
4        N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O   
...                                                                                   ...   
2124               CC1(C)OCC([C@]2(C)C=C3CC[C@@H]4C(C)(C)[C@H](O)CC[C@@]4(C)[C@@H]3CC2)O1   
2125            CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)([C@@H](O)CO)C=C3CC[C@@H]2C1(C)C   
2126                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(CO)C=C3CC[C@@H]2C1(C)C   
2127                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(O)CC=C3CC[C@@H]2C1(C)C   
2128    CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C@@H]3NC(=O)C(=O)Nc3ccc(Cl)cn3)sc2C1   

                    filename  
0        mols_imgs/mol_0.png  
1        mols_imgs/mol_1.png  
2        mols_imgs/mol_2.png  
3        mols_imgs/mol_3.png  
4        mols_imgs/mol_4.png  
...                      ...  
2124  mols_imgs/mol_2124.png  
2125  mols_imgs/mol_2125.png  
2126  mols_imgs/mol_2126.png  
2127  mols_imgs/mol_2127.png  
2128  mols_imgs/mol_2128.png  

[2129 rows x 17 columns]
Found 2129 items
2 datasets of sizes 1703,426
Setting up Pipeline: ColReader -- {'cols': 'filename', 'pref': '', 'suff': '', 'label_delim': None} -> PILBase.create
Setting up Pipeline: ColReader -- {'cols': 'act', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}

Building one sample
  Pipeline: ColReader -- {'cols': 'filename', 'pref': '', 'suff': '', 'label_delim': None} -> PILBase.create
    starting from
      doc_id                                                                     6512
standard_value                                                              120
standard_type                                                              IC50
standard_relation                                                             =
pchembl                                                                    6.92
molregno                                                                 204329
canonical_smiles     Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O
chembl_id                                                          CHEMBL340500
target_dictionary                                                           194
target_chembl_id                                                      CHEMBL244
l1                                                                       Enzyme
l2                                                                     Protease
l3                                                              Serine protease
confidence_score                                                              8
act                                                                      Active
processed_smiles     Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O
filename                                                    mols_imgs/mol_2.png
Name: 2, dtype: object
    applying ColReader -- {'cols': 'filename', 'pref': '', 'suff': '', 'label_delim': None} gives
      mols_imgs/mol_2.png
    applying PILBase.create gives
      PILImage mode=RGB size=400x400
  Pipeline: ColReader -- {'cols': 'act', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
    starting from
      doc_id                                                                     6512
standard_value                                                              120
standard_type                                                              IC50
standard_relation                                                             =
pchembl                                                                    6.92
molregno                                                                 204329
canonical_smiles     Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O
chembl_id                                                          CHEMBL340500
target_dictionary                                                           194
target_chembl_id                                                      CHEMBL244
l1                                                                       Enzyme
l2                                                                     Protease
l3                                                              Serine protease
confidence_score                                                              8
act                                                                      Active
processed_smiles     Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O
filename                                                    mols_imgs/mol_2.png
Name: 2, dtype: object
    applying ColReader -- {'cols': 'act', 'pref': '', 'suff': '', 'label_delim': None} gives
      Active
    applying Categorize -- {'vocab': None, 'sort': True, 'add_na': False} gives
      TensorCategory(0)

Final sample: (PILImage mode=RGB size=400x400, TensorCategory(0))


Collecting items from       doc_id  standard_value standard_type standard_relation  pchembl  \
0       3476            44.4          IC50                 =     7.35   
1       6512           180.0          IC50                 =     6.75   
2       6512           120.0          IC50                 =     6.92   
3       3476           311.0          IC50                 =     6.51   
4       3476             6.1          IC50                 =     8.21   
...      ...             ...           ...               ...      ...   
2124  109537          1023.0          IC50                 =     5.99   
2125  109537          3026.0          IC50                 =     5.52   
2126  109537          8451.0          IC50                 =     5.07   
2127  109537          5735.0          IC50                 =     5.24   
2128  109537             3.4          IC50                 =     8.47   

      molregno  \
0       192068   
1       203908   
2       204329   
3       192044   
4       191486   
...        ...   
2124   2333070   
2125   2319885   
2126   2333406   
2127   2325502   
2128    709600   

                                                                         canonical_smiles  \
0     N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O   
1                             Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O   
2                              Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O   
3            N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O   
4        N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O   
...                                                                                   ...   
2124               CC1(C)OCC([C@]2(C)C=C3CC[C@@H]4C(C)(C)[C@H](O)CC[C@@]4(C)[C@@H]3CC2)O1   
2125            CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)([C@@H](O)CO)C=C3CC[C@@H]2C1(C)C   
2126                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(CO)C=C3CC[C@@H]2C1(C)C   
2127                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(O)CC=C3CC[C@@H]2C1(C)C   
2128    CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C@@H]3NC(=O)C(=O)Nc3ccc(Cl)cn3)sc2C1   

          chembl_id  target_dictionary target_chembl_id      l1        l2  \
0      CHEMBL117716                194        CHEMBL244  Enzyme  Protease   
1      CHEMBL337921                194        CHEMBL244  Enzyme  Protease   
2      CHEMBL340500                194        CHEMBL244  Enzyme  Protease   
3      CHEMBL117721                194        CHEMBL244  Enzyme  Protease   
4      CHEMBL331807                194        CHEMBL244  Enzyme  Protease   
...             ...                ...              ...     ...       ...   
2124  CHEMBL4293622                194        CHEMBL244  Enzyme  Protease   
2125  CHEMBL4280434                194        CHEMBL244  Enzyme  Protease   
2126  CHEMBL4293958                194        CHEMBL244  Enzyme  Protease   
2127  CHEMBL4286054                194        CHEMBL244  Enzyme  Protease   
2128  CHEMBL1269025                194        CHEMBL244  Enzyme  Protease   

                   l3  confidence_score       act  \
0     Serine protease                 8    Active   
1     Serine protease                 8    Active   
2     Serine protease                 8    Active   
3     Serine protease                 8    Active   
4     Serine protease                 8    Active   
...               ...               ...       ...   
2124  Serine protease                 9  Inactive   
2125  Serine protease                 9  Inactive   
2126  Serine protease                 9  Inactive   
2127  Serine protease                 9  Inactive   
2128  Serine protease                 9    Active   

                                                                         processed_smiles  \
0     N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc3ccccc3n2)NS(=O)(=O)Cc2ccccc2)C1O   
1                             Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1-n1ccccc1=O   
2                              Cc1cc(NC(=O)Cc2ccc3[nH]c(-c4ccc(Cl)s4)nc3c2)ccc1N1CCOCC1=O   
3            N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CNC(=O)c2cnccn2)NS(=O)(=O)Cc2ccccc2)C1O   
4        N=C(N)N1CCC[C@H](NC(=O)CNC(=O)[C@@H](CCNC(=O)c2ccc(O)nc2)NS(=O)(=O)Cc2ccccc2)C1O   
...                                                                                   ...   
2124               CC1(C)OCC([C@]2(C)C=C3CC[C@@H]4C(C)(C)[C@H](O)CC[C@@]4(C)[C@@H]3CC2)O1   
2125            CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)([C@@H](O)CO)C=C3CC[C@@H]2C1(C)C   
2126                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(CO)C=C3CC[C@@H]2C1(C)C   
2127                     CC(=O)O[C@@H]1CC[C@@]2(C)[C@@H]3CC[C@](C)(O)CC=C3CC[C@@H]2C1(C)C   
2128    CN1CCc2nc(C(=O)N[C@@H]3C[C@@H](C(=O)N(C)C)CC[C@@H]3NC(=O)C(=O)Nc3ccc(Cl)cn3)sc2C1   

                    filename  
0        mols_imgs/mol_0.png  
1        mols_imgs/mol_1.png  
2        mols_imgs/mol_2.png  
3        mols_imgs/mol_3.png  
4        mols_imgs/mol_4.png  
...                      ...  
2124  mols_imgs/mol_2124.png  
2125  mols_imgs/mol_2125.png  
2126  mols_imgs/mol_2126.png  
2127  mols_imgs/mol_2127.png  
2128  mols_imgs/mol_2128.png  

[2129 rows x 17 columns]
Found 2129 items
2 datasets of sizes 1703,426
Setting up Pipeline: ColReader -- {'cols': 'filename', 'pref': '', 'suff': '', 'label_delim': None} -> PILBase.create
Setting up Pipeline: ColReader -- {'cols': 'act', 'pref': '', 'suff': '', 'label_delim': None} -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
Setting up after_item: Pipeline: Resize -- {'size': (256, 256), 'method': 'crop', 'pad_mode': 'reflection', 'resamples': (2, 0), 'p': 1.0} -> ToTensor
Setting up before_batch: Pipeline: 
Setting up after_batch: Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Rotate -- {'size': None, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 1.0}

Building one batch
Applying item_tfms to the first sample:
  Pipeline: Resize -- {'size': (256, 256), 'method': 'crop', 'pad_mode': 'reflection', 'resamples': (2, 0), 'p': 1.0} -> ToTensor
    starting from
      (PILImage mode=RGB size=400x400, TensorCategory(0))
    applying Resize -- {'size': (256, 256), 'method': 'crop', 'pad_mode': 'reflection', 'resamples': (2, 0), 'p': 1.0} gives
      (PILImage mode=RGB size=256x256, TensorCategory(0))
    applying ToTensor gives
      (TensorImage of size 3x256x256, TensorCategory(0))

Adding the next 3 samples

No before_batch transform to apply

Collating items in a batch

Applying batch_tfms to the batch built
  Pipeline: IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} -> Rotate -- {'size': None, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 1.0}
    starting from
      (TensorImage of size 4x3x256x256, TensorCategory([0, 0, 0, 0], device='cuda:0'))
    applying IntToFloatTensor -- {'div': 255.0, 'div_mask': 1} gives
      (TensorImage of size 4x3x256x256, TensorCategory([0, 0, 0, 0], device='cuda:0'))
    applying Rotate -- {'size': None, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 1.0} gives
      (TensorImage of size 4x3x256x256, TensorCategory([0, 0, 0, 0], device='cuda:0'))

No errors? Nice!

Now we can create our dataloaders.

dls = d_block.dataloaders(moldf,bs=32)
dls.show_batch(max_n=4)

Train model

learn = cnn_learner(dls, resnet34, metrics=MatthewsCorrCoef())
learn.fine_tune(10,base_lr = 5e-4, freeze_epochs=5)

epoch train_loss valid_loss matthews_corrcoef time
0 1.279663 0.927331 0.156844 00:20
1 1.133440 0.756082 0.219890 00:19
2 1.048430 0.702097 0.249759 00:19
3 0.950931 0.667921 0.327793 00:19
4 0.885842 0.627575 0.313518 00:18
epoch train_loss valid_loss matthews_corrcoef time
0 0.798137 0.593731 0.372278 00:26
1 0.741893 0.603974 0.332780 00:26
2 0.706047 0.553811 0.410596 00:26
3 0.661693 0.566272 0.448788 00:26
4 0.626662 0.556986 0.485671 00:26
5 0.585711 0.572504 0.400680 00:26
6 0.559231 0.550696 0.504990 00:25
7 0.548869 0.506300 0.508802 00:26
8 0.508134 0.506491 0.496422 00:25
9 0.481029 0.506984 0.485368 00:25

The matthew's correlation coefficient is ~0.50. Not bad at all!

learn.save('fit1')
Path('models/fit1.pth')
learn.export()

Interpretation

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

As we can see, it is possible to train a model using only images and get a decent MCC. We didn't even optimize the model or try different data augmentation methods.

I wonder how the model is actually making a prediction. Is it focusing on specific chemical groups? Can we apply interpretation tools used in computer vision? Those are interesting questions that we might take a look in future posts!

Do you have any ideas about possible improvements? Please, let me know!

Fin