Skip to content

Commit b28f340

Browse files
committed
fix loop blocking
1 parent d6f5f4d commit b28f340

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

src/rust/container/lib.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use tokio::runtime::Builder;
2525
pub mod io;
2626
use io::channel;
2727
use io::signo_as_string;
28+
use tokio::sync::Notify;
2829

2930
#[derive(Debug, Error)]
3031
pub enum ContainerError {
@@ -99,6 +100,7 @@ impl ContainerService {
99100
return false;
100101
}
101102
self.r#impl.sender.try_send(data.to_vec()).unwrap();
103+
self.r#impl.notify.notify_one();
102104
true
103105
}
104106

@@ -118,6 +120,7 @@ impl ContainerService {
118120

119121
struct Impl {
120122
sender: mpsc::SyncSender<Vec<u8>>,
123+
notify: Arc<Notify>,
121124
write_shutdown: std::sync::atomic::AtomicBool,
122125
}
123126

@@ -128,8 +131,10 @@ impl Impl {
128131
mut messages_callback: Pin<&'static mut ffi::MessageCallback>,
129132
) -> Result<Self, ContainerError> {
130133
let (input_sender, input_receiver) = mpsc::sync_channel::<Vec<u8>>(1000);
131-
let (mut output_sender, mut output_receiver) = channel();
132-
let (mut tokio_input_sender, mut tokio_input_receiver) = channel();
134+
let (output_sender, mut output_receiver) = channel();
135+
let (mut tokio_input_sender, tokio_input_receiver) = channel();
136+
let output_notify = Arc::new(Notify::new());
137+
let notifiy_awaiter = output_notify.clone();
133138

134139
let server = Server::connect(address, container_name)?;
135140
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
@@ -152,7 +157,7 @@ impl Impl {
152157
tokio::task::spawn_local(rpc_system);
153158
});
154159
local.spawn_local(async move {
155-
while true {
160+
loop {
156161
let mut buf = [0; 8096];
157162
let len = output_receiver.read(&mut buf).await;
158163
if len.is_err() {
@@ -163,9 +168,12 @@ impl Impl {
163168
dbg!("OUTPUT_RECEIVER.RECV() ENDED");
164169
});
165170
local.spawn_local(async move {
166-
while let Ok(msg) = input_receiver.recv() {
167-
if tokio_input_sender.write_all(&msg).await.is_err() {
168-
break;
171+
loop {
172+
notifiy_awaiter.notified().await;
173+
while let Ok(msg) = input_receiver.try_recv() {
174+
if tokio_input_sender.write(&msg).await.is_err() {
175+
break;
176+
}
169177
}
170178
}
171179
dbg!("INPUT_RECEIVER.RECV() ENDED");
@@ -176,6 +184,7 @@ impl Impl {
176184

177185
Ok(Impl {
178186
sender: input_sender,
187+
notify: output_notify,
179188
write_shutdown: std::sync::atomic::AtomicBool::new(false),
180189
})
181190
}

src/workerd/io/container-client.c++

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,11 @@ kj::Promise<size_t> ContainerAsyncStream::tryRead(void* buffer, size_t minBytes,
7070

7171
kj::Promise<void> ContainerAsyncStream::write(kj::ArrayPtr<const kj::byte> buffer) {
7272
KJ_DBG("WRITE");
73-
if (service->write_data(buffer.as<Rust>())) {
74-
return kj::READY_NOW;
75-
} else {
73+
if (!service->write_data(buffer.as<Rust>())) {
7674
KJ_DBG("WRITE FAILED");
7775
return KJ_EXCEPTION(DISCONNECTED, "Write failed: stream is disconnected");
7876
}
77+
return kj::READY_NOW;
7978
}
8079

8180
kj::Promise<void> ContainerAsyncStream::write(

src/workerd/server/server.c++

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,28 +2277,24 @@ class Server::WorkerService final: public Service,
22772277
static constexpr uint16_t hibernationEventTypeId = 8;
22782278

22792279
kj::Maybe<rpc::Container::Client> containerClient = kj::none;
2280-
kj::Own<io::ContainerAsyncStream> containerStream;
22812280
kj::Own<capnp::TwoPartyClient> containerRpcClient;
22822281

22832282
KJ_IF_SOME(config, containerOptions) {
2284-
KJ_IF_SOME(path, service.containerAddress) {
2285-
KJ_REQUIRE(config.hasName(), "Container name is required");
2286-
auto container_name = config.getName();
2287-
containerStream = io::createContainerRpcStream(kj::str(path), kj::str(container_name));
2288-
containerRpcClient = kj::heap<capnp::TwoPartyClient>(*containerStream);
2289-
containerClient = containerRpcClient->bootstrap().castAs<rpc::Container>();
2290-
KJ_DBG("CREATED CONTAINER CLIENT");
2291-
} else {
2292-
KJ_FAIL_REQUIRE(
2293-
"container address needs to be defined in order enable containers on this durable object.");
2294-
}
2283+
auto& path = KJ_ASSERT_NONNULL(service.containerAddress,
2284+
"container address needs to be defined in order enable containers on this durable object.");
2285+
KJ_REQUIRE(config.hasName(), "Container name is required");
2286+
auto container_name = config.getName();
2287+
auto containerStream =
2288+
io::createContainerRpcStream(kj::str(path), kj::str(container_name));
2289+
containerRpcClient =
2290+
kj::heap<capnp::TwoPartyClient>(*containerStream).attach(kj::mv(containerStream));
2291+
containerClient = containerRpcClient->bootstrap().castAs<rpc::Container>();
22952292
}
22962293

22972294
auto& actorRef = *actor.emplace(kj::refcounted<Worker::Actor>(*service.worker, getTracker(),
22982295
Worker::Actor::cloneId(id), true, kj::mv(makeActorCache), parent.className,
22992296
kj::mv(makeStorage), kj::mv(loopback), timerChannel, kj::refcounted<ActorObserver>(),
23002297
tryGetManagerRef(), hibernationEventTypeId, kj::mv(containerClient))
2301-
.attach(kj::mv(containerStream))
23022298
.attach(kj::mv(containerRpcClient)));
23032299
onBrokenTask = monitorOnBroken(actorRef);
23042300
}

0 commit comments

Comments
 (0)