hiitsmeme
added grover code, hf api files
f986893
"""
The molecule dataset for finetuning.
This implementation is adapted from
https://github.com/chemprop/chemprop/blob/master/chemprop/data/data.py
"""
import random
from argparse import Namespace
from typing import Callable, List, Union
import numpy as np
from rdkit import Chem
from torch.utils.data.dataset import Dataset
from grover.data.molfeaturegenerator import get_features_generator
from grover.data.scaler import StandardScaler
class MoleculeDatapoint:
"""A MoleculeDatapoint contains a single molecule and its associated features and targets."""
def __init__(self,
line: List[str],
args: Namespace = None,
features: np.ndarray = None,
use_compound_names: bool = False):
"""
Initializes a MoleculeDatapoint, which contains a single molecule.
:param line: A list of strings generated by separating a line in a data CSV file by comma.
:param args: Arguments.
:param features: A numpy array containing additional features (ex. Morgan fingerprint).
:param use_compound_names: Whether the data CSV includes the compound name on each line.
"""
self.features_generator = None
self.args = None
if args is not None:
if hasattr(args, "features_generator"):
self.features_generator = args.features_generator
self.args = args
if features is not None and self.features_generator is not None:
raise ValueError('Currently cannot provide both loaded features and a features generator.')
self.features = features
if use_compound_names:
self.compound_name = line[0] # str
line = line[1:]
else:
self.compound_name = None
self.smiles = line[0] # str
# Generate additional features if given a generator
if self.features_generator is not None:
self.features = []
mol = Chem.MolFromSmiles(self.smiles)
for fg in self.features_generator:
features_generator = get_features_generator(fg)
if mol is not None and mol.GetNumHeavyAtoms() > 0:
if fg in ['morgan', 'morgan_count']:
self.features.extend(features_generator(mol, num_bits=args.num_bits))
else:
self.features.extend(features_generator(mol))
self.features = np.array(self.features)
# Fix nans in features
if self.features is not None:
replace_token = 0
self.features = np.where(np.isnan(self.features), replace_token, self.features)
# Create targets
self.targets = [float(x) if x != '' else None for x in line[1:]]
def set_features(self, features: np.ndarray):
"""
Sets the features of the molecule.
:param features: A 1-D numpy array of features for the molecule.
"""
self.features = features
def num_tasks(self) -> int:
"""
Returns the number of prediction tasks.
:return: The number of tasks.
"""
return len(self.targets)
def set_targets(self, targets: List[float]):
"""
Sets the targets of a molecule.
:param targets: A list of floats containing the targets.
"""
self.targets = targets
class MoleculeDataset(Dataset):
"""A MoleculeDataset contains a list of molecules and their associated features and targets."""
def __init__(self, data: List[MoleculeDatapoint]):
"""
Initializes a MoleculeDataset, which contains a list of MoleculeDatapoints (i.e. a list of molecules).
:param data: A list of MoleculeDatapoints.
"""
self.data = data
self.args = self.data[0].args if len(self.data) > 0 else None
self.scaler = None
def compound_names(self) -> List[str]:
"""
Returns the compound names associated with the molecule (if they exist).
:return: A list of compound names or None if the dataset does not contain compound names.
"""
if len(self.data) == 0 or self.data[0].compound_name is None:
return None
return [d.compound_name for d in self.data]
def smiles(self) -> List[str]:
"""
Returns the smiles strings associated with the molecules.
:return: A list of smiles strings.
"""
return [d.smiles for d in self.data]
def features(self) -> List[np.ndarray]:
"""
Returns the features associated with each molecule (if they exist).
:return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
"""
if len(self.data) == 0 or self.data[0].features is None:
return None
return [d.features for d in self.data]
def targets(self) -> List[List[float]]:
"""
Returns the targets associated with each molecule.
:return: A list of lists of floats containing the targets.
"""
return [d.targets for d in self.data]
def num_tasks(self) -> int:
"""
Returns the number of prediction tasks.
:return: The number of tasks.
"""
if self.args.dataset_type == 'multiclass':
return int(max([i[0] for i in self.targets()])) + 1
else:
return self.data[0].num_tasks() if len(self.data) > 0 else None
def features_size(self) -> int:
"""
Returns the size of the features array associated with each molecule.
:return: The size of the features.
"""
return len(self.data[0].features) if len(self.data) > 0 and self.data[0].features is not None else None
def shuffle(self, seed: int = None):
"""
Shuffles the dataset.
:param seed: Optional random seed.
"""
if seed is not None:
random.seed(seed)
random.shuffle(self.data)
def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler:
"""
Normalizes the features of the dataset using a StandardScaler (subtract mean, divide by standard deviation).
If a scaler is provided, uses that scaler to perform the normalization. Otherwise fits a scaler to the
features in the dataset and then performs the normalization.
:param scaler: A fitted StandardScaler. Used if provided. Otherwise a StandardScaler is fit on
this dataset and is then used.
:param replace_nan_token: What to replace nans with.
:return: A fitted StandardScaler. If a scaler is provided, this is the same scaler. Otherwise, this is
a scaler fit on this dataset.
"""
if len(self.data) == 0 or self.data[0].features is None:
return None
if scaler is not None:
self.scaler = scaler
elif self.scaler is None:
features = np.vstack([d.features for d in self.data])
self.scaler = StandardScaler(replace_nan_token=replace_nan_token)
self.scaler.fit(features)
for d in self.data:
d.set_features(self.scaler.transform(d.features.reshape(1, -1))[0])
return self.scaler
def set_targets(self, targets: List[List[float]]):
"""
Sets the targets for each molecule in the dataset. Assumes the targets are aligned with the datapoints.
:param targets: A list of lists of floats containing targets for each molecule. This must be the
same length as the underlying dataset.
"""
assert len(self.data) == len(targets)
for i in range(len(self.data)):
self.data[i].set_targets(targets[i])
def sort(self, key: Callable):
"""
Sorts the dataset using the provided key.
:param key: A function on a MoleculeDatapoint to determine the sorting order.
"""
self.data.sort(key=key)
def __len__(self) -> int:
"""
Returns the length of the dataset (i.e. the number of molecules).
:return: The length of the dataset.
"""
return len(self.data)
def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
"""
Gets one or more MoleculeDatapoints via an index or slice.
:param item: An index (int) or a slice object.
:return: A MoleculeDatapoint if an int is provided or a list of MoleculeDatapoints if a slice is provided.
"""
return self.data[idx]