context server: Make requests type safe (#32254)

This changes the context server crate so that the input/output for a
request are encoded at the type level, similar to how it is done for LSP
requests.

This also makes it easier to write tests that mock context servers, e.g.
you can write something like this now when using the `test-support`
feature of the `context-server` crate:

```rust
create_fake_transport("mcp-1", cx.background_executor())
    .on_request::<context_server::types::request::PromptsList>(|_params| {
        PromptsListResponse {
            prompts: vec![/* some prompts */],
            ..
        }
    })
```

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2025-06-06 17:47:21 +02:00 committed by GitHub
parent 454adfacae
commit 95d78ff8d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 320 additions and 433 deletions

View file

@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand {
cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?;
let completion_result = protocol
.completion(
context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference {
r#type: context_server::types::PromptReferenceType::Prompt,
name: prompt_name,
let response = protocol
.request::<context_server::types::request::CompletionComplete>(
context_server::types::CompletionCompleteParams {
reference: context_server::types::CompletionReference::Prompt(
context_server::types::PromptReference {
ty: context_server::types::PromptReferenceType::Prompt,
name: prompt_name,
},
),
argument: context_server::types::CompletionArgument {
name: arg_name,
value: arg_value,
},
),
arg_name,
arg_value,
meta: None,
},
)
.await?;
let completions = completion_result
let completions = response
.completion
.values
.into_iter()
.map(|value| ArgumentCompletion {
@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand {
if let Some(server) = store.get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?;
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
let response = protocol
.request::<context_server::types::request::PromptsGet>(
context_server::types::PromptsGetParams {
name: prompt_name.clone(),
arguments: Some(prompt_args),
meta: None,
},
)
.await?;
anyhow::ensure!(
result
response
.messages
.iter()
.all(|msg| matches!(msg.role, context_server::types::Role::User)),
@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand {
);
// Extract text from user messages into a single prompt string
let mut prompt = result
let mut prompt = response
.messages
.into_iter()
.filter_map(|msg| match msg.content {
@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand {
range: 0..(prompt.len()),
icon: IconName::ZedAssistant,
label: SharedString::from(
result
response
.description
.unwrap_or(format!("Result from {}", prompt_name)),
),