Skip to content

Commit e9f4034

Browse files
committed
Add tests for Params#new_segment_callback=
1 parent 5e350b1 commit e9f4034

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

bindings/ruby/ext/ruby_whisper.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ extern "C" {
3535
VALUE mWhisper;
3636
VALUE cContext;
3737
VALUE cParams;
38+
VALUE cSegment;
3839

3940
static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
4041
return INT2NUM(whisper_lang_max_id());
@@ -476,6 +477,7 @@ void Init_whisper() {
476477
mWhisper = rb_define_module("Whisper");
477478
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
478479
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
480+
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
479481

480482
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
481483
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);

bindings/ruby/tests/test_whisper.rb

+50
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,56 @@ def test_whisper
127127
}
128128
end
129129

130+
def test_new_segment_callback
131+
whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
132+
133+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
134+
assert_kind_of Integer, n_new
135+
assert n_new > 0
136+
assert_same whisper, context
137+
138+
n_segments = context.full_n_segments
139+
n_new.times do |i|
140+
i_segment = n_segments - 1 + i
141+
start_time = context.full_get_segment_t0(i_segment) * 10
142+
end_time = context.full_get_segment_t1(i_segment) * 10
143+
text = context.full_get_segment_text(i_segment)
144+
145+
assert_kind_of Integer, start_time
146+
assert start_time >= 0
147+
assert_kind_of Integer, end_time
148+
assert end_time > 0
149+
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
150+
end
151+
}
152+
153+
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
154+
whisper.transcribe(jfk, @params)
155+
end
156+
157+
def test_new_segment_callback_closure
158+
whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
159+
160+
search_word = "what"
161+
@params.new_segment_callback = ->(context, state, n_new, user_data) {
162+
n_segments = context.full_n_segments
163+
n_new.times do |i|
164+
i_segment = n_segments - 1 + i
165+
text = context.full_get_segment_text(i_segment)
166+
if text.include?(search_word)
167+
t0 = context.full_get_segment_t0(i_segment)
168+
t1 = context.full_get_segment_t1(i_segment)
169+
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
170+
end
171+
end
172+
}
173+
174+
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
175+
assert_raise RuntimeError do
176+
whisper.transcribe(jfk, @params)
177+
end
178+
end
179+
130180
sub_test_case "After transcription" do
131181
class << self
132182
attr_reader :whisper

0 commit comments

Comments
 (0)