Use LLM service for tool call requests (#16046)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
d96afde5bf
commit
fbebb73d7b
2 changed files with 330 additions and 131 deletions
|
@ -322,25 +322,33 @@ async fn perform_completion(
|
|||
}
|
||||
|
||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||
match provider {
|
||||
LanguageModelProvider::Anthropic => {
|
||||
for prefix in &[
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
] {
|
||||
if name.starts_with(prefix) {
|
||||
return prefix.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
LanguageModelProvider::OpenAi => {}
|
||||
LanguageModelProvider::Google => {}
|
||||
LanguageModelProvider::Zed => {}
|
||||
}
|
||||
let prefixes: &[_] = match provider {
|
||||
LanguageModelProvider::Anthropic => &[
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
],
|
||||
LanguageModelProvider::OpenAi => &[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4",
|
||||
],
|
||||
LanguageModelProvider::Google => &[],
|
||||
LanguageModelProvider::Zed => &[],
|
||||
};
|
||||
|
||||
name
|
||||
if let Some(prefix) = prefixes
|
||||
.iter()
|
||||
.filter(|&&prefix| name.starts_with(prefix))
|
||||
.max_by_key(|&&prefix| prefix.len())
|
||||
{
|
||||
prefix.to_string()
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_usage_limit(
|
||||
|
|
|
@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
tool_name: String,
|
||||
tool_description: String,
|
||||
input_schema: serde_json::Value,
|
||||
_cx: &AsyncAppContext,
|
||||
cx: &AsyncAppContext,
|
||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => {
|
||||
|
@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel {
|
|||
input_schema,
|
||||
}];
|
||||
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Anthropic,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let response: anthropic::Response =
|
||||
serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
|
||||
let mut tool_use_index = None;
|
||||
let mut tool_input = String::new();
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
match event {
|
||||
anthropic::Event::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => {
|
||||
if let anthropic::Content::ToolUse { name, .. } =
|
||||
content_block
|
||||
{
|
||||
if name == tool_name {
|
||||
tool_use_index = Some(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockDelta { index, delta } => {
|
||||
match delta {
|
||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||
anthropic::ContentDelta::InputJsonDelta {
|
||||
partial_json,
|
||||
} => {
|
||||
if Some(index) == tool_use_index {
|
||||
tool_input.push_str(&partial_json);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
anthropic::Event::ContentBlockStop { index } => {
|
||||
if Some(index) == tool_use_index {
|
||||
return Ok(serde_json::from_str(&tool_input)?);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if tool_use_index.is_some() {
|
||||
Err(anyhow!("tool content incomplete"))
|
||||
} else {
|
||||
Err(anyhow!("tool not used"))
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request(proto::CompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::Anthropic as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let response: anthropic::Response =
|
||||
serde_json::from_str(&response.completion)?;
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find_map(|content| {
|
||||
if let anthropic::Content::ToolUse { name, input, .. } = content
|
||||
{
|
||||
if name == tool_name {
|
||||
Some(input)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
})
|
||||
.context("tool not used")
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let mut request = request.into_open_ai(model.id().into());
|
||||
|
@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel {
|
|||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
// Call arguments are gonna be streamed in over multiple chunks.
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::OpenAi,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
CloudModel::Google(_) => {
|
||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||
|
@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel {
|
|||
function.description = Some(tool_description);
|
||||
function.parameters = Some(input_schema);
|
||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
// Call arguments are gonna be streamed in over multiple chunks.
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
if cx
|
||||
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
client.clone(),
|
||||
llm_api_token,
|
||||
PerformCompletionParams {
|
||||
provider: client::LanguageModelProvider::Zed,
|
||||
model: request.model.clone(),
|
||||
provider_request: RawValue::from_string(
|
||||
serde_json::to_string(&request)?,
|
||||
)?,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut body = BufReader::new(response.into_body());
|
||||
let mut line = String::new();
|
||||
let mut load_state = None;
|
||||
|
||||
while body.read_line(&mut line).await? > 0 {
|
||||
let part: open_ai::ResponseStreamEvent =
|
||||
serde_json::from_str(&line)?;
|
||||
line.clear();
|
||||
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
self.request_limiter
|
||||
.run(async move {
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let response = client
|
||||
.request_stream(proto::StreamCompleteWithLanguageModel {
|
||||
provider: proto::LanguageModelProvider::OpenAi as i32,
|
||||
request,
|
||||
})
|
||||
.await?;
|
||||
let mut load_state = None;
|
||||
let mut response = response.map(
|
||||
|item: Result<
|
||||
proto::StreamCompleteWithLanguageModelResponse,
|
||||
anyhow::Error,
|
||||
>| {
|
||||
Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
|
||||
serde_json::from_str(&item?.event)?,
|
||||
)
|
||||
},
|
||||
);
|
||||
while let Some(Ok(part)) = response.next().await {
|
||||
for choice in part.choices {
|
||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for call in tool_calls {
|
||||
if let Some(func) = call.function {
|
||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||
load_state = Some((String::default(), call.index));
|
||||
}
|
||||
if let Some((arguments, (output, index))) =
|
||||
func.arguments.zip(load_state.as_mut())
|
||||
{
|
||||
if call.index == *index {
|
||||
output.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((arguments, _)) = load_state {
|
||||
return Ok(serde_json::from_str(&arguments)?);
|
||||
} else {
|
||||
bail!("tool not used");
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue