Skip to content

Commit 31da5ff

Browse files
authored
fix(handler): do call handler methods when initialize server (#118)
1 parent d0962ec commit 31da5ff

File tree

3 files changed

+40
-14
lines changed

3 files changed

+40
-14
lines changed

crates/rmcp/src/service.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -504,15 +504,16 @@ where
504504
T: IntoTransport<R, E, A>,
505505
E: std::error::Error + Send + Sync + 'static,
506506
{
507-
serve_inner(service, transport, peer_info, Default::default(), ct).await
507+
let (peer, peer_rx) = Peer::new(Arc::new(AtomicU32RequestIdProvider::default()), peer_info);
508+
serve_inner(service, transport, peer, peer_rx, ct).await
508509
}
509510

510511
#[instrument(skip_all)]
511512
async fn serve_inner<R, S, T, E, A>(
512513
mut service: S,
513514
transport: T,
514-
peer_info: R::PeerInfo,
515-
id_provider: Arc<AtomicU32Provider>,
515+
peer: Peer<R>,
516+
mut peer_rx: tokio::sync::mpsc::Receiver<PeerSinkMessage<R>>,
516517
ct: CancellationToken,
517518
) -> Result<RunningService<R, S>, E>
518519
where
@@ -525,14 +526,13 @@ where
525526
const SINK_PROXY_BUFFER_SIZE: usize = 64;
526527
let (sink_proxy_tx, mut sink_proxy_rx) =
527528
tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
528-
529+
let peer_info = peer.peer_info();
529530
if R::IS_CLIENT {
530531
tracing::info!(?peer_info, "Service initialized as client");
531532
} else {
532533
tracing::info!(?peer_info, "Service initialized as server");
533534
}
534535

535-
let (peer, mut peer_proxy) = <Peer<R>>::new(id_provider, peer_info);
536536
service.set_peer(peer.clone());
537537
let mut local_responder_pool = HashMap::new();
538538
let mut local_ct_pool = HashMap::<RequestId, CancellationToken>::new();
@@ -576,7 +576,7 @@ where
576576
break QuitReason::Closed
577577
}
578578
}
579-
m = peer_proxy.recv() => {
579+
m = peer_rx.recv() => {
580580
if let Some(m) = m {
581581
Event::ProxyMessage(m)
582582
} else {

crates/rmcp/src/service/client.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ where
174174
}),
175175
);
176176
sink.send(notification).await?;
177-
serve_inner(service, (sink, stream), initialize_result, id_provider, ct).await
177+
let (peer, peer_rx) = Peer::new(id_provider, initialize_result);
178+
serve_inner(service, (sink, stream), peer, peer_rx, ct).await
178179
}
179180

180181
macro_rules! method {

crates/rmcp/src/service/server.rs

+32-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use super::*;
55
use crate::model::{
66
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
77
ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
8-
CreateMessageRequestParam, CreateMessageResult, ListRootsRequest, ListRootsResult,
8+
CreateMessageRequestParam, CreateMessageResult, ErrorData, ListRootsRequest, ListRootsResult,
99
LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification,
1010
ProgressNotificationParam, PromptListChangedNotification, ResourceListChangedNotification,
1111
ResourceUpdatedNotification, ResourceUpdatedNotificationParam, ServerInfo, ServerNotification,
@@ -41,6 +41,12 @@ pub enum ServerError {
4141
#[error("connection closed: {0}")]
4242
ConnectionClosed(String),
4343

44+
#[error("unexpected initialize result: {0:?}")]
45+
UnexpectedInitializeResponse(ServerResult),
46+
47+
#[error("initialize failed: {0}")]
48+
InitializeFailed(ErrorData),
49+
4450
#[error("IO error: {0}")]
4551
Io(#[from] std::io::Error),
4652
}
@@ -144,14 +150,34 @@ where
144150
.await
145151
.map_err(handle_server_error)?;
146152

147-
let ClientRequest::InitializeRequest(peer_info) = request else {
153+
let ClientRequest::InitializeRequest(peer_info) = &request else {
148154
return Err(handle_server_error(ServerError::ExpectedInitRequest(Some(
149155
ClientJsonRpcMessage::request(request, id),
150156
))));
151157
};
152-
158+
let (peer, peer_rx) = Peer::new(id_provider, peer_info.params.clone());
159+
let context = RequestContext {
160+
ct: ct.child_token(),
161+
id: id.clone(),
162+
meta: request.get_meta().clone(),
163+
extensions: request.extensions().clone(),
164+
peer: peer.clone(),
165+
};
153166
// Send initialize response
154-
let mut init_response = service.get_info();
167+
let init_response = service.handle_request(request.clone(), context).await;
168+
let mut init_response = match init_response {
169+
Ok(ServerResult::InitializeResult(init_response)) => init_response,
170+
Ok(result) => {
171+
return Err(handle_server_error(
172+
ServerError::UnexpectedInitializeResponse(result),
173+
));
174+
}
175+
Err(e) => {
176+
sink.send(ServerJsonRpcMessage::error(e.clone(), id))
177+
.await?;
178+
return Err(handle_server_error(ServerError::InitializeFailed(e)));
179+
}
180+
};
155181
let protocol_version = match peer_info
156182
.params
157183
.protocol_version
@@ -174,15 +200,14 @@ where
174200
let notification = expect_notification(&mut stream, "initialize notification")
175201
.await
176202
.map_err(handle_server_error)?;
177-
178203
let ClientNotification::InitializedNotification(_) = notification else {
179204
return Err(handle_server_error(ServerError::ExpectedInitNotification(
180205
Some(ClientJsonRpcMessage::notification(notification)),
181206
)));
182207
};
183-
208+
let _ = service.handle_notification(notification).await;
184209
// Continue processing service
185-
serve_inner(service, (sink, stream), peer_info.params, id_provider, ct).await
210+
serve_inner(service, (sink, stream), peer, peer_rx, ct).await
186211
}
187212

188213
macro_rules! method {

0 commit comments

Comments
 (0)