Use LLM service for tool call requests (#16046)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-09 13:22:58 -07:00 committed by GitHub
parent d96afde5bf
commit fbebb73d7b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 330 additions and 131 deletions

View file

@ -322,25 +322,33 @@ async fn perform_completion(
} }
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
match provider { let prefixes: &[_] = match provider {
LanguageModelProvider::Anthropic => { LanguageModelProvider::Anthropic => &[
for prefix in &[ "claude-3-5-sonnet",
"claude-3-5-sonnet", "claude-3-haiku",
"claude-3-haiku", "claude-3-opus",
"claude-3-opus", "claude-3-sonnet",
"claude-3-sonnet", ],
] { LanguageModelProvider::OpenAi => &[
if name.starts_with(prefix) { "gpt-3.5-turbo",
return prefix.to_string(); "gpt-4-turbo-preview",
} "gpt-4o-mini",
} "gpt-4o",
} "gpt-4",
LanguageModelProvider::OpenAi => {} ],
LanguageModelProvider::Google => {} LanguageModelProvider::Google => &[],
LanguageModelProvider::Zed => {} LanguageModelProvider::Zed => &[],
} };
name if let Some(prefix) = prefixes
.iter()
.filter(|&&prefix| name.starts_with(prefix))
.max_by_key(|&&prefix| prefix.len())
{
prefix.to_string()
} else {
name
}
} }
async fn check_usage_limit( async fn check_usage_limit(

View file

@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String, tool_name: String,
tool_description: String, tool_description: String,
input_schema: serde_json::Value, input_schema: serde_json::Value,
_cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> { ) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel {
input_schema, input_schema,
}]; }];
self.request_limiter if cx
.run(async move { .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
let request = serde_json::to_string(&request)?; .unwrap_or(false)
let response = client {
.request(proto::CompleteWithLanguageModel { let llm_api_token = self.llm_api_token.clone();
provider: proto::LanguageModelProvider::Anthropic as i32, self.request_limiter
request, .run(async move {
}) let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?; .await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?; let mut tool_use_index = None;
response let mut tool_input = String::new();
.content let mut body = BufReader::new(response.into_body());
.into_iter() let mut line = String::new();
.find_map(|content| { while body.read_line(&mut line).await? > 0 {
if let anthropic::Content::ToolUse { name, input, .. } = content { let event: anthropic::Event = serde_json::from_str(&line)?;
if name == tool_name { line.clear();
Some(input)
match event {
anthropic::Event::ContentBlockStart {
content_block,
index,
} => {
if let anthropic::Content::ToolUse { name, .. } =
content_block
{
if name == tool_name {
tool_use_index = Some(index);
}
}
}
anthropic::Event::ContentBlockDelta { index, delta } => {
match delta {
anthropic::ContentDelta::TextDelta { .. } => {}
anthropic::ContentDelta::InputJsonDelta {
partial_json,
} => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
}
}
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
})
.boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content
{
if name == tool_name {
Some(input)
} else {
None
}
} else { } else {
None None
} }
} else { })
None .context("tool not used")
} })
}) .boxed()
.context("tool not used") }
})
.boxed()
} }
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into()); let mut request = request.into_open_ai(model.id().into());
@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description); function.description = Some(tool_description);
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls { if cx
if let Some(func) = call.function { .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
if func.name.as_deref() == Some(tool_name.as_str()) { .unwrap_or(false)
load_state = Some((String::default(), call.index)); {
} let llm_api_token = self.llm_api_token.clone();
if let Some((arguments, (output, index))) = self.request_limiter
func.arguments.zip(load_state.as_mut()) .run(async move {
{ let response = Self::perform_llm_completion(
if call.index == *index { client.clone(),
output.push_str(&arguments); llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
} }
} }
} }
} }
} }
}
if let Some((arguments, _)) = load_state { if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?); return Ok(serde_json::from_str(&arguments)?);
} else { } else {
bail!("tool not used"); bail!("tool not used");
} }
}) })
.boxed() .boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
} }
CloudModel::Google(_) => { CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description); function.description = Some(tool_description);
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls { if cx
if let Some(func) = call.function { .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
if func.name.as_deref() == Some(tool_name.as_str()) { .unwrap_or(false)
load_state = Some((String::default(), call.index)); {
} let llm_api_token = self.llm_api_token.clone();
if let Some((arguments, (output, index))) = self.request_limiter
func.arguments.zip(load_state.as_mut()) .run(async move {
{ let response = Self::perform_llm_completion(
if call.index == *index { client.clone(),
output.push_str(&arguments); llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Zed,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
} }
} }
} }
} }
} }
} if let Some((arguments, _)) = load_state {
if let Some((arguments, _)) = load_state { return Ok(serde_json::from_str(&arguments)?);
return Ok(serde_json::from_str(&arguments)?); } else {
} else { bail!("tool not used");
bail!("tool not used"); }
} })
}) .boxed()
.boxed() } else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
} }
} }
} }