context_servers: Fix argument handling (#16402)

This commit is contained in:
David Soria Parra 2024-08-18 04:04:34 +01:00 committed by GitHub
parent 5e6e465294
commit 10a996cbc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -67,7 +67,11 @@ impl SlashCommand for ContextServerSlashCommand {
) -> Task<Result<SlashCommandOutput>> { ) -> Task<Result<SlashCommandOutput>> {
let server_id = self.server_id.clone(); let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone(); let prompt_name = self.prompt.name.clone();
let argument = arguments.first().cloned();
let prompt_args = match prompt_arguments(&self.prompt, arguments) {
Ok(args) => args,
Err(e) => return Task::ready(Err(e)),
};
let manager = ContextServerManager::global(cx); let manager = ContextServerManager::global(cx);
let manager = manager.read(cx); let manager = manager.read(cx);
@ -76,10 +80,7 @@ impl SlashCommand for ContextServerSlashCommand {
let Some(protocol) = server.client.read().clone() else { let Some(protocol) = server.client.read().clone() else {
return Err(anyhow!("Context server not initialized")); return Err(anyhow!("Context server not initialized"));
}; };
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
let result = protocol
.run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
.await?;
Ok(SlashCommandOutput { Ok(SlashCommandOutput {
sections: vec![SlashCommandOutputSection { sections: vec![SlashCommandOutputSection {
@ -97,19 +98,27 @@ impl SlashCommand for ContextServerSlashCommand {
} }
} }
fn prompt_arguments( fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
prompt: &PromptInfo,
argument: Option<String>,
) -> Result<HashMap<String, String>> {
match &prompt.arguments { match &prompt.arguments {
Some(args) if args.len() >= 2 => Err(anyhow!( Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported" "Prompt has more than one argument, which is not supported"
)), )),
Some(args) if args.len() == 1 => match argument { Some(args) if args.len() == 1 => {
Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])), if !arguments.is_empty() {
None => Err(anyhow!("Prompt expects argument but none given")), let mut map = HashMap::default();
}, map.insert(args[0].name.clone(), arguments.join(" "));
Some(_) | None => Ok(HashMap::default()), Ok(map)
} else {
Err(anyhow!("Prompt expects argument but none given"))
}
}
Some(_) | None => {
if arguments.is_empty() {
Ok(HashMap::default())
} else {
Err(anyhow!("Prompt expects no arguments but some were given"))
}
}
} }
} }