Use LLM service for tool call requests (#16046)
Release Notes: - N/A --------- Co-authored-by: Marshall <marshall@zed.dev>
This commit is contained in:
parent
d96afde5bf
commit
fbebb73d7b
2 changed files with 330 additions and 131 deletions
|
@ -322,25 +322,33 @@ async fn perform_completion(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
|
||||||
match provider {
|
let prefixes: &[_] = match provider {
|
||||||
LanguageModelProvider::Anthropic => {
|
LanguageModelProvider::Anthropic => &[
|
||||||
for prefix in &[
|
|
||||||
"claude-3-5-sonnet",
|
"claude-3-5-sonnet",
|
||||||
"claude-3-haiku",
|
"claude-3-haiku",
|
||||||
"claude-3-opus",
|
"claude-3-opus",
|
||||||
"claude-3-sonnet",
|
"claude-3-sonnet",
|
||||||
] {
|
],
|
||||||
if name.starts_with(prefix) {
|
LanguageModelProvider::OpenAi => &[
|
||||||
return prefix.to_string();
|
"gpt-3.5-turbo",
|
||||||
}
|
"gpt-4-turbo-preview",
|
||||||
}
|
"gpt-4o-mini",
|
||||||
}
|
"gpt-4o",
|
||||||
LanguageModelProvider::OpenAi => {}
|
"gpt-4",
|
||||||
LanguageModelProvider::Google => {}
|
],
|
||||||
LanguageModelProvider::Zed => {}
|
LanguageModelProvider::Google => &[],
|
||||||
}
|
LanguageModelProvider::Zed => &[],
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(prefix) = prefixes
|
||||||
|
.iter()
|
||||||
|
.filter(|&&prefix| name.starts_with(prefix))
|
||||||
|
.max_by_key(|&&prefix| prefix.len())
|
||||||
|
{
|
||||||
|
prefix.to_string()
|
||||||
|
} else {
|
||||||
name
|
name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn check_usage_limit(
|
async fn check_usage_limit(
|
||||||
|
|
|
@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
|
@ -605,6 +605,76 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
if cx
|
||||||
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::Anthropic,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(
|
||||||
|
serde_json::to_string(&request)?,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut tool_use_index = None;
|
||||||
|
let mut tool_input = String::new();
|
||||||
|
let mut body = BufReader::new(response.into_body());
|
||||||
|
let mut line = String::new();
|
||||||
|
while body.read_line(&mut line).await? > 0 {
|
||||||
|
let event: anthropic::Event = serde_json::from_str(&line)?;
|
||||||
|
line.clear();
|
||||||
|
|
||||||
|
match event {
|
||||||
|
anthropic::Event::ContentBlockStart {
|
||||||
|
content_block,
|
||||||
|
index,
|
||||||
|
} => {
|
||||||
|
if let anthropic::Content::ToolUse { name, .. } =
|
||||||
|
content_block
|
||||||
|
{
|
||||||
|
if name == tool_name {
|
||||||
|
tool_use_index = Some(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anthropic::Event::ContentBlockDelta { index, delta } => {
|
||||||
|
match delta {
|
||||||
|
anthropic::ContentDelta::TextDelta { .. } => {}
|
||||||
|
anthropic::ContentDelta::InputJsonDelta {
|
||||||
|
partial_json,
|
||||||
|
} => {
|
||||||
|
if Some(index) == tool_use_index {
|
||||||
|
tool_input.push_str(&partial_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
anthropic::Event::ContentBlockStop { index } => {
|
||||||
|
if Some(index) == tool_use_index {
|
||||||
|
return Ok(serde_json::from_str(&tool_input)?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_use_index.is_some() {
|
||||||
|
Err(anyhow!("tool content incomplete"))
|
||||||
|
} else {
|
||||||
|
Err(anyhow!("tool not used"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
|
@ -620,7 +690,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
.content
|
.content
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_map(|content| {
|
.find_map(|content| {
|
||||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
if let anthropic::Content::ToolUse { name, input, .. } = content
|
||||||
|
{
|
||||||
if name == tool_name {
|
if name == tool_name {
|
||||||
Some(input)
|
Some(input)
|
||||||
} else {
|
} else {
|
||||||
|
@ -634,6 +705,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
|
@ -650,6 +722,66 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.description = Some(tool_description);
|
function.description = Some(tool_description);
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
|
if cx
|
||||||
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::OpenAi,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(
|
||||||
|
serde_json::to_string(&request)?,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut body = BufReader::new(response.into_body());
|
||||||
|
let mut line = String::new();
|
||||||
|
let mut load_state = None;
|
||||||
|
|
||||||
|
while body.read_line(&mut line).await? > 0 {
|
||||||
|
let part: open_ai::ResponseStreamEvent =
|
||||||
|
serde_json::from_str(&line)?;
|
||||||
|
line.clear();
|
||||||
|
|
||||||
|
for choice in part.choices {
|
||||||
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for call in tool_calls {
|
||||||
|
if let Some(func) = call.function {
|
||||||
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
|
load_state = Some((String::default(), call.index));
|
||||||
|
}
|
||||||
|
if let Some((arguments, (output, index))) =
|
||||||
|
func.arguments.zip(load_state.as_mut())
|
||||||
|
{
|
||||||
|
if call.index == *index {
|
||||||
|
output.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some((arguments, _)) = load_state {
|
||||||
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
|
} else {
|
||||||
|
bail!("tool not used");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
|
@ -659,7 +791,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
request,
|
request,
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
// Call arguments are gonna be streamed in over multiple chunks.
|
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
let mut response = response.map(
|
let mut response = response.map(
|
||||||
|item: Result<
|
|item: Result<
|
||||||
|
@ -701,6 +832,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
CloudModel::Google(_) => {
|
CloudModel::Google(_) => {
|
||||||
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
|
||||||
}
|
}
|
||||||
|
@ -721,6 +853,65 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
function.description = Some(tool_description);
|
function.description = Some(tool_description);
|
||||||
function.parameters = Some(input_schema);
|
function.parameters = Some(input_schema);
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
||||||
|
|
||||||
|
if cx
|
||||||
|
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
self.request_limiter
|
||||||
|
.run(async move {
|
||||||
|
let response = Self::perform_llm_completion(
|
||||||
|
client.clone(),
|
||||||
|
llm_api_token,
|
||||||
|
PerformCompletionParams {
|
||||||
|
provider: client::LanguageModelProvider::Zed,
|
||||||
|
model: request.model.clone(),
|
||||||
|
provider_request: RawValue::from_string(
|
||||||
|
serde_json::to_string(&request)?,
|
||||||
|
)?,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut body = BufReader::new(response.into_body());
|
||||||
|
let mut line = String::new();
|
||||||
|
let mut load_state = None;
|
||||||
|
|
||||||
|
while body.read_line(&mut line).await? > 0 {
|
||||||
|
let part: open_ai::ResponseStreamEvent =
|
||||||
|
serde_json::from_str(&line)?;
|
||||||
|
line.clear();
|
||||||
|
|
||||||
|
for choice in part.choices {
|
||||||
|
let Some(tool_calls) = choice.delta.tool_calls else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for call in tool_calls {
|
||||||
|
if let Some(func) = call.function {
|
||||||
|
if func.name.as_deref() == Some(tool_name.as_str()) {
|
||||||
|
load_state = Some((String::default(), call.index));
|
||||||
|
}
|
||||||
|
if let Some((arguments, (output, index))) =
|
||||||
|
func.arguments.zip(load_state.as_mut())
|
||||||
|
{
|
||||||
|
if call.index == *index {
|
||||||
|
output.push_str(&arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some((arguments, _)) = load_state {
|
||||||
|
return Ok(serde_json::from_str(&arguments)?);
|
||||||
|
} else {
|
||||||
|
bail!("tool not used");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
} else {
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let request = serde_json::to_string(&request)?;
|
let request = serde_json::to_string(&request)?;
|
||||||
|
@ -730,7 +921,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
request,
|
request,
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
// Call arguments are gonna be streamed in over multiple chunks.
|
|
||||||
let mut load_state = None;
|
let mut load_state = None;
|
||||||
let mut response = response.map(
|
let mut response = response.map(
|
||||||
|item: Result<
|
|item: Result<
|
||||||
|
@ -774,6 +964,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmApiToken {
|
impl LlmApiToken {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue