context_servers: Completion support for context server slash commands (#17085)
This PR adds support for completions via MCP. The protocol now supports a new request type "completion/complete" that can either complete a resource URI template (which we currently don't use in Zed), or a prompt argument. We use this to add autocompletion to our context server slash commands! https://github.com/user-attachments/assets/08c9cf04-cbeb-49a7-903f-5049fb3b3d9f Release Notes: - context_servers: Added support for argument completions for context server prompts. These show up as regular completions to slash commands.
This commit is contained in:
parent
01f8d27f22
commit
5bae6eb493
3 changed files with 179 additions and 5 deletions
|
@ -127,6 +127,35 @@ impl InitializedContextServerProtocol {
|
|||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
pub async fn completion<P: Into<String>>(
|
||||
&self,
|
||||
reference: types::CompletionReference,
|
||||
argument: P,
|
||||
value: P,
|
||||
) -> Result<types::Completion> {
|
||||
let params = types::CompletionCompleteParams {
|
||||
r#ref: reference,
|
||||
argument: types::CompletionArgument {
|
||||
name: argument.into(),
|
||||
value: value.into(),
|
||||
},
|
||||
};
|
||||
let result: types::CompletionCompleteResponse = self
|
||||
.inner
|
||||
.request(types::RequestType::CompletionComplete.as_str(), params)
|
||||
.await?;
|
||||
|
||||
let completion = types::Completion {
|
||||
values: result.completion.values,
|
||||
total: types::CompletionTotal::from_options(
|
||||
result.completion.has_more,
|
||||
result.completion.total,
|
||||
),
|
||||
};
|
||||
|
||||
Ok(completion)
|
||||
}
|
||||
}
|
||||
|
||||
impl InitializedContextServerProtocol {
|
||||
|
|
|
@ -14,6 +14,7 @@ pub enum RequestType {
|
|||
LoggingSetLevel,
|
||||
PromptsGet,
|
||||
PromptsList,
|
||||
CompletionComplete,
|
||||
}
|
||||
|
||||
impl RequestType {
|
||||
|
@ -28,6 +29,7 @@ impl RequestType {
|
|||
RequestType::LoggingSetLevel => "logging/setLevel",
|
||||
RequestType::PromptsGet => "prompts/get",
|
||||
RequestType::PromptsList => "prompts/list",
|
||||
RequestType::CompletionComplete => "completion/complete",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -78,6 +80,50 @@ pub struct PromptsGetParams {
|
|||
pub arguments: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionCompleteParams {
|
||||
pub r#ref: CompletionReference,
|
||||
pub argument: CompletionArgument,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CompletionReference {
|
||||
Prompt(PromptReference),
|
||||
Resource(ResourceReference),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptReference {
|
||||
pub r#type: PromptReferenceType,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PromptReferenceType {
|
||||
#[serde(rename = "ref/prompt")]
|
||||
Prompt,
|
||||
#[serde(rename = "ref/resource")]
|
||||
Resource,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ResourceReference {
|
||||
pub r#type: String,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionArgument {
|
||||
pub name: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResponse {
|
||||
|
@ -112,6 +158,20 @@ pub struct PromptsListResponse {
|
|||
pub prompts: Vec<PromptInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionCompleteResponse {
|
||||
pub completion: CompletionResult,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CompletionResult {
|
||||
pub values: Vec<String>,
|
||||
pub total: Option<u32>,
|
||||
pub has_more: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptInfo {
|
||||
|
@ -233,3 +293,26 @@ pub struct ProgressParams {
|
|||
pub progress: f64,
|
||||
pub total: Option<f64>,
|
||||
}
|
||||
|
||||
// Helper Types that don't map directly to the protocol
|
||||
|
||||
pub enum CompletionTotal {
|
||||
Exact(u32),
|
||||
HasMore,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl CompletionTotal {
|
||||
pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
|
||||
match (has_more, total) {
|
||||
(_, Some(count)) => CompletionTotal::Exact(count),
|
||||
(Some(true), _) => CompletionTotal::HasMore,
|
||||
_ => CompletionTotal::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Completion {
|
||||
pub values: Vec<String>,
|
||||
pub total: CompletionTotal,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue