context_servers: Fix argument handling (#16402)

This commit is contained in:
David Soria Parra 2024-08-18 04:04:34 +01:00 committed by GitHub
parent 5e6e465294
commit 10a996cbc4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -67,7 +67,11 @@ impl SlashCommand for ContextServerSlashCommand {
) -> Task<Result<SlashCommandOutput>> {
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
let argument = arguments.first().cloned();
let prompt_args = match prompt_arguments(&self.prompt, arguments) {
Ok(args) => args,
Err(e) => return Task::ready(Err(e)),
};
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
@ -76,10 +80,7 @@ impl SlashCommand for ContextServerSlashCommand {
let Some(protocol) = server.client.read().clone() else {
return Err(anyhow!("Context server not initialized"));
};
let result = protocol
.run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
.await?;
let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
Ok(SlashCommandOutput {
sections: vec![SlashCommandOutputSection {
@ -97,19 +98,27 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
fn prompt_arguments(
prompt: &PromptInfo,
argument: Option<String>,
) -> Result<HashMap<String, String>> {
fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
match &prompt.arguments {
Some(args) if args.len() >= 2 => Err(anyhow!(
Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported"
)),
Some(args) if args.len() == 1 => match argument {
Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
None => Err(anyhow!("Prompt expects argument but none given")),
},
Some(_) | None => Ok(HashMap::default()),
Some(args) if args.len() == 1 => {
if !arguments.is_empty() {
let mut map = HashMap::default();
map.insert(args[0].name.clone(), arguments.join(" "));
Ok(map)
} else {
Err(anyhow!("Prompt expects argument but none given"))
}
}
Some(_) | None => {
if arguments.is_empty() {
Ok(HashMap::default())
} else {
Err(anyhow!("Prompt expects no arguments but some were given"))
}
}
}
}