diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index d938cd66..ee947aec 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -51,7 +51,7 @@ use crate::{ JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta, NumberOrString, ProgressToken, RequestId, }, - transport::{DynamicTransportError, IntoTransport, Transport}, + transport::{DynamicTransportError, IntoTransport, Transport, TransportSessionIdHandle}, }; #[cfg(feature = "client")] mod client; @@ -508,6 +508,7 @@ pub struct RunningService> { handle: Option>, cancellation_token: CancellationToken, dg: DropGuard, + session_id_handle: Option, } impl> Deref for RunningService { type Target = Peer; @@ -530,6 +531,12 @@ impl> RunningService { pub fn cancellation_token(&self) -> RunningServiceCancellationToken { RunningServiceCancellationToken(self.cancellation_token.clone()) } + #[inline] + pub fn session_id(&self) -> Option> { + self.session_id_handle + .as_ref() + .and_then(|handle| handle.session_id()) + } /// Returns true if the service has been closed or cancelled. #[inline] @@ -755,6 +762,7 @@ where let (sink_proxy_tx, mut sink_proxy_rx) = tokio::sync::mpsc::channel::>(SINK_PROXY_BUFFER_SIZE); let peer_info = peer.peer_info(); + let session_id_handle = transport.session_id_handle(); if R::IS_CLIENT { tracing::info!(?peer_info, "Service initialized as client"); } else { @@ -1094,5 +1102,6 @@ where handle: Some(handle), cancellation_token: ct.clone(), dg: ct.drop_guard(), + session_id_handle, } } diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 89568b3d..3c4918d4 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -145,6 +145,39 @@ where /// Close the transport fn close(&mut self) -> impl Future> + Send; + + /// Returns a handle for reading the current transport session ID, when the + /// transport protocol has one. + fn session_id_handle(&self) -> Option { + None + } + + /// Returns the current transport session ID, when the transport protocol has one. + fn session_id(&self) -> Option> { + self.session_id_handle() + .and_then(|handle| handle.session_id()) + } +} + +/// Read-only provider for transports that negotiate a session ID. +pub trait TransportSessionIdProvider: std::fmt::Debug + Send + Sync + 'static { + fn session_id(&self) -> Option>; +} + +/// Cloneable handle for observing a transport's current session ID. +#[derive(Debug, Clone)] +pub struct TransportSessionIdHandle { + provider: Arc, +} + +impl TransportSessionIdHandle { + pub fn new(provider: Arc) -> Self { + Self { provider } + } + + pub fn session_id(&self) -> Option> { + self.provider.session_id() + } } pub trait IntoTransport: Send + 'static diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index a2c1a7b1..a2bebe7e 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -1,4 +1,9 @@ -use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; +use std::{ + borrow::Cow, + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream}; use http::{HeaderName, HeaderValue}; @@ -16,6 +21,7 @@ use crate::{ ServerResult, }, transport::{ + TransportSessionIdHandle, TransportSessionIdProvider, common::client_side_sse::SseAutoReconnectStream, worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, }, @@ -23,6 +29,30 @@ use crate::{ type BoxedSseStream = BoxStream<'static, Result>; +/// Cloneable read-only handle for the negotiated streamable HTTP session ID. +#[derive(Debug, Clone, Default)] +pub struct StreamableHttpClientSession { + session_id: Arc>>>, +} + +impl StreamableHttpClientSession { + pub fn session_id(&self) -> Option> { + self.session_id.read().ok().and_then(|guard| guard.clone()) + } + + fn set_session_id(&self, session_id: Option>) { + if let Ok(mut guard) = self.session_id.write() { + *guard = session_id; + } + } +} + +impl TransportSessionIdProvider for StreamableHttpClientSession { + fn session_id(&self) -> Option> { + self.session_id() + } +} + #[derive(Debug)] #[non_exhaustive] pub struct AuthRequiredError { @@ -277,6 +307,7 @@ struct SessionCleanupInfo { pub struct StreamableHttpClientWorker { pub client: C, pub config: StreamableHttpClientTransportConfig, + session: StreamableHttpClientSession, } impl StreamableHttpClientWorker { @@ -287,13 +318,18 @@ impl StreamableHttpClientWorker { uri: url.into(), ..Default::default() }, + session: StreamableHttpClientSession::default(), } } } impl StreamableHttpClientWorker { pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self { - Self { client, config } + Self { + client, + config, + session: StreamableHttpClientSession::default(), + } } } @@ -447,6 +483,11 @@ impl Worker for StreamableHttpClientWorker { fn err_join(e: tokio::task::JoinError) -> Self::Error { StreamableHttpError::TokioJoinError(e) } + fn session_id_handle(&self) -> Option { + Some(TransportSessionIdHandle::new(Arc::new( + self.session.clone(), + ))) + } fn config(&self) -> super::worker::WorkerConfig { super::worker::WorkerConfig { name: Some("StreamableHttpClientWorker".into()), @@ -505,6 +546,7 @@ impl Worker for StreamableHttpClientWorker { } None }; + self.session.set_session_id(session_id.clone()); // Extract the negotiated protocol version from the init response // and build a custom headers map that includes MCP-Protocol-Version // for all subsequent HTTP requests (per MCP 2025-06-18 spec). @@ -684,6 +726,7 @@ impl Worker for StreamableHttpClientWorker { streams.abort_all(); session_id = new_session_id; + self.session.set_session_id(session_id.clone()); protocol_headers = new_protocol_headers; session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo { @@ -872,6 +915,7 @@ impl Worker for StreamableHttpClientWorker { } } } + self.session.set_session_id(None); loop_result } diff --git a/crates/rmcp/src/transport/worker.rs b/crates/rmcp/src/transport/worker.rs index 5294640e..67059ade 100644 --- a/crates/rmcp/src/transport/worker.rs +++ b/crates/rmcp/src/transport/worker.rs @@ -3,7 +3,7 @@ use std::{borrow::Cow, time::Duration}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, Level}; -use super::{IntoTransport, Transport}; +use super::{IntoTransport, Transport, TransportSessionIdHandle}; use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage}; #[derive(Debug, thiserror::Error)] @@ -46,6 +46,9 @@ pub trait Worker: Sized + Send + 'static { type Role: ServiceRole; fn err_closed() -> Self::Error; fn err_join(e: tokio::task::JoinError) -> Self::Error; + fn session_id_handle(&self) -> Option { + None + } fn run( self, context: WorkerContext, @@ -67,6 +70,7 @@ pub struct WorkerTransport { join_handle: Option>>>, _drop_guard: tokio_util::sync::DropGuard, ct: CancellationToken, + session_id_handle: Option, } #[non_exhaustive] @@ -96,12 +100,21 @@ impl WorkerTransport { pub fn cancel_token(&self) -> CancellationToken { self.ct.clone() } + pub fn session_id_handle(&self) -> Option { + self.session_id_handle.clone() + } + pub fn session_id(&self) -> Option> { + self.session_id_handle + .as_ref() + .and_then(|handle| handle.session_id()) + } pub fn spawn(worker: W) -> Self { Self::spawn_with_ct(worker, CancellationToken::new()) } pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self { let config = worker.config(); let worker_name = config.name; + let session_id_handle = worker.session_id_handle(); let (to_transport_tx, from_handler_rx) = tokio::sync::mpsc::channel::>(config.channel_buffer_capacity); let (to_handler_tx, from_transport_rx) = @@ -145,6 +158,7 @@ impl WorkerTransport { join_handle: Some(join_handle), ct: transport_task_ct.clone(), _drop_guard: transport_task_ct.drop_guard(), + session_id_handle, } } } @@ -214,4 +228,8 @@ impl Transport for WorkerTransport { Ok(()) } } + + fn session_id_handle(&self) -> Option { + WorkerTransport::session_id_handle(self) + } } diff --git a/crates/rmcp/tests/test_streamable_http_session_store.rs b/crates/rmcp/tests/test_streamable_http_session_store.rs index 91e77029..2d176874 100644 --- a/crates/rmcp/tests/test_streamable_http_session_store.rs +++ b/crates/rmcp/tests/test_streamable_http_session_store.rs @@ -108,7 +108,22 @@ async fn test_session_state_persisted_to_store() -> anyhow::Result<()> { let transport = StreamableHttpClientTransport::from_config( StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), ); + assert!( + transport.session_id().is_none(), + "session ID should not be set before initialization" + ); + let session_id_handle = transport + .session_id_handle() + .expect("streamable HTTP transport should expose a session ID handle"); let client = ().serve(transport).await?; + let client_session_id = client + .session_id() + .expect("session ID should be exposed after initialization"); + assert_eq!( + session_id_handle.session_id().as_deref(), + Some(client_session_id.as_ref()), + "transport handle and running service should expose the same session ID" + ); // Make a real request so the session is fully active. let _resources = client.list_all_resources().await?; @@ -122,6 +137,10 @@ async fn test_session_state_persisted_to_store() -> anyhow::Result<()> { // Verify the stored state contains the expected client info. let entries = store.0.read().await; + assert!( + entries.contains_key(client_session_id.as_ref()), + "exposed session ID should match the persisted store key" + ); let state = entries.values().next().expect("store entry should exist"); assert_eq!( state.initialize_params.client_info.name, "rmcp",