language_model: Remove dependencies on individual model provider crates (#25503)

This PR removes the dependencies on the individual model provider crates
from the `language_model` crate.

The various conversion methods for converting a `LanguageModelRequest`
into its provider-specific request type have been inlined into the
various provider modules in the `language_models` crate.

The model providers we provide via Zed's cloud offering get to stay, for
now.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-24 16:41:35 -05:00 committed by GitHub
parent 2f7a62780a
commit 0acd556106
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 347 additions and 366 deletions

4
Cargo.lock generated
View file

@ -7015,16 +7015,12 @@ dependencies = [
"anyhow",
"base64 0.22.1",
"collections",
"deepseek",
"futures 0.3.31",
"google_ai",
"gpui",
"http_client",
"image",
"lmstudio",
"log",
"mistral",
"ollama",
"open_ai",
"parking_lot",
"proto",

View file

@ -20,16 +20,12 @@ anthropic = { workspace = true, features = ["schemars"] }
anyhow.workspace = true
base64.workspace = true
collections.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
image.workspace = true
lmstudio = { workspace = true, features = ["schemars"] }
log.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
proto.workspace = true

View file

@ -1,7 +1,3 @@
pub mod cloud_model;
pub use anthropic::Model as AnthropicModel;
pub use cloud_model::*;
pub use lmstudio::Model as LmStudioModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;

View file

@ -241,298 +241,6 @@ pub struct LanguageModelRequest {
pub temperature: Option<f32>,
}
impl LanguageModelRequest {
pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
let stream = !model.starts_with("o1-");
open_ai::Request {
model,
messages: self
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => open_ai::RequestMessage::User {
content: msg.string_contents(),
},
Role::Assistant => open_ai::RequestMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System {
content: msg.string_contents(),
},
})
.collect(),
stream,
stop: self.stop,
temperature: self.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
tools: Vec::new(),
tool_choice: None,
}
}
pub fn into_mistral(self, model: String, max_output_tokens: Option<u32>) -> mistral::Request {
let len = self.messages.len();
let merged_messages =
self.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();
acc.push(match role {
Role::User => mistral::RequestMessage::User { content },
Role::Assistant => mistral::RequestMessage::Assistant {
content: Some(content),
tool_calls: Vec::new(),
},
Role::System => mistral::RequestMessage::System { content },
});
acc
});
mistral::Request {
model,
messages: merged_messages,
stream: true,
max_tokens: max_output_tokens,
temperature: self.temperature,
response_format: None,
tools: self
.tools
.into_iter()
.map(|tool| mistral::ToolDefinition::Function {
function: mistral::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
}
}
pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
google_ai::GenerateContentRequest {
model,
contents: self
.messages
.into_iter()
.map(|msg| google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: msg.string_contents(),
})],
role: match msg.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
.collect(),
generation_config: Some(google_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(self.stop),
max_output_tokens: None,
temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
top_p: None,
top_k: None,
}),
safety_settings: None,
}
}
pub fn into_anthropic(
self,
model: String,
default_temperature: f32,
max_output_tokens: u32,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in self.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let cache_control = if message.cache {
Some(anthropic::CacheControl {
cache_type: anthropic::CacheControlType::Ephemeral,
})
} else {
None
};
let anthropic_message_content: Vec<anthropic::RequestContent> = message
.content
.into_iter()
.filter_map(|content| match content {
MessageContent::Text(text) => {
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control,
})
} else {
None
}
}
MessageContent::Image(image) => {
Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control,
})
}
MessageContent::ToolUse(tool_use) => {
Some(anthropic::RequestContent::ToolUse {
id: tool_use.id.to_string(),
name: tool_use.name,
input: tool_use.input,
cache_control,
})
}
MessageContent::ToolResult(tool_result) => {
Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id,
is_error: tool_result.is_error,
content: tool_result.content,
cache_control,
})
}
})
.collect();
let anthropic_role = match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == anthropic_role {
last_message.content.extend(anthropic_message_content);
continue;
}
}
new_messages.push(anthropic::Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
anthropic::Request {
model,
messages: new_messages,
max_tokens: max_output_tokens,
system: Some(system_message),
tools: self
.tools
.into_iter()
.map(|tool| anthropic::Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
})
.collect(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),
temperature: self.temperature.or(Some(default_temperature)),
top_k: None,
top_p: None,
}
}
pub fn into_deepseek(self, model: String, max_output_tokens: Option<u32>) -> deepseek::Request {
let is_reasoner = model == "deepseek-reasoner";
let len = self.messages.len();
let merged_messages =
self.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();
if is_reasoner {
if let Some(last_msg) = acc.last_mut() {
match (last_msg, role) {
(deepseek::RequestMessage::User { content: last }, Role::User) => {
last.push(' ');
last.push_str(&content);
return acc;
}
(
deepseek::RequestMessage::Assistant {
content: last_content,
..
},
Role::Assistant,
) => {
*last_content = last_content
.take()
.map(|c| {
let mut s =
String::with_capacity(c.len() + content.len() + 1);
s.push_str(&c);
s.push(' ');
s.push_str(&content);
s
})
.or(Some(content));
return acc;
}
_ => {}
}
}
}
acc.push(match role {
Role::User => deepseek::RequestMessage::User { content },
Role::Assistant => deepseek::RequestMessage::Assistant {
content: Some(content),
tool_calls: Vec::new(),
},
Role::System => deepseek::RequestMessage::System { content },
});
acc
});
deepseek::Request {
model,
messages: merged_messages,
stream: true,
max_tokens: max_output_tokens,
temperature: if is_reasoner { None } else { self.temperature },
response_format: None,
tools: self
.tools
.into_iter()
.map(|tool| deepseek::ToolDefinition::Function {
function: deepseek::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LanguageModelResponseMessage {
pub role: Option<Role>,

View file

@ -45,43 +45,3 @@ impl Display for Role {
}
}
}
impl From<Role> for ollama::Role {
fn from(val: Role) -> Self {
match val {
Role::User => ollama::Role::User,
Role::Assistant => ollama::Role::Assistant,
Role::System => ollama::Role::System,
}
}
}
impl From<Role> for open_ai::Role {
fn from(val: Role) -> Self {
match val {
Role::User => open_ai::Role::User,
Role::Assistant => open_ai::Role::Assistant,
Role::System => open_ai::Role::System,
}
}
}
impl From<Role> for deepseek::Role {
fn from(val: Role) -> Self {
match val {
Role::User => deepseek::Role::User,
Role::Assistant => deepseek::Role::Assistant,
Role::System => deepseek::Role::System,
}
}
}
impl From<Role> for lmstudio::Role {
fn from(val: Role) -> Self {
match val {
Role::User => lmstudio::Role::User,
Role::Assistant => lmstudio::Role::Assistant,
Role::System => lmstudio::Role::System,
}
}
}

View file

@ -13,7 +13,7 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@ -396,7 +396,8 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = request.into_anthropic(
let request = into_anthropic(
request,
self.model.id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
@ -427,7 +428,8 @@ impl LanguageModel for AnthropicModel {
input_schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let mut request = request.into_anthropic(
let mut request = into_anthropic(
request,
self.model.tool_model_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
@ -456,6 +458,117 @@ impl LanguageModel for AnthropicModel {
}
}
pub fn into_anthropic(
request: LanguageModelRequest,
model: String,
default_temperature: f32,
max_output_tokens: u32,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in request.messages {
if message.contents_empty() {
continue;
}
match message.role {
Role::User | Role::Assistant => {
let cache_control = if message.cache {
Some(anthropic::CacheControl {
cache_type: anthropic::CacheControlType::Ephemeral,
})
} else {
None
};
let anthropic_message_content: Vec<anthropic::RequestContent> = message
.content
.into_iter()
.filter_map(|content| match content {
MessageContent::Text(text) => {
if !text.is_empty() {
Some(anthropic::RequestContent::Text {
text,
cache_control,
})
} else {
None
}
}
MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
source: anthropic::ImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: image.source.to_string(),
},
cache_control,
}),
MessageContent::ToolUse(tool_use) => {
Some(anthropic::RequestContent::ToolUse {
id: tool_use.id.to_string(),
name: tool_use.name,
input: tool_use.input,
cache_control,
})
}
MessageContent::ToolResult(tool_result) => {
Some(anthropic::RequestContent::ToolResult {
tool_use_id: tool_result.tool_use_id,
is_error: tool_result.is_error,
content: tool_result.content,
cache_control,
})
}
})
.collect();
let anthropic_role = match message.role {
Role::User => anthropic::Role::User,
Role::Assistant => anthropic::Role::Assistant,
Role::System => unreachable!("System role should never occur here"),
};
if let Some(last_message) = new_messages.last_mut() {
if last_message.role == anthropic_role {
last_message.content.extend(anthropic_message_content);
continue;
}
}
new_messages.push(anthropic::Message {
role: anthropic_role,
content: anthropic_message_content,
});
}
Role::System => {
if !system_message.is_empty() {
system_message.push_str("\n\n");
}
system_message.push_str(&message.string_contents());
}
}
}
anthropic::Request {
model,
messages: new_messages,
max_tokens: max_output_tokens,
system: Some(system_message),
tools: request
.tools
.into_iter()
.map(|tool| anthropic::Tool {
name: tool.name,
description: tool.description,
input_schema: tool.input_schema,
})
.collect(),
tool_choice: None,
metadata: None,
stop_sequences: Vec::new(),
temperature: request.temperature.or(Some(default_temperature)),
top_k: None,
top_p: None,
}
}
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {

View file

@ -1,4 +1,3 @@
use super::open_ai::count_open_ai_tokens;
use anthropic::AnthropicError;
use anyhow::{anyhow, Result};
use client::{
@ -43,11 +42,13 @@ use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{prelude::*, TintColor};
use crate::provider::anthropic::map_to_language_model_completion_events;
use crate::provider::anthropic::{
count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
};
use crate::provider::google::into_google;
use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
use crate::AllLanguageModelSettings;
use super::anthropic::count_anthropic_tokens;
pub const PROVIDER_NAME: &str = "Zed";
const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
@ -612,7 +613,7 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
let request = into_google(request, model.id().into());
let request = google_ai::CountTokensRequest {
contents: request.contents,
};
@ -638,7 +639,8 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(
let request = into_anthropic(
request,
model.id().into(),
model.default_temperature(),
model.max_output_tokens(),
@ -666,7 +668,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@ -693,7 +695,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
@ -736,7 +738,8 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let mut request = request.into_anthropic(
let mut request = into_anthropic(
request,
model.tool_model_id().into(),
model.default_temperature(),
model.max_output_tokens(),
@ -776,7 +779,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let mut request =
request.into_open_ai(model.id().into(), model.max_output_tokens());
into_open_ai(request, model.id().into(), model.max_output_tokens());
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {

View file

@ -322,7 +322,11 @@ impl LanguageModel for DeepSeekLanguageModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
let request = into_deepseek(
request,
self.model.id().to_string(),
self.max_output_tokens(),
);
let stream = self.stream_completion(request, cx);
async move {
@ -357,8 +361,11 @@ impl LanguageModel for DeepSeekLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut deepseek_request =
request.into_deepseek(self.model.id().to_string(), self.max_output_tokens());
let mut deepseek_request = into_deepseek(
request,
self.model.id().to_string(),
self.max_output_tokens(),
);
deepseek_request.tools = vec![deepseek::ToolDefinition::Function {
function: deepseek::FunctionDefinition {
@ -402,6 +409,93 @@ impl LanguageModel for DeepSeekLanguageModel {
}
}
pub fn into_deepseek(
request: LanguageModelRequest,
model: String,
max_output_tokens: Option<u32>,
) -> deepseek::Request {
let is_reasoner = model == "deepseek-reasoner";
let len = request.messages.len();
let merged_messages =
request
.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();
if is_reasoner {
if let Some(last_msg) = acc.last_mut() {
match (last_msg, role) {
(deepseek::RequestMessage::User { content: last }, Role::User) => {
last.push(' ');
last.push_str(&content);
return acc;
}
(
deepseek::RequestMessage::Assistant {
content: last_content,
..
},
Role::Assistant,
) => {
*last_content = last_content
.take()
.map(|c| {
let mut s =
String::with_capacity(c.len() + content.len() + 1);
s.push_str(&c);
s.push(' ');
s.push_str(&content);
s
})
.or(Some(content));
return acc;
}
_ => {}
}
}
}
acc.push(match role {
Role::User => deepseek::RequestMessage::User { content },
Role::Assistant => deepseek::RequestMessage::Assistant {
content: Some(content),
tool_calls: Vec::new(),
},
Role::System => deepseek::RequestMessage::System { content },
});
acc
});
deepseek::Request {
model,
messages: merged_messages,
stream: true,
max_tokens: max_output_tokens,
temperature: if is_reasoner {
None
} else {
request.temperature
},
response_format: None,
tools: request
.tools
.into_iter()
.map(|tool| deepseek::ToolDefinition::Function {
function: deepseek::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
}
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: Entity<State>,

View file

@ -272,7 +272,7 @@ impl LanguageModel for GoogleLanguageModel {
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
let request = request.into_google(self.model.id().to_string());
let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
@ -303,7 +303,7 @@ impl LanguageModel for GoogleLanguageModel {
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_google(self.model.id().to_string());
let request = into_google(request, self.model.id().to_string());
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
@ -341,6 +341,38 @@ impl LanguageModel for GoogleLanguageModel {
}
}
pub fn into_google(
request: LanguageModelRequest,
model: String,
) -> google_ai::GenerateContentRequest {
google_ai::GenerateContentRequest {
model,
contents: request
.messages
.into_iter()
.map(|msg| google_ai::Content {
parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
text: msg.string_contents(),
})],
role: match msg.role {
Role::User => google_ai::Role::User,
Role::Assistant => google_ai::Role::Model,
Role::System => google_ai::Role::User, // Google AI doesn't have a system role
},
})
.collect(),
generation_config: Some(google_ai::GenerationConfig {
candidate_count: Some(1),
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
top_p: None,
top_k: None,
}),
safety_settings: None,
}
}
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,

View file

@ -334,7 +334,11 @@ impl LanguageModel for MistralLanguageModel {
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = request.into_mistral(self.model.id().to_string(), self.max_output_tokens());
let request = into_mistral(
request,
self.model.id().to_string(),
self.max_output_tokens(),
);
let stream = self.stream_completion(request, cx);
async move {
@ -369,7 +373,7 @@ impl LanguageModel for MistralLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_mistral(self.model.id().into(), self.max_output_tokens());
let mut request = into_mistral(request, self.model.id().into(), self.max_output_tokens());
request.tools = vec![mistral::ToolDefinition::Function {
function: mistral::FunctionDefinition {
name: tool_name.clone(),
@ -411,6 +415,52 @@ impl LanguageModel for MistralLanguageModel {
}
}
pub fn into_mistral(
request: LanguageModelRequest,
model: String,
max_output_tokens: Option<u32>,
) -> mistral::Request {
let len = request.messages.len();
let merged_messages =
request
.messages
.into_iter()
.fold(Vec::with_capacity(len), |mut acc, msg| {
let role = msg.role;
let content = msg.string_contents();
acc.push(match role {
Role::User => mistral::RequestMessage::User { content },
Role::Assistant => mistral::RequestMessage::Assistant {
content: Some(content),
tool_calls: Vec::new(),
},
Role::System => mistral::RequestMessage::System { content },
});
acc
});
mistral::Request {
model,
messages: merged_messages,
stream: true,
max_tokens: max_output_tokens,
temperature: request.temperature,
response_format: None,
tools: request
.tools
.into_iter()
.map(|tool| mistral::ToolDefinition::Function {
function: mistral::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
}
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,

View file

@ -318,7 +318,7 @@ impl LanguageModel for OpenAiLanguageModel {
'static,
Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
> {
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move {
Ok(open_ai::extract_text_from_events(completions.await?)
@ -336,7 +336,7 @@ impl LanguageModel for OpenAiLanguageModel {
schema: serde_json::Value,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let mut request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
@ -366,6 +366,39 @@ impl LanguageModel for OpenAiLanguageModel {
}
}
pub fn into_open_ai(
request: LanguageModelRequest,
model: String,
max_output_tokens: Option<u32>,
) -> open_ai::Request {
let stream = !model.starts_with("o1-");
open_ai::Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => open_ai::RequestMessage::User {
content: msg.string_contents(),
},
Role::Assistant => open_ai::RequestMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: Vec::new(),
},
Role::System => open_ai::RequestMessage::System {
content: msg.string_contents(),
},
})
.collect(),
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
tools: Vec::new(),
tool_choice: None,
}
}
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,