|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +import torchtext |
| 4 | +import time |
| 5 | +import random |
| 6 | +import pandas as pd |
| 7 | +from music21 import converter, instrument, note, chord |
| 8 | + |
| 9 | +def train_network(): |
| 10 | + """ Train a Neural Network to generate music """ |
| 11 | + notes = get_notes() |
| 12 | + |
| 13 | + # get amount of pitch names |
| 14 | + n_vocab = len(set(notes)) |
| 15 | + |
| 16 | + network_input, network_output = prepare_sequences(notes, n_vocab) |
| 17 | + |
| 18 | + model = create_network(network_input, n_vocab) |
| 19 | + |
| 20 | + train(model, network_input, network_output) |
| 21 | + |
| 22 | +def get_notes(): |
| 23 | + """ Get all the notes and chords from the midi files in the ./midi_songs directory """ |
| 24 | + notes = [] |
| 25 | + |
| 26 | + for file in glob.glob("midi_songs/*.mid"): |
| 27 | + midi = converter.parse(file) |
| 28 | + |
| 29 | + print("Parsing %s" % file) |
| 30 | + |
| 31 | + notes_to_parse = None |
| 32 | + |
| 33 | + try: # file has instrument parts |
| 34 | + s2 = instrument.partitionByInstrument(midi) |
| 35 | + notes_to_parse = s2.parts[0].recurse() |
| 36 | + except: # file has notes in a flat structure |
| 37 | + notes_to_parse = midi.flat.notes |
| 38 | + |
| 39 | + for element in notes_to_parse: |
| 40 | + if isinstance(element, note.Note): |
| 41 | + notes.append(str(element.pitch)) |
| 42 | + elif isinstance(element, chord.Chord): |
| 43 | + notes.append('.'.join(str(n) for n in element.normalOrder)) |
| 44 | + |
| 45 | + with open('data/notes', 'wb') as filepath: |
| 46 | + pickle.dump(notes, filepath) |
| 47 | + |
| 48 | + return notes |
| 49 | + |
| 50 | +def prepare_sequences(notes, n_vocab): |
| 51 | + """ Prepare the sequences used by the Neural Network """ |
| 52 | + sequence_length = 100 |
| 53 | + |
| 54 | + # get all pitch names |
| 55 | + pitchnames = sorted(set(item for item in notes)) |
| 56 | + |
| 57 | + # create a dictionary to map pitches to integers |
| 58 | + note_to_int = dict((note, number) for number, note in enumerate(pitchnames)) |
| 59 | + |
| 60 | + network_input = [] |
| 61 | + network_output = [] |
| 62 | + |
| 63 | + # create input sequences and the corresponding outputs |
| 64 | + for i in range(0, len(notes) - sequence_length, 1): |
| 65 | + sequence_in = notes[i:i + sequence_length] |
| 66 | + sequence_out = notes[i + sequence_length] |
| 67 | + network_input.append([note_to_int[char] for char in sequence_in]) |
| 68 | + network_output.append(note_to_int[sequence_out]) |
| 69 | + |
| 70 | + n_patterns = len(network_input) |
| 71 | + |
| 72 | + # reshape the input into a format compatible with LSTM layers |
| 73 | + network_input = numpy.reshape(network_input, (n_patterns, sequence_length, 1)) |
| 74 | + # normalize input |
| 75 | + network_input = network_input / float(n_vocab) |
| 76 | + |
| 77 | + network_output = np_utils.to_categorical(network_output) |
| 78 | + |
| 79 | + return (network_input, network_output) |
| 80 | + |
| 81 | +def create_network(network_input, n_vocab): |
| 82 | + """ create the structure of the neural network """ |
| 83 | + model = Sequential() |
| 84 | + model.add(LSTM( |
| 85 | + 512, |
| 86 | + input_shape=(network_input.shape[1], network_input.shape[2]), |
| 87 | + recurrent_dropout=0.3, |
| 88 | + return_sequences=True |
| 89 | + )) |
| 90 | + model.add(LSTM(512, return_sequences=True, recurrent_dropout=0.3,)) |
| 91 | + model.add(LSTM(512)) |
| 92 | + model.add(BatchNorm()) |
| 93 | + model.add(Dropout(0.3)) |
| 94 | + model.add(Dense(256)) |
| 95 | + model.add(Activation('relu')) |
| 96 | + model.add(BatchNorm()) |
| 97 | + model.add(Dropout(0.3)) |
| 98 | + model.add(Dense(n_vocab)) |
| 99 | + model.add(Activation('softmax')) |
| 100 | + model.compile(loss='categorical_crossentropy', optimizer='rmsprop') |
| 101 | + |
| 102 | + return model |
| 103 | + |
| 104 | +def train(model, network_input, network_output): |
| 105 | + """ train the neural network """ |
| 106 | + filepath = "weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5" |
| 107 | + checkpoint = ModelCheckpoint( |
| 108 | + filepath, |
| 109 | + monitor='loss', |
| 110 | + verbose=0, |
| 111 | + save_best_only=True, |
| 112 | + mode='min' |
| 113 | + ) |
| 114 | + callbacks_list = [checkpoint] |
| 115 | + |
| 116 | + model.fit(network_input, network_output, epochs=200, batch_size=128, callbacks=callbacks_list) |
| 117 | + |
| 118 | +if __name__ == '__main__': |
| 119 | + train_network() |
| 120 | + |
0 commit comments