Skip to content

Commit 4b207aa

Browse files
committed
Improved code clarity and readability
1 parent 75cb0bc commit 4b207aa

File tree

7 files changed

+494
-527
lines changed

7 files changed

+494
-527
lines changed

src/cnn_affinity.py

Lines changed: 81 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -25,157 +25,140 @@
2525
int_smiles = dict(zip(elements_smiles, range(1, len(elements_smiles)+1)))
2626
int_fasta = dict(zip(elements_fasta, range(1, len(elements_fasta)+1)))
2727

28-
def convert(arx=file_path):
29-
30-
#Function to convert all elements (both smiles and fasta) into int, in order to be trained in the model
31-
32-
smiles_w_numbers = [] # Smiles obtained with int_smiles[1] and the smiles of the df
33-
for i in arx.smiles:
34-
smiles_list = []
35-
for elements in i: # Elements refers to the elements that make up elements_smile
36-
try:
37-
smiles_list.append(int_smiles[elements])
38-
except:
39-
pass
40-
while (len(smiles_list) != max_smiles):
41-
smiles_list.append(0)
28+
def convert(file_path=file_path):
29+
30+
'''
31+
Function to convert all elements (both smiles and fasta) into int, in order to be trained in the model
32+
33+
Parameters:
34+
35+
file_path (path): DataFrame containing the SMILES, FASTA and IC50 columns. Columns must be named "smiles", "sequence" and "IC50". This file is generated from src/fix_data_for_models.py
36+
37+
Returns:
38+
39+
smiles_w_numbers (list): List of SMILES converted to integers
40+
fasta_w_numbers (list): List of FASTA converted to integers
41+
42+
'''
43+
44+
smiles_w_numbers = []
45+
for i in file_path.smiles:
46+
smiles_list = [int_smiles.get(element, 0) for element in i]
47+
smiles_list.extend([0] * (max_smiles - len(smiles_list)))
4248
smiles_w_numbers.append(smiles_list)
4349

4450
fasta_w_numbers = []
45-
for i in arx.sequence:
46-
fasta_list = []
47-
for elements in i: # Elements fa referència a els elements que formen elements_smile
48-
try:
49-
fasta_list.append(int_fasta[elements])
50-
except:
51-
pass
52-
while (len(fasta_list) != max_fasta):
53-
fasta_list.append(0)
51+
for i in file_path.sequence:
52+
fasta_list = [int_fasta.get(element, 0) for element in i]
53+
fasta_list.extend([0] * (max_fasta - len(fasta_list)))
5454
fasta_w_numbers.append(fasta_list)
5555

56-
ic50_numeros = list(arx.IC50)
56+
ic50_numeros = list(file_path.IC50)
5757

5858
return smiles_w_numbers, fasta_w_numbers, ic50_numeros
5959

6060

61-
X_test_smile, X_test_fasta, T_test_IC50 = convert(arx[350000:])
6261

62+
X_test_smile, X_test_fasta, T_test_IC50 = convert(file_path[350000:])
6363

64-
def model_cnn():
65-
# model to train
66-
67-
# kernel regularizer
68-
regulatos = l2(0.001)
69-
70-
# model per a smiles
71-
smiles_input = Input(
72-
shape=(max_smiles,), dtype='int32', name='smiles_input')
73-
embed = Embedding(input_dim=len(
74-
elements_smiles)+1, input_length=max_smiles, output_dim=128)(smiles_input)
75-
x = Conv1D(
76-
filters=32, kernel_size=3, padding="SAME", input_shape=(50700, max_smiles))(embed)
64+
65+
def model_cnn(file_path=file_path):
66+
67+
'''
68+
Function to train a model using CNN. The model is trained using the SMILES and FASTA sequences.
69+
The model is trained using the IC50 values.
70+
71+
Parameters:
72+
file_path (path): DataFrame containing the SMILES, FASTA and IC50 columns. Columns must be named "smiles", "sequence" and "IC50". This file is generated from src/fix_data_for_models.py
73+
74+
'''
75+
regulator = l2(0.001)
76+
77+
# Model for SMILES
78+
smiles_input = Input(shape=(max_smiles,), dtype='int32', name='smiles_input')
79+
embed_smiles = Embedding(input_dim=len(elements_smiles)+1, input_length=max_smiles, output_dim=128)(smiles_input)
80+
x = Conv1D(filters=32, kernel_size=3, padding="SAME", kernel_regularizer=regulator)(embed_smiles)
7781
x = PReLU()(x)
7882

7983
x = Conv1D(filters=64, kernel_size=3, padding="SAME")(x)
8084
x = BatchNormalization()(x)
8185
x = PReLU()(x)
82-
x = Conv1D(
83-
filters=128, kernel_size=3, padding="SAME")(x)
86+
x = Conv1D(filters=128, kernel_size=3, padding="SAME")(x)
8487
x = BatchNormalization()(x)
8588
x = PReLU()(x)
86-
pool = GlobalMaxPooling1D()(
87-
x) # maxpool to get a 1d vector
89+
pool_smiles = GlobalMaxPooling1D()(x)
8890

89-
# model per fastas
91+
# Model for FASTA
9092
fasta_input = Input(shape=(max_fasta,), name='fasta_input')
91-
embed2 = Embedding(input_dim=len(
92-
elements_fasta)+1, input_length=max_fasta, output_dim=256)(fasta_input)
93-
x2 = Conv1D(
94-
filters=32, kernel_size=3, padding="SAME", input_shape=(50700, max_fasta))(embed2)
95-
x2 = PReLU()(embed2)
96-
97-
x2 = Conv1D(
98-
filters=64, kernel_size=3, padding="SAME")(x2)
93+
embed_fasta = Embedding(input_dim=len(elements_fasta)+1, input_length=max_fasta, output_dim=256)(fasta_input)
94+
x2 = Conv1D(filters=32, kernel_size=3, padding="SAME")(embed_fasta)
95+
x2 = PReLU()(x2)
96+
97+
x2 = Conv1D(filters=64, kernel_size=3, padding="SAME")(x2)
9998
x2 = BatchNormalization()(x2)
10099
x2 = PReLU()(x2)
101-
x2 = Conv1D(
102-
filters=128, kernel_size=3, padding="SAME")(x2)
100+
x2 = Conv1D(filters=128, kernel_size=3, padding="SAME")(x2)
103101
x2 = BatchNormalization()(x2)
104102
x2 = PReLU()(x2)
105-
pool2 = GlobalMaxPooling1D()(
106-
x2) #maxpool to get a 1d vector
107-
108-
junt = concatenate(inputs=[pool, pool2])
109-
110-
# dense
103+
pool_fasta = GlobalMaxPooling1D()(x2)
111104

112-
de = Dense(units=1024, activation="relu")(junt)
113-
dr = Dropout(0.3)(de)
114-
de = Dense(units=1024, activation="relu")(dr)
115-
dr = Dropout(0.3)(de)
116-
de2 = Dense(units=512, activation="relu")(dr)
105+
# Concatenate and Dense layers
106+
combined = concatenate([pool_smiles, pool_fasta])
107+
dense = Dense(units=1024, activation="relu")(combined)
108+
dense = Dropout(0.3)(dense)
109+
dense = Dense(units=1024, activation="relu")(dense)
110+
dense = Dropout(0.3)(dense)
111+
dense = Dense(units=512, activation="relu")(dense)
117112

118-
# output
113+
output = Dense(1, activation="relu", name="output")(dense)
119114

120-
output = Dense(
121-
1, activation="relu", name="output", kernel_initializer="normal")(de2)
115+
model = tf.keras.models.Model(inputs=[smiles_input, fasta_input], outputs=[output])
122116

123-
model = tf.keras.models.Model(
124-
inputs=[smiles_input, fasta_input], outputs=[output])
125-
126-
127-
# funció per mirar la precisió del model (serà la nostra metric)
128117
def r2_score(y_true, y_pred):
129118
SS_res = K.sum(K.square(y_true - y_pred))
130119
SS_tot = K.sum(K.square(y_true - K.mean(y_true)))
131-
return (1-SS_res/(SS_tot)+K.epsilon())
120+
return (1 - SS_res / (SS_tot + K.epsilon()))
132121

133-
model.load_weights(
134-
r"")
135-
# In case you want to continue training a model
136-
137122
model.compile(optimizer="adam",
138-
loss={'output': "mean_squared_logarithmic_error"},
139-
metrics={'output': r2_score})
140-
141-
# To do checkpoints
123+
loss={'output': "mean_squared_logarithmic_error"},
124+
metrics={'output': r2_score})
125+
142126
save_model_path = "models/cnn_model.hdf5"
143-
checkpoint = ModelCheckpoint(save_model_path,
144-
monitor='val_loss',
145-
verbose=1,
146-
save_best_only=True)
127+
checkpoint = ModelCheckpoint(save_model_path, monitor='val_loss', verbose=1, save_best_only=True)
147128

148-
# We use a high value to get better results
149129
size_per_epoch = 50700
150-
151-
train = arx[:355000]
130+
train = file_path[:355000]
152131
loss = []
153132
loss_validades = []
154133
epochs = 50
155134

156-
for epoch in range(epochs): #Amount of epochs you want to use
135+
for epoch in range(epochs):
157136
start = 0
158137
end = size_per_epoch
159-
print(f"Començant el epoch {epoch+1}")
138+
print(f"Comenzando el epoch {epoch+1}")
160139

161-
while final < 355000:
140+
while end <= 355000:
162141
X_smiles, X_fasta, y_train = convert(train[start:end])
163142

164143
r = model.fit({'smiles_input': np.array(X_smiles),
165-
'fasta_input': np.array(X_fasta)}, {'output': np.array(y_train)},
144+
'fasta_input': np.array(X_fasta)},
145+
{'output': np.array(y_train)},
166146
validation_data=({'smiles_input': np.array(X_test_smile),
167-
'fasta_input': np.array(X_test_fasta)}, {'output': np.array(T_test_IC50)}), callbacks=[checkpoint], epochs=20, batch_size=64, shuffle=True)
147+
'fasta_input': np.array(X_test_fasta)},
148+
{'output': np.array(T_test_IC50)}),
149+
callbacks=[checkpoint], epochs=1, batch_size=64, shuffle=True)
168150

169-
inici += size_per_epoch
170-
final += size_per_epoch
151+
start += size_per_epoch
152+
end += size_per_epoch
171153

172-
loss.append(r.history["loss"])
173-
loss_validades.append(r.history["val_loss"])
154+
loss.append(np.mean(r.history["loss"]))
155+
loss_validades.append(np.mean(r.history["val_loss"]))
174156

175157
plt.plot(range(epochs), loss, label="loss")
176158
plt.plot(range(epochs), loss_validades, label="val_loss")
177159
plt.legend()
178160
plt.show()
179161

180162

181-
model_cnn()
163+
# Example usage
164+
model_cnn(file_path=file_path)

0 commit comments

Comments
 (0)