diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 975e504580..ff109cf55a 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -690,7 +690,7 @@ impl LanguageModel for CloudLanguageModel { } CloudModel::OpenAi(model) => { let client = self.client.clone(); - let request = into_open_ai(request, model.id().into(), model.max_output_tokens()); + let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let response = Self::perform_llm_completion( diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 0f02642e25..774f1f7b2a 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -14,7 +14,7 @@ use language_model::{ LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, }; -use open_ai::{ResponseStreamEvent, stream_completion}; +use open_ai::{Model, ResponseStreamEvent, stream_completion}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; @@ -324,7 +324,7 @@ impl LanguageModel for OpenAiLanguageModel { 'static, Result>>, > { - let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens()); + let request = into_open_ai(request, &self.model, self.max_output_tokens()); let completions = self.stream_completion(request, cx); async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) } .boxed() @@ -333,10 +333,10 @@ impl LanguageModel for OpenAiLanguageModel { pub fn into_open_ai( request: LanguageModelRequest, - model: String, + model: &Model, max_output_tokens: Option, ) -> open_ai::Request { - let stream = !model.starts_with("o1-"); + let stream = !model.id().starts_with("o1-"); let mut messages = Vec::new(); for message in request.messages { @@ -389,12 +389,18 @@ pub fn into_open_ai( } open_ai::Request { - model, + model: model.id().into(), messages, stream, stop: request.stop, temperature: request.temperature.unwrap_or(1.0), max_tokens: max_output_tokens, + parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { + // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. + Some(false) + } else { + None + }, tools: request .tools .into_iter() diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 586d864da4..b9aa2ce7f0 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -162,6 +162,23 @@ impl Model { _ => None, } } + + /// Returns whether the given model supports the `parallel_tool_calls` parameter. + /// + /// If the model does not support the parameter, do not pass it up, or the API will return an error. + pub fn supports_parallel_tool_calls(&self) -> bool { + match self { + Self::ThreePointFiveTurbo + | Self::Four + | Self::FourTurbo + | Self::FourOmni + | Self::FourOmniMini + | Self::O1 + | Self::O1Preview + | Self::O1Mini => true, + _ => false, + } + } } #[derive(Debug, Serialize, Deserialize)] @@ -176,6 +193,9 @@ pub struct Request { pub temperature: f32, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, + /// Whether to enable parallel function calling during tool use. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, }