Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::{
JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Meta,
NumberOrString, ProgressToken, RequestId,
},
transport::{DynamicTransportError, IntoTransport, Transport},
transport::{DynamicTransportError, IntoTransport, Transport, TransportSessionIdHandle},
};
#[cfg(feature = "client")]
mod client;
Expand Down Expand Up @@ -508,6 +508,7 @@ pub struct RunningService<R: ServiceRole, S: Service<R>> {
handle: Option<tokio::task::JoinHandle<QuitReason>>,
cancellation_token: CancellationToken,
dg: DropGuard,
session_id_handle: Option<TransportSessionIdHandle>,
}
impl<R: ServiceRole, S: Service<R>> Deref for RunningService<R, S> {
type Target = Peer<R>;
Expand All @@ -530,6 +531,12 @@ impl<R: ServiceRole, S: Service<R>> RunningService<R, S> {
pub fn cancellation_token(&self) -> RunningServiceCancellationToken {
RunningServiceCancellationToken(self.cancellation_token.clone())
}
#[inline]
pub fn session_id(&self) -> Option<Arc<str>> {
self.session_id_handle
.as_ref()
.and_then(|handle| handle.session_id())
}

/// Returns true if the service has been closed or cancelled.
#[inline]
Expand Down Expand Up @@ -755,6 +762,7 @@ where
let (sink_proxy_tx, mut sink_proxy_rx) =
tokio::sync::mpsc::channel::<TxJsonRpcMessage<R>>(SINK_PROXY_BUFFER_SIZE);
let peer_info = peer.peer_info();
let session_id_handle = transport.session_id_handle();
if R::IS_CLIENT {
tracing::info!(?peer_info, "Service initialized as client");
} else {
Expand Down Expand Up @@ -1094,5 +1102,6 @@ where
handle: Some(handle),
cancellation_token: ct.clone(),
dg: ct.drop_guard(),
session_id_handle,
}
}
33 changes: 33 additions & 0 deletions crates/rmcp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,39 @@ where

/// Close the transport
fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send;

/// Returns a handle for reading the current transport session ID, when the
/// transport protocol has one.
fn session_id_handle(&self) -> Option<TransportSessionIdHandle> {
None
}

/// Returns the current transport session ID, when the transport protocol has one.
fn session_id(&self) -> Option<Arc<str>> {
self.session_id_handle()
.and_then(|handle| handle.session_id())
}
}

/// Read-only provider for transports that negotiate a session ID.
pub trait TransportSessionIdProvider: std::fmt::Debug + Send + Sync + 'static {
fn session_id(&self) -> Option<Arc<str>>;
}

/// Cloneable handle for observing a transport's current session ID.
#[derive(Debug, Clone)]
pub struct TransportSessionIdHandle {
provider: Arc<dyn TransportSessionIdProvider>,
}

impl TransportSessionIdHandle {
pub fn new(provider: Arc<dyn TransportSessionIdProvider>) -> Self {
Self { provider }
}

pub fn session_id(&self) -> Option<Arc<str>> {
self.provider.session_id()
}
}

pub trait IntoTransport<R, E, A>: Send + 'static
Expand Down
48 changes: 46 additions & 2 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
use std::{
borrow::Cow,
collections::HashMap,
sync::{Arc, RwLock},
time::Duration,
};

use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use http::{HeaderName, HeaderValue};
Expand All @@ -16,13 +21,38 @@ use crate::{
ServerResult,
},
transport::{
TransportSessionIdHandle, TransportSessionIdProvider,
common::client_side_sse::SseAutoReconnectStream,
worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
},
};

type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;

/// Cloneable read-only handle for the negotiated streamable HTTP session ID.
#[derive(Debug, Clone, Default)]
pub struct StreamableHttpClientSession {
session_id: Arc<RwLock<Option<Arc<str>>>>,
}

impl StreamableHttpClientSession {
pub fn session_id(&self) -> Option<Arc<str>> {
self.session_id.read().ok().and_then(|guard| guard.clone())
}

fn set_session_id(&self, session_id: Option<Arc<str>>) {
if let Ok(mut guard) = self.session_id.write() {
*guard = session_id;
}
}
}

impl TransportSessionIdProvider for StreamableHttpClientSession {
fn session_id(&self) -> Option<Arc<str>> {
self.session_id()
}
}

#[derive(Debug)]
#[non_exhaustive]
pub struct AuthRequiredError {
Expand Down Expand Up @@ -277,6 +307,7 @@ struct SessionCleanupInfo<C> {
pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
pub client: C,
pub config: StreamableHttpClientTransportConfig,
session: StreamableHttpClientSession,
}

impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
Expand All @@ -287,13 +318,18 @@ impl<C: StreamableHttpClient + Default> StreamableHttpClientWorker<C> {
uri: url.into(),
..Default::default()
},
session: StreamableHttpClientSession::default(),
}
}
}

impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
pub fn new(client: C, config: StreamableHttpClientTransportConfig) -> Self {
Self { client, config }
Self {
client,
config,
session: StreamableHttpClientSession::default(),
}
}
}

Expand Down Expand Up @@ -447,6 +483,11 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
fn err_join(e: tokio::task::JoinError) -> Self::Error {
StreamableHttpError::TokioJoinError(e)
}
fn session_id_handle(&self) -> Option<TransportSessionIdHandle> {
Some(TransportSessionIdHandle::new(Arc::new(
self.session.clone(),
)))
}
fn config(&self) -> super::worker::WorkerConfig {
super::worker::WorkerConfig {
name: Some("StreamableHttpClientWorker".into()),
Expand Down Expand Up @@ -505,6 +546,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
None
};
self.session.set_session_id(session_id.clone());
// Extract the negotiated protocol version from the init response
// and build a custom headers map that includes MCP-Protocol-Version
// for all subsequent HTTP requests (per MCP 2025-06-18 spec).
Expand Down Expand Up @@ -684,6 +726,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
streams.abort_all();

session_id = new_session_id;
self.session.set_session_id(session_id.clone());
protocol_headers = new_protocol_headers;
session_cleanup_info =
session_id.as_ref().map(|sid| SessionCleanupInfo {
Expand Down Expand Up @@ -872,6 +915,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
}
}
self.session.set_session_id(None);

loop_result
}
Expand Down
20 changes: 19 additions & 1 deletion crates/rmcp/src/transport/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{borrow::Cow, time::Duration};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Level};

use super::{IntoTransport, Transport};
use super::{IntoTransport, Transport, TransportSessionIdHandle};
use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -46,6 +46,9 @@ pub trait Worker: Sized + Send + 'static {
type Role: ServiceRole;
fn err_closed() -> Self::Error;
fn err_join(e: tokio::task::JoinError) -> Self::Error;
fn session_id_handle(&self) -> Option<TransportSessionIdHandle> {
None
}
fn run(
self,
context: WorkerContext<Self>,
Expand All @@ -67,6 +70,7 @@ pub struct WorkerTransport<W: Worker> {
join_handle: Option<tokio::task::JoinHandle<Result<(), WorkerQuitReason<W::Error>>>>,
_drop_guard: tokio_util::sync::DropGuard,
ct: CancellationToken,
session_id_handle: Option<TransportSessionIdHandle>,
}

#[non_exhaustive]
Expand Down Expand Up @@ -96,12 +100,21 @@ impl<W: Worker> WorkerTransport<W> {
pub fn cancel_token(&self) -> CancellationToken {
self.ct.clone()
}
pub fn session_id_handle(&self) -> Option<TransportSessionIdHandle> {
self.session_id_handle.clone()
}
pub fn session_id(&self) -> Option<std::sync::Arc<str>> {
self.session_id_handle
.as_ref()
.and_then(|handle| handle.session_id())
}
pub fn spawn(worker: W) -> Self {
Self::spawn_with_ct(worker, CancellationToken::new())
}
pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self {
let config = worker.config();
let worker_name = config.name;
let session_id_handle = worker.session_id_handle();
let (to_transport_tx, from_handler_rx) =
tokio::sync::mpsc::channel::<WorkerSendRequest<W>>(config.channel_buffer_capacity);
let (to_handler_tx, from_transport_rx) =
Expand Down Expand Up @@ -145,6 +158,7 @@ impl<W: Worker> WorkerTransport<W> {
join_handle: Some(join_handle),
ct: transport_task_ct.clone(),
_drop_guard: transport_task_ct.drop_guard(),
session_id_handle,
}
}
}
Expand Down Expand Up @@ -214,4 +228,8 @@ impl<W: Worker> Transport<W::Role> for WorkerTransport<W> {
Ok(())
}
}

fn session_id_handle(&self) -> Option<TransportSessionIdHandle> {
WorkerTransport::session_id_handle(self)
}
}
19 changes: 19 additions & 0 deletions crates/rmcp/tests/test_streamable_http_session_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,22 @@ async fn test_session_state_persisted_to_store() -> anyhow::Result<()> {
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
assert!(
transport.session_id().is_none(),
"session ID should not be set before initialization"
);
let session_id_handle = transport
.session_id_handle()
.expect("streamable HTTP transport should expose a session ID handle");
let client = ().serve(transport).await?;
let client_session_id = client
.session_id()
.expect("session ID should be exposed after initialization");
assert_eq!(
session_id_handle.session_id().as_deref(),
Some(client_session_id.as_ref()),
"transport handle and running service should expose the same session ID"
);

// Make a real request so the session is fully active.
let _resources = client.list_all_resources().await?;
Expand All @@ -122,6 +137,10 @@ async fn test_session_state_persisted_to_store() -> anyhow::Result<()> {

// Verify the stored state contains the expected client info.
let entries = store.0.read().await;
assert!(
entries.contains_key(client_session_id.as_ref()),
"exposed session ID should match the persisted store key"
);
let state = entries.values().next().expect("store entry should exist");
assert_eq!(
state.initialize_params.client_info.name, "rmcp",
Expand Down
Loading