@@ -40,7 +40,58 @@ def get_element_str(el, includeDuration):
40
40
note_strings = [get_note_str (n , duration ) for n in el .notes ]
41
41
return " " .join (sorted (note_strings ))
42
42
43
- def get_notes (includeDuration = False ):
43
+ def get_dataset (includeDuration = False , byGenre = False , regenerate = False ):
44
+
45
+ fileString = "musicDataset_includeDuration=" + \
46
+ str (includeDuration ) + \
47
+ "_byGenre=" + \
48
+ str (byGenre ) + \
49
+ ".pickle"
50
+
51
+ # Check if this dataset is already generated
52
+ if not regenerate :
53
+ try :
54
+ previouslyGenerateDataset = open (fileString , "rb" )
55
+ return pickle .load (previouslyGenerateDataset )
56
+ except :
57
+ print ("Re/Creating Music Dataset" )
58
+
59
+ # Create the data set by converting all the midi files into string vectors
60
+ dataset = {"total" : []}
61
+
62
+ for genre in GENRES :
63
+ for file in glob .glob ("../TrainingData/" + genre + "/*.mid" ):
64
+
65
+ print ("Parsing %s" % file )
66
+ midi = converter .parse (file )
67
+
68
+ try :
69
+ s2 = instrument .partitionByInstrument (midi )
70
+ notes_to_parse = s2 .parts [0 ].recurse ()
71
+ except :
72
+ notes_to_parse = midi .flat .notes
73
+
74
+ # Create the song vector
75
+ song = []
76
+ for el in notes_to_parse :
77
+ if isinstance (el , note .Note ) or isinstance (el , chord .Chord ) or isinstance (el , note .Rest ):
78
+ song .append (get_element_str (el , includeDuration ))
79
+
80
+ # Add the song to the list
81
+ if byGenre :
82
+ if genre in dataset :
83
+ dataset [genre ].append (song )
84
+ else :
85
+ dataset [genre ] = [song ]
86
+
87
+ dataset ["total" ].append (song )
88
+
89
+ with open (fileString , "wb" ) as output :
90
+ pickle .dump (dataset , output , pickle .HIGHEST_PROTOCOL )
91
+
92
+ return dataset
93
+
94
+ def build_vocab (includeDuration = False ):
44
95
""" Get all the notes and chords from the midi files in the ./midi_songs directory """
45
96
notes = []
46
97
@@ -62,16 +113,36 @@ def get_notes(includeDuration = False):
62
113
if isinstance (el , note .Note ) or isinstance (el , chord .Chord ) or isinstance (el , note .Rest ):
63
114
notes .append (get_element_str (el , includeDuration ))
64
115
65
- np .save ("notes" , notes )
116
+ np .save ("notes_duration:" + str ( includeDuration ) , notes )
66
117
67
118
return notes
68
119
69
- def prepare_sequences (notes , n_vocab ):
120
+ def standardize_songs (songs ):
121
+
122
+ max_song_length = max (map (len , songs ))
123
+
124
+ adjusted_songs = []
125
+ for song in songs :
126
+
127
+ song_length = len (song )
128
+ full_addition = max_song_length // song_length
129
+ part_addition = max_song_length % song_length
130
+
131
+ adjusted_song = song * full_addition
132
+ adjusted_song .extend (song [:part_addition ])
133
+
134
+ adjusted_songs .append (adjusted_song )
135
+
136
+ return adjusted_songs
137
+
138
+
139
+ def prepare_sequences (notes ):
70
140
""" Prepare the sequences used by the Neural Network """
71
141
sequence_length = 100
142
+ vocab_length = len (set (notes ))
72
143
73
144
# get all pitch names
74
- pitchnames = sorted (set (item for item in notes ))
145
+ pitchnames = sorted (set (notes ))
75
146
76
147
# create a dictionary to map pitches to integers
77
148
note_to_int = dict ((note , number ) for number , note in enumerate (pitchnames ))
@@ -91,7 +162,7 @@ def prepare_sequences(notes, n_vocab):
91
162
# reshape the input into a format compatible with LSTM layers
92
163
network_input = np .reshape (network_input , (n_patterns , sequence_length , 1 ))
93
164
# normalize input
94
- network_input = network_input / float (n_vocab )
165
+ network_input = network_input / float (vocab_length )
95
166
96
167
network_output = np_utils .to_categorical (network_output )
97
168
@@ -135,5 +206,6 @@ def train(model, network_input, network_output):
135
206
model .fit (network_input , network_output , epochs = 200 , batch_size = 128 , callbacks = callbacks_list )
136
207
137
208
if __name__ == '__main__' :
138
- a = get_notes ()
139
-
209
+ songs = get_dataset ()
210
+ l = standardize_songs (songs ["total" ])
211
+ print ("Fart" )
0 commit comments