context_servers: Completion support for context server slash commands (#17085)

This PR adds support for completions via MCP. The protocol now supports
a new request type "completion/complete"
that can either complete a resource URI template (which we currently
don't use in Zed), or a prompt argument.
We use this to add autocompletion to our context server slash commands!


https://github.com/user-attachments/assets/08c9cf04-cbeb-49a7-903f-5049fb3b3d9f



Release Notes:

- context_servers: Added support for argument completions for context
server prompts. These show up as regular completions to slash commands.
This commit is contained in:
David Soria Parra 2024-08-29 21:56:58 +01:00 committed by GitHub
parent 01f8d27f22
commit 5bae6eb493
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 179 additions and 5 deletions

View file

@ -127,6 +127,35 @@ impl InitializedContextServerProtocol {
Ok(response)
}
pub async fn completion<P: Into<String>>(
&self,
reference: types::CompletionReference,
argument: P,
value: P,
) -> Result<types::Completion> {
let params = types::CompletionCompleteParams {
r#ref: reference,
argument: types::CompletionArgument {
name: argument.into(),
value: value.into(),
},
};
let result: types::CompletionCompleteResponse = self
.inner
.request(types::RequestType::CompletionComplete.as_str(), params)
.await?;
let completion = types::Completion {
values: result.completion.values,
total: types::CompletionTotal::from_options(
result.completion.has_more,
result.completion.total,
),
};
Ok(completion)
}
}
impl InitializedContextServerProtocol {

View file

@ -14,6 +14,7 @@ pub enum RequestType {
LoggingSetLevel,
PromptsGet,
PromptsList,
CompletionComplete,
}
impl RequestType {
@ -28,6 +29,7 @@ impl RequestType {
RequestType::LoggingSetLevel => "logging/setLevel",
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
RequestType::CompletionComplete => "completion/complete",
}
}
}
@ -78,6 +80,50 @@ pub struct PromptsGetParams {
pub arguments: Option<HashMap<String, String>>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteParams {
pub r#ref: CompletionReference,
pub argument: CompletionArgument,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum CompletionReference {
Prompt(PromptReference),
Resource(ResourceReference),
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptReference {
pub r#type: PromptReferenceType,
pub name: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PromptReferenceType {
#[serde(rename = "ref/prompt")]
Prompt,
#[serde(rename = "ref/resource")]
Resource,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceReference {
pub r#type: String,
pub uri: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionArgument {
pub name: String,
pub value: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
@ -112,6 +158,20 @@ pub struct PromptsListResponse {
pub prompts: Vec<PromptInfo>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse {
pub completion: CompletionResult,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionResult {
pub values: Vec<String>,
pub total: Option<u32>,
pub has_more: Option<bool>,
}
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PromptInfo {
@ -233,3 +293,26 @@ pub struct ProgressParams {
pub progress: f64,
pub total: Option<f64>,
}
// Helper Types that don't map directly to the protocol
pub enum CompletionTotal {
Exact(u32),
HasMore,
Unknown,
}
impl CompletionTotal {
pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
match (has_more, total) {
(_, Some(count)) => CompletionTotal::Exact(count),
(Some(true), _) => CompletionTotal::HasMore,
_ => CompletionTotal::Unknown,
}
}
}
pub struct Completion {
pub values: Vec<String>,
pub total: CompletionTotal,
}