Use LLM service for tool call requests (#16046)

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
Max Brunsfeld 2024-08-09 13:22:58 -07:00 committed by GitHub
parent d96afde5bf
commit fbebb73d7b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 330 additions and 131 deletions

View file

@ -322,25 +322,33 @@ async fn perform_completion(
}
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
match provider {
LanguageModelProvider::Anthropic => {
for prefix in &[
"claude-3-5-sonnet",
"claude-3-haiku",
"claude-3-opus",
"claude-3-sonnet",
] {
if name.starts_with(prefix) {
return prefix.to_string();
}
}
}
LanguageModelProvider::OpenAi => {}
LanguageModelProvider::Google => {}
LanguageModelProvider::Zed => {}
}
let prefixes: &[_] = match provider {
LanguageModelProvider::Anthropic => &[
"claude-3-5-sonnet",
"claude-3-haiku",
"claude-3-opus",
"claude-3-sonnet",
],
LanguageModelProvider::OpenAi => &[
"gpt-3.5-turbo",
"gpt-4-turbo-preview",
"gpt-4o-mini",
"gpt-4o",
"gpt-4",
],
LanguageModelProvider::Google => &[],
LanguageModelProvider::Zed => &[],
};
name
if let Some(prefix) = prefixes
.iter()
.filter(|&&prefix| name.starts_with(prefix))
.max_by_key(|&&prefix| prefix.len())
{
prefix.to_string()
} else {
name
}
}
async fn check_usage_limit(

View file

@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Anthropic,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
let mut tool_use_index = None;
let mut tool_input = String::new();
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
while body.read_line(&mut line).await? > 0 {
let event: anthropic::Event = serde_json::from_str(&line)?;
line.clear();
match event {
anthropic::Event::ContentBlockStart {
content_block,
index,
} => {
if let anthropic::Content::ToolUse { name, .. } =
content_block
{
if name == tool_name {
tool_use_index = Some(index);
}
}
}
anthropic::Event::ContentBlockDelta { index, delta } => {
match delta {
anthropic::ContentDelta::TextDelta { .. } => {}
anthropic::ContentDelta::InputJsonDelta {
partial_json,
} => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
}
}
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
})
.boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request(proto::CompleteWithLanguageModel {
provider: proto::LanguageModelProvider::Anthropic as i32,
request,
})
.await?;
let response: anthropic::Response =
serde_json::from_str(&response.completion)?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content
{
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
} else {
None
}
})
.context("tool not used")
})
.boxed()
})
.context("tool not used")
})
.boxed()
}
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::OpenAi,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
if cx
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
client.clone(),
llm_api_token,
PerformCompletionParams {
provider: client::LanguageModelProvider::Zed,
model: request.model.clone(),
provider_request: RawValue::from_string(
serde_json::to_string(&request)?,
)?,
},
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent =
serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
} else {
self.request_limiter
.run(async move {
let request = serde_json::to_string(&request)?;
let response = client
.request_stream(proto::StreamCompleteWithLanguageModel {
provider: proto::LanguageModelProvider::OpenAi as i32,
request,
})
.await?;
let mut load_state = None;
let mut response = response.map(
|item: Result<
proto::StreamCompleteWithLanguageModelResponse,
anyhow::Error,
>| {
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
serde_json::from_str(&item?.event)?,
)
},
);
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
})
.boxed()
}
}
}
}