Skip to content

Commit 47303d9

Browse files
author
Cannon Lock
committed
Temp
1 parent f161ab7 commit 47303d9

File tree

1 file changed

+79
-7
lines changed

1 file changed

+79
-7
lines changed

DeepLearning.py

+79-7
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,58 @@ def get_element_str(el, includeDuration):
4040
note_strings = [get_note_str(n, duration) for n in el.notes]
4141
return " ".join(sorted(note_strings))
4242

43-
def get_notes(includeDuration = False):
43+
def get_dataset(includeDuration = False, byGenre = False, regenerate = False):
44+
45+
fileString = "musicDataset_includeDuration=" +\
46+
str(includeDuration) +\
47+
"_byGenre=" +\
48+
str(byGenre) +\
49+
".pickle"
50+
51+
# Check if this dataset is already generated
52+
if not regenerate:
53+
try:
54+
previouslyGenerateDataset = open(fileString, "rb")
55+
return pickle.load(previouslyGenerateDataset)
56+
except:
57+
print("Re/Creating Music Dataset")
58+
59+
# Create the data set by converting all the midi files into string vectors
60+
dataset = {"total": []}
61+
62+
for genre in GENRES:
63+
for file in glob.glob("../TrainingData/" + genre + "/*.mid"):
64+
65+
print("Parsing %s" % file)
66+
midi = converter.parse(file)
67+
68+
try:
69+
s2 = instrument.partitionByInstrument(midi)
70+
notes_to_parse = s2.parts[0].recurse()
71+
except:
72+
notes_to_parse = midi.flat.notes
73+
74+
# Create the song vector
75+
song = []
76+
for el in notes_to_parse:
77+
if isinstance(el, note.Note) or isinstance(el, chord.Chord) or isinstance(el, note.Rest):
78+
song.append(get_element_str(el, includeDuration))
79+
80+
# Add the song to the list
81+
if byGenre:
82+
if genre in dataset:
83+
dataset[genre].append(song)
84+
else:
85+
dataset[genre] = [song]
86+
87+
dataset["total"].append(song)
88+
89+
with open(fileString, "wb") as output:
90+
pickle.dump(dataset, output, pickle.HIGHEST_PROTOCOL)
91+
92+
return dataset
93+
94+
def build_vocab(includeDuration = False):
4495
""" Get all the notes and chords from the midi files in the ./midi_songs directory """
4596
notes = []
4697

@@ -62,16 +113,36 @@ def get_notes(includeDuration = False):
62113
if isinstance(el, note.Note) or isinstance(el, chord.Chord) or isinstance(el, note.Rest):
63114
notes.append(get_element_str(el, includeDuration))
64115

65-
np.save("notes", notes)
116+
np.save("notes_duration:" + str(includeDuration), notes)
66117

67118
return notes
68119

69-
def prepare_sequences(notes, n_vocab):
120+
def standardize_songs(songs):
121+
122+
max_song_length = max(map(len, songs))
123+
124+
adjusted_songs = []
125+
for song in songs:
126+
127+
song_length = len(song)
128+
full_addition = max_song_length // song_length
129+
part_addition = max_song_length % song_length
130+
131+
adjusted_song = song*full_addition
132+
adjusted_song.extend(song[:part_addition])
133+
134+
adjusted_songs.append(adjusted_song)
135+
136+
return adjusted_songs
137+
138+
139+
def prepare_sequences(notes):
70140
""" Prepare the sequences used by the Neural Network """
71141
sequence_length = 100
142+
vocab_length = len(set(notes))
72143

73144
# get all pitch names
74-
pitchnames = sorted(set(item for item in notes))
145+
pitchnames = sorted(set(notes))
75146

76147
# create a dictionary to map pitches to integers
77148
note_to_int = dict((note, number) for number, note in enumerate(pitchnames))
@@ -91,7 +162,7 @@ def prepare_sequences(notes, n_vocab):
91162
# reshape the input into a format compatible with LSTM layers
92163
network_input = np.reshape(network_input, (n_patterns, sequence_length, 1))
93164
# normalize input
94-
network_input = network_input / float(n_vocab)
165+
network_input = network_input / float(vocab_length)
95166

96167
network_output = np_utils.to_categorical(network_output)
97168

@@ -135,5 +206,6 @@ def train(model, network_input, network_output):
135206
model.fit(network_input, network_output, epochs=200, batch_size=128, callbacks=callbacks_list)
136207

137208
if __name__ == '__main__':
138-
a = get_notes()
139-
209+
songs = get_dataset()
210+
l = standardize_songs(songs["total"])
211+
print("Fart")

0 commit comments

Comments
 (0)