context server: Make requests type safe (#32254)
This changes the context server crate so that the input/output for a request are encoded at the type level, similar to how it is done for LSP requests. This also makes it easier to write tests that mock context servers, e.g. you can write something like this now when using the `test-support` feature of the `context-server` crate: ```rust create_fake_transport("mcp-1", cx.background_executor()) .on_request::<context_server::types::request::PromptsList>(|_params| { PromptsListResponse { prompts: vec![/* some prompts */], .. } }) ``` Release Notes: - N/A
This commit is contained in:
parent
454adfacae
commit
95d78ff8d5
11 changed files with 320 additions and 433 deletions
|
@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
cx.foreground_executor().spawn(async move {
|
||||
let protocol = server.client().context("Context server not initialized")?;
|
||||
|
||||
let completion_result = protocol
|
||||
.completion(
|
||||
context_server::types::CompletionReference::Prompt(
|
||||
context_server::types::PromptReference {
|
||||
r#type: context_server::types::PromptReferenceType::Prompt,
|
||||
name: prompt_name,
|
||||
let response = protocol
|
||||
.request::<context_server::types::request::CompletionComplete>(
|
||||
context_server::types::CompletionCompleteParams {
|
||||
reference: context_server::types::CompletionReference::Prompt(
|
||||
context_server::types::PromptReference {
|
||||
ty: context_server::types::PromptReferenceType::Prompt,
|
||||
name: prompt_name,
|
||||
},
|
||||
),
|
||||
argument: context_server::types::CompletionArgument {
|
||||
name: arg_name,
|
||||
value: arg_value,
|
||||
},
|
||||
),
|
||||
arg_name,
|
||||
arg_value,
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let completions = completion_result
|
||||
let completions = response
|
||||
.completion
|
||||
.values
|
||||
.into_iter()
|
||||
.map(|value| ArgumentCompletion {
|
||||
|
@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
if let Some(server) = store.get_running_server(&server_id) {
|
||||
cx.foreground_executor().spawn(async move {
|
||||
let protocol = server.client().context("Context server not initialized")?;
|
||||
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
|
||||
let response = protocol
|
||||
.request::<context_server::types::request::PromptsGet>(
|
||||
context_server::types::PromptsGetParams {
|
||||
name: prompt_name.clone(),
|
||||
arguments: Some(prompt_args),
|
||||
meta: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
result
|
||||
response
|
||||
.messages
|
||||
.iter()
|
||||
.all(|msg| matches!(msg.role, context_server::types::Role::User)),
|
||||
|
@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
);
|
||||
|
||||
// Extract text from user messages into a single prompt string
|
||||
let mut prompt = result
|
||||
let mut prompt = response
|
||||
.messages
|
||||
.into_iter()
|
||||
.filter_map(|msg| match msg.content {
|
||||
|
@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand {
|
|||
range: 0..(prompt.len()),
|
||||
icon: IconName::ZedAssistant,
|
||||
label: SharedString::from(
|
||||
result
|
||||
response
|
||||
.description
|
||||
.unwrap_or(format!("Result from {}", prompt_name)),
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue