Use LLM service for tool call requests (#16046)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-09 13:22:58 -07:00 committed by GitHub
parent d96afde5bf
commit fbebb73d7b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 330 additions and 131 deletions

View file

@ -322,25 +322,33 @@ async fn perform_completion(
} }
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
match provider { let prefixes: &[_] = match provider {
LanguageModelProvider::Anthropic => { LanguageModelProvider::Anthropic => &[
for prefix in &[
"claude-3-5-sonnet", "claude-3-5-sonnet",
"claude-3-haiku", "claude-3-haiku",
"claude-3-opus", "claude-3-opus",
"claude-3-sonnet", "claude-3-sonnet",
] { ],
if name.starts_with(prefix) { LanguageModelProvider::OpenAi => &[
return prefix.to_string(); "gpt-3.5-turbo",
} "gpt-4-turbo-preview",
} "gpt-4o-mini",
} "gpt-4o",
LanguageModelProvider::OpenAi => {} "gpt-4",
LanguageModelProvider::Google => {} ],
LanguageModelProvider::Zed => {} LanguageModelProvider::Google => &[],
} LanguageModelProvider::Zed => &[],
};
if let Some(prefix) = prefixes
.iter()
.filter(|&&prefix| name.starts_with(prefix))
.max_by_key(|&&prefix| prefix.len())
{
prefix.to_string()
} else {
name name
}
} }
async fn check_usage_limit( async fn check_usage_limit(

View file

@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String, tool_name: String,
tool_description: String, tool_description: String,
input_schema: serde_json::Value, input_schema: serde_json::Value,
_cx: &AsyncAppContext, cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> { ) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model { match &self.model {
CloudModel::Anthropic(model) => { CloudModel::Anthropic(model) => {
@ -605,6 +605,76 @@ impl LanguageModel for CloudLanguageModel {
input_schema, input_schema,
}]; }];
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut tool_use_index = None;
let mut tool_input = String::new();
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
while body.read_line(&mut line).await? > 0 {
let event: anthropic::Event = serde_json::from_str(&line)?;
line.clear();
match event {
anthropic::Event::ContentBlockStart {
content_block,
index,
} => {
if let anthropic::Content::ToolUse { name, .. } =
content_block
{
if name == tool_name {
tool_use_index = Some(index);
}
}
}
anthropic::Event::ContentBlockDelta { index, delta } => {
match delta {
anthropic::ContentDelta::TextDelta { .. } => {}
anthropic::ContentDelta::InputJsonDelta {
partial_json,
} => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
}
}
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
})
.boxed()
} else {
self.request_limiter self.request_limiter
.run(async move { .run(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
@ -620,7 +690,8 @@ impl LanguageModel for CloudLanguageModel {
.content .content
.into_iter() .into_iter()
.find_map(|content| { .find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content { if let anthropic::Content::ToolUse { name, input, .. } = content
{
if name == tool_name { if name == tool_name {
Some(input) Some(input)
} else { } else {
@ -634,6 +705,7 @@ impl LanguageModel for CloudLanguageModel {
}) })
.boxed() .boxed()
} }
}
CloudModel::OpenAi(model) => { CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into()); let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone(); let client = self.client.clone();
@ -650,6 +722,66 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description); function.description = Some(tool_description);
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
} else {
self.request_limiter self.request_limiter
.run(async move { .run(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
@ -659,7 +791,6 @@ impl LanguageModel for CloudLanguageModel {
request, request,
}) })
.await?; .await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None; let mut load_state = None;
let mut response = response.map( let mut response = response.map(
|item: Result< |item: Result<
@ -701,6 +832,7 @@ impl LanguageModel for CloudLanguageModel {
}) })
.boxed() .boxed()
} }
}
CloudModel::Google(_) => { CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
} }
@ -721,6 +853,65 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description); function.description = Some(tool_description);
function.parameters = Some(input_schema); function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }]; request.tools = vec![open_ai::ToolDefinition::Function { function }];
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Zed,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
} else {
self.request_limiter self.request_limiter
.run(async move { .run(async move {
let request = serde_json::to_string(&request)?; let request = serde_json::to_string(&request)?;
@ -730,7 +921,6 @@ impl LanguageModel for CloudLanguageModel {
request, request,
}) })
.await?; .await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None; let mut load_state = None;
let mut response = response.map( let mut response = response.map(
|item: Result< |item: Result<
@ -774,6 +964,7 @@ impl LanguageModel for CloudLanguageModel {
} }
} }
} }
}
} }
impl LlmApiToken { impl LlmApiToken {