diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index f35d77533a..5bab54799c 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -322,25 +322,33 @@ async fn perform_completion( } fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String { - match provider { - LanguageModelProvider::Anthropic => { - for prefix in &[ - "claude-3-5-sonnet", - "claude-3-haiku", - "claude-3-opus", - "claude-3-sonnet", - ] { - if name.starts_with(prefix) { - return prefix.to_string(); - } - } - } - LanguageModelProvider::OpenAi => {} - LanguageModelProvider::Google => {} - LanguageModelProvider::Zed => {} - } + let prefixes: &[_] = match provider { + LanguageModelProvider::Anthropic => &[ + "claude-3-5-sonnet", + "claude-3-haiku", + "claude-3-opus", + "claude-3-sonnet", + ], + LanguageModelProvider::OpenAi => &[ + "gpt-3.5-turbo", + "gpt-4-turbo-preview", + "gpt-4o-mini", + "gpt-4o", + "gpt-4", + ], + LanguageModelProvider::Google => &[], + LanguageModelProvider::Zed => &[], + }; - name + if let Some(prefix) = prefixes + .iter() + .filter(|&&prefix| name.starts_with(prefix)) + .max_by_key(|&&prefix| prefix.len()) + { + prefix.to_string() + } else { + name + } } async fn check_usage_limit( diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index b413f8d2cb..2f6651a2eb 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel { tool_name: String, tool_description: String, input_schema: serde_json::Value, - _cx: &AsyncAppContext, + cx: &AsyncAppContext, ) -> BoxFuture<'static, Result> { match &self.model { CloudModel::Anthropic(model) => { @@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel { input_schema, }]; - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request(proto::CompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - request, - }) + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: RawValue::from_string( + serde_json::to_string(&request)?, + )?, + }, + ) .await?; - let response: anthropic::Response = - serde_json::from_str(&response.completion)?; - response - .content - .into_iter() - .find_map(|content| { - if let anthropic::Content::ToolUse { name, input, .. } = content { - if name == tool_name { - Some(input) + + let mut tool_use_index = None; + let mut tool_input = String::new(); + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + while body.read_line(&mut line).await? > 0 { + let event: anthropic::Event = serde_json::from_str(&line)?; + line.clear(); + + match event { + anthropic::Event::ContentBlockStart { + content_block, + index, + } => { + if let anthropic::Content::ToolUse { name, .. } = + content_block + { + if name == tool_name { + tool_use_index = Some(index); + } + } + } + anthropic::Event::ContentBlockDelta { index, delta } => { + match delta { + anthropic::ContentDelta::TextDelta { .. } => {} + anthropic::ContentDelta::InputJsonDelta { + partial_json, + } => { + if Some(index) == tool_use_index { + tool_input.push_str(&partial_json); + } + } + } + } + anthropic::Event::ContentBlockStop { index } => { + if Some(index) == tool_use_index { + return Ok(serde_json::from_str(&tool_input)?); + } + } + _ => {} + } + } + + if tool_use_index.is_some() { + Err(anyhow!("tool content incomplete")) + } else { + Err(anyhow!("tool not used")) + } + }) + .boxed() + } else { + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request(proto::CompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + request, + }) + .await?; + let response: anthropic::Response = + serde_json::from_str(&response.completion)?; + response + .content + .into_iter() + .find_map(|content| { + if let anthropic::Content::ToolUse { name, input, .. } = content + { + if name == tool_name { + Some(input) + } else { + None + } } else { None } - } else { - None - } - }) - .context("tool not used") - }) - .boxed() + }) + .context("tool not used") + }) + .boxed() + } } CloudModel::OpenAi(model) => { let mut request = request.into_open_ai(model.id().into()); @@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel { function.description = Some(tool_description); function.parameters = Some(input_schema); request.tools = vec![open_ai::ToolDefinition::Function { function }]; - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) - .await?; - // Call arguments are gonna be streamed in over multiple chunks. - let mut load_state = None; - let mut response = response.map( - |item: Result< - proto::StreamCompleteWithLanguageModelResponse, - anyhow::Error, - >| { - Result::::Ok( - serde_json::from_str(&item?.event)?, - ) - }, - ); - while let Some(Ok(part)) = response.next().await { - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: RawValue::from_string( + serde_json::to_string(&request)?, + )?, + }, + ) + .await?; + + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + let mut load_state = None; + + while body.read_line(&mut line).await? > 0 { + let part: open_ai::ResponseStreamEvent = + serde_json::from_str(&line)?; + line.clear(); + + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } } } } } } - } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() + + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() + } else { + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; + let mut load_state = None; + let mut response = response.map( + |item: Result< + proto::StreamCompleteWithLanguageModelResponse, + anyhow::Error, + >| { + Result::::Ok( + serde_json::from_str(&item?.event)?, + ) + }, + ); + while let Some(Ok(part)) = response.next().await { + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } + } + } + } + } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() + } } CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() @@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel { function.description = Some(tool_description); function.parameters = Some(input_schema); request.tools = vec![open_ai::ToolDefinition::Function { function }]; - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) - .await?; - // Call arguments are gonna be streamed in over multiple chunks. - let mut load_state = None; - let mut response = response.map( - |item: Result< - proto::StreamCompleteWithLanguageModelResponse, - anyhow::Error, - >| { - Result::::Ok( - serde_json::from_str(&item?.event)?, - ) - }, - ); - while let Some(Ok(part)) = response.next().await { - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); + if cx + .update(|cx| cx.has_flag::()) + .unwrap_or(false) + { + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Zed, + model: request.model.clone(), + provider_request: RawValue::from_string( + serde_json::to_string(&request)?, + )?, + }, + ) + .await?; + + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + let mut load_state = None; + + while body.read_line(&mut line).await? > 0 { + let part: open_ai::ResponseStreamEvent = + serde_json::from_str(&line)?; + line.clear(); + + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } } } } } } - } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() + } else { + self.request_limiter + .run(async move { + let request = serde_json::to_string(&request)?; + let response = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; + let mut load_state = None; + let mut response = response.map( + |item: Result< + proto::StreamCompleteWithLanguageModelResponse, + anyhow::Error, + >| { + Result::::Ok( + serde_json::from_str(&item?.event)?, + ) + }, + ); + while let Some(Ok(part)) = response.next().await { + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); + } + } + } + } + } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() + } } } }