open_ai: Disable parallel_tool_calls
(#28056)
This PR disables `parallel_tool_calls` for the models that support it, as the Agent currently expects at most one tool use per turn. It was a bit of trial and error to figure this out. OpenAI's API annoyingly will return an error if passing `parallel_tool_calls` to a model that doesn't support it. Release Notes: - N/A
This commit is contained in:
parent
c6e2d20a02
commit
819bb8fffb
3 changed files with 32 additions and 6 deletions
|
@ -690,7 +690,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
CloudModel::OpenAi(model) => {
|
||||
let client = self.client.clone();
|
||||
let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
|
||||
let request = into_open_ai(request, model, model.max_output_tokens());
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = Self::perform_llm_completion(
|
||||
|
|
|
@ -14,7 +14,7 @@ use language_model::{
|
|||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
use open_ai::{ResponseStreamEvent, stream_completion};
|
||||
use open_ai::{Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -324,7 +324,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
|
||||
> {
|
||||
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
|
||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
|
||||
.boxed()
|
||||
|
@ -333,10 +333,10 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
|
||||
pub fn into_open_ai(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
model: &Model,
|
||||
max_output_tokens: Option<u32>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model.starts_with("o1-");
|
||||
let stream = !model.id().starts_with("o1-");
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
|
@ -389,12 +389,18 @@ pub fn into_open_ai(
|
|||
}
|
||||
|
||||
open_ai::Request {
|
||||
model,
|
||||
model: model.id().into(),
|
||||
messages,
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
max_tokens: max_output_tokens,
|
||||
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
|
||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
|
|
|
@ -162,6 +162,23 @@ impl Model {
|
|||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the given model supports the `parallel_tool_calls` parameter.
|
||||
///
|
||||
/// If the model does not support the parameter, do not pass it up, or the API will return an error.
|
||||
pub fn supports_parallel_tool_calls(&self) -> bool {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo
|
||||
| Self::Four
|
||||
| Self::FourTurbo
|
||||
| Self::FourOmni
|
||||
| Self::FourOmniMini
|
||||
| Self::O1
|
||||
| Self::O1Preview
|
||||
| Self::O1Mini => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
@ -176,6 +193,9 @@ pub struct Request {
|
|||
pub temperature: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
/// Whether to enable parallel function calling during tool use.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue