diff --git a/Cargo.lock b/Cargo.lock index 0c832b83aa..224aa421a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17431,11 +17431,8 @@ name = "vercel" version = "0.1.0" dependencies = [ "anyhow", - "futures 0.3.31", - "http_client", "schemars", "serde", - "serde_json", "strum 0.27.1", "workspace-hack", ] diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 1062d732a4..58902850ea 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -888,7 +888,12 @@ impl LanguageModel for CloudLanguageModel { Ok(model) => model, Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(), }; - let request = into_open_ai(request, &model, None); + let request = into_open_ai( + request, + model.id(), + model.supports_parallel_tool_calls(), + None, + ); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream(async move { let PerformLlmCompletionResponse { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 3fa5334eb0..56a81d36e9 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -344,7 +344,12 @@ impl LanguageModel for OpenAiLanguageModel { LanguageModelCompletionError, >, > { - let request = into_open_ai(request, &self.model, self.max_output_tokens()); + let request = into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.max_output_tokens(), + ); let completions = self.stream_completion(request, cx); async move { let mapper = OpenAiEventMapper::new(); @@ -356,10 +361,11 @@ impl LanguageModel for OpenAiLanguageModel { pub fn into_open_ai( request: LanguageModelRequest, - model: &Model, + model_id: &str, + supports_parallel_tool_calls: bool, max_output_tokens: Option, ) -> open_ai::Request { - let stream = !model.id().starts_with("o1-"); + let stream = !model_id.starts_with("o1-"); let mut messages = Vec::new(); for message in request.messages { @@ -435,13 +441,13 @@ pub fn into_open_ai( } open_ai::Request { - model: model.id().into(), + model: model_id.into(), messages, stream, stop: request.stop, temperature: request.temperature.unwrap_or(1.0), max_completion_tokens: max_output_tokens, - parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { + parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() { // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. Some(false) } else { diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 46063aceff..65058cbb74 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -1,8 +1,6 @@ use anyhow::{Context as _, Result, anyhow}; -use collections::{BTreeMap, HashMap}; +use collections::BTreeMap; use credentials_provider::CredentialsProvider; - -use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture}; use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; use http_client::HttpClient; @@ -10,16 +8,13 @@ use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - RateLimiter, Role, StopReason, + LanguageModelToolChoice, RateLimiter, Role, }; use menu; -use open_ai::{ImageUrl, ResponseStreamEvent, stream_completion}; +use open_ai::ResponseStreamEvent; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr as _; use std::sync::Arc; use strum::IntoEnumIterator; use vercel::Model; @@ -200,14 +195,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider { fn provided_models(&self, cx: &App) -> Vec> { let mut models = BTreeMap::default(); - // Add base models from vercel::Model::iter() for model in vercel::Model::iter() { if !matches!(model, vercel::Model::Custom { .. }) { models.insert(model.id().to_string(), model); } } - // Override with available models from settings for model in &AllLanguageModelSettings::get_global(cx) .vercel .available_models @@ -278,7 +271,8 @@ impl VercelLanguageModel { let future = self.request_limiter.stream(async move { let api_key = api_key.context("Missing Vercel API Key")?; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = + open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; Ok(response) }); @@ -354,264 +348,21 @@ impl LanguageModel for VercelLanguageModel { LanguageModelCompletionError, >, > { - let request = into_vercel(request, &self.model, self.max_output_tokens()); + let request = crate::provider::open_ai::into_open_ai( + request, + self.model.id(), + self.model.supports_parallel_tool_calls(), + self.max_output_tokens(), + ); let completions = self.stream_completion(request, cx); async move { - let mapper = VercelEventMapper::new(); + let mapper = crate::provider::open_ai::OpenAiEventMapper::new(); Ok(mapper.map_stream(completions.await?).boxed()) } .boxed() } } -pub fn into_vercel( - request: LanguageModelRequest, - model: &vercel::Model, - max_output_tokens: Option, -) -> open_ai::Request { - let stream = !model.id().starts_with("o1-"); - - let mut messages = Vec::new(); - for message in request.messages { - for content in message.content { - match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { - add_message_content_part( - open_ai::MessagePart::Text { text: text }, - message.role, - &mut messages, - ) - } - MessageContent::RedactedThinking(_) => {} - MessageContent::Image(image) => { - add_message_content_part( - open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }, - message.role, - &mut messages, - ); - } - MessageContent::ToolUse(tool_use) => { - let tool_call = open_ai::ToolCall { - id: tool_use.id.to_string(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: tool_use.name.to_string(), - arguments: serde_json::to_string(&tool_use.input) - .unwrap_or_default(), - }, - }, - }; - - if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) = - messages.last_mut() - { - tool_calls.push(tool_call); - } else { - messages.push(open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![tool_call], - }); - } - } - MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_ai::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_ai::MessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }] - } - }; - - messages.push(open_ai::RequestMessage::Tool { - content: content.into(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); - } - } - } - } - - open_ai::Request { - model: model.id().into(), - messages, - stream, - stop: request.stop, - temperature: request.temperature.unwrap_or(1.0), - max_completion_tokens: max_output_tokens, - parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() { - // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn. - Some(false) - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: tool.name, - description: Some(tool.description), - parameters: Some(tool.input_schema), - }, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto, - LanguageModelToolChoice::Any => open_ai::ToolChoice::Required, - LanguageModelToolChoice::None => open_ai::ToolChoice::None, - }), - } -} - -fn add_message_content_part( - new_part: open_ai::MessagePart, - role: Role, - messages: &mut Vec, -) { - match (role, messages.last_mut()) { - (Role::User, Some(open_ai::RequestMessage::User { content })) - | ( - Role::Assistant, - Some(open_ai::RequestMessage::Assistant { - content: Some(content), - .. - }), - ) - | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => { - content.push_part(new_part); - } - _ => { - messages.push(match role { - Role::User => open_ai::RequestMessage::User { - content: open_ai::MessageContent::from(vec![new_part]), - }, - Role::Assistant => open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::from(vec![new_part])), - tool_calls: Vec::new(), - }, - Role::System => open_ai::RequestMessage::System { - content: open_ai::MessageContent::from(vec![new_part]), - }, - }); - } - } -} - -pub struct VercelEventMapper { - tool_calls_by_index: HashMap, -} - -impl VercelEventMapper { - pub fn new() -> Self { - Self { - tool_calls_by_index: HashMap::default(), - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], - }) - }) - } - - pub fn map_event( - &mut self, - event: ResponseStreamEvent, - ) -> Vec> { - let Some(choice) = event.choices.first() else { - return Vec::new(); - }; - - let mut events = Vec::new(); - if let Some(content) = choice.delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } - - if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - } - } - - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - Some("tool_calls") => { - events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { - match serde_json::Value::from_str(&tool_call.arguments) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - }, - )), - Err(error) => Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - }), - } - })); - - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); - } - Some(stop_reason) => { - log::error!("Unexpected Vercel stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); - } - None => {} - } - - events - } -} - -#[derive(Default)] -struct RawToolCall { - id: String, - name: String, - arguments: String, -} - pub fn count_vercel_tokens( request: LanguageModelRequest, model: Model, @@ -825,43 +576,3 @@ impl Render for ConfigurationView { } } } - -#[cfg(test)] -mod tests { - use gpui::TestAppContext; - use language_model::LanguageModelRequestMessage; - - use super::*; - - #[gpui::test] - fn tiktoken_rs_support(cx: &TestAppContext) { - let request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: None, - mode: None, - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text("message".into())], - cache: false, - }], - tools: vec![], - tool_choice: None, - stop: vec![], - temperature: None, - }; - - // Validate that all models are supported by tiktoken-rs - for model in Model::iter() { - let count = cx - .executor() - .block(count_vercel_tokens( - request.clone(), - model, - &cx.app.borrow(), - )) - .unwrap(); - assert!(count > 0); - } - } -} diff --git a/crates/vercel/Cargo.toml b/crates/vercel/Cargo.toml index c4e1e4f99d..60fa1a2390 100644 --- a/crates/vercel/Cargo.toml +++ b/crates/vercel/Cargo.toml @@ -17,10 +17,7 @@ schemars = ["dep:schemars"] [dependencies] anyhow.workspace = true -futures.workspace = true -http_client.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true -serde_json.workspace = true strum.workspace = true workspace-hack.workspace = true diff --git a/crates/vercel/src/vercel.rs b/crates/vercel/src/vercel.rs index 3195355bbc..cce219eca4 100644 --- a/crates/vercel/src/vercel.rs +++ b/crates/vercel/src/vercel.rs @@ -1,51 +1,9 @@ -use anyhow::{Context as _, Result, anyhow}; -use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use anyhow::Result; use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::{convert::TryFrom, future::Future}; use strum::EnumIter; pub const VERCEL_API_URL: &str = "https://api.v0.dev/v1"; -fn is_none_or_empty, U>(opt: &Option) -> bool { - opt.as_ref().map_or(true, |v| v.as_ref().is_empty()) -} - -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, - Tool, -} - -impl TryFrom for Role { - type Error = anyhow::Error; - - fn try_from(value: String) -> Result { - match value.as_str() { - "user" => Ok(Self::User), - "assistant" => Ok(Self::Assistant), - "system" => Ok(Self::System), - "tool" => Ok(Self::Tool), - _ => anyhow::bail!("invalid role '{value}'"), - } - } -} - -impl From for String { - fn from(val: Role) -> Self { - match val { - Role::User => "user".to_owned(), - Role::Assistant => "assistant".to_owned(), - Role::System => "system".to_owned(), - Role::Tool => "tool".to_owned(), - } - } -} - #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] pub enum Model { @@ -118,321 +76,3 @@ impl Model { } } } - -#[derive(Debug, Serialize, Deserialize)] -pub struct Request { - pub model: String, - pub messages: Vec, - pub stream: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub max_completion_tokens: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub stop: Vec, - pub temperature: f32, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - /// Whether to enable parallel function calling during tool use. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub tools: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolChoice { - Auto, - Required, - None, - Other(ToolDefinition), -} - -#[derive(Clone, Deserialize, Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ToolDefinition { - #[allow(dead_code)] - Function { function: FunctionDefinition }, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FunctionDefinition { - pub name: String, - pub description: Option, - pub parameters: Option, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(tag = "role", rename_all = "lowercase")] -pub enum RequestMessage { - Assistant { - content: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - tool_calls: Vec, - }, - User { - content: MessageContent, - }, - System { - content: MessageContent, - }, - Tool { - content: MessageContent, - tool_call_id: String, - }, -} - -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] -#[serde(untagged)] -pub enum MessageContent { - Plain(String), - Multipart(Vec), -} - -impl MessageContent { - pub fn empty() -> Self { - MessageContent::Multipart(vec![]) - } - - pub fn push_part(&mut self, part: MessagePart) { - match self { - MessageContent::Plain(text) => { - *self = - MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]); - } - MessageContent::Multipart(parts) if parts.is_empty() => match part { - MessagePart::Text { text } => *self = MessageContent::Plain(text), - MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]), - }, - MessageContent::Multipart(parts) => parts.push(part), - } - } -} - -impl From> for MessageContent { - fn from(mut parts: Vec) -> Self { - if let [MessagePart::Text { text }] = parts.as_mut_slice() { - MessageContent::Plain(std::mem::take(text)) - } else { - MessageContent::Multipart(parts) - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] -#[serde(tag = "type")] -pub enum MessagePart { - #[serde(rename = "text")] - Text { text: String }, - #[serde(rename = "image_url")] - Image { image_url: ImageUrl }, -} - -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)] -pub struct ImageUrl { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ToolCall { - pub id: String, - #[serde(flatten)] - pub content: ToolCallContent, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum ToolCallContent { - Function { function: FunctionContent }, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct FunctionContent { - pub name: String, - pub arguments: String, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessageDelta { - pub role: Option, - pub content: Option, - #[serde(default, skip_serializing_if = "is_none_or_empty")] - pub tool_calls: Option>, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ToolCallChunk { - pub index: usize, - pub id: Option, - - // There is also an optional `type` field that would determine if a - // function is there. Sometimes this streams in with the `function` before - // it streams in the `type` - pub function: Option, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct FunctionChunk { - pub name: Option, - pub arguments: Option, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ChoiceDelta { - pub index: u32, - pub delta: ResponseMessageDelta, - pub finish_reason: Option, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(untagged)] -pub enum ResponseStreamResult { - Ok(ResponseStreamEvent), - Err { error: String }, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ResponseStreamEvent { - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub async fn stream_completion( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - request: Request, -) -> Result>> { - let uri = format!("{api_url}/chat/completions"); - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)); - - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; - let mut response = client.send(request).await?; - if response.status().is_success() { - let reader = BufReader::new(response.into_body()); - Ok(reader - .lines() - .filter_map(|line| async move { - match line { - Ok(line) => { - let line = line.strip_prefix("data: ")?; - if line == "[DONE]" { - None - } else { - match serde_json::from_str(line) { - Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)), - Ok(ResponseStreamResult::Err { error }) => { - Some(Err(anyhow!(error))) - } - Err(error) => Some(Err(anyhow!(error))), - } - } - } - Err(error) => Some(Err(anyhow!(error))), - } - }) - .boxed()) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct VercelResponse { - error: VercelError, - } - - #[derive(Deserialize)] - struct VercelError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to Vercel API: {}", - response.error.message, - )), - - _ => anyhow::bail!( - "Failed to connect to Vercel API: {} {}", - response.status(), - body, - ), - } - } -} - -#[derive(Copy, Clone, Serialize, Deserialize)] -pub enum VercelEmbeddingModel { - #[serde(rename = "text-embedding-3-small")] - TextEmbedding3Small, - #[serde(rename = "text-embedding-3-large")] - TextEmbedding3Large, -} - -#[derive(Serialize)] -struct VercelEmbeddingRequest<'a> { - model: VercelEmbeddingModel, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -pub struct VercelEmbeddingResponse { - pub data: Vec, -} - -#[derive(Deserialize)] -pub struct VercelEmbedding { - pub embedding: Vec, -} - -pub fn embed<'a>( - client: &dyn HttpClient, - api_url: &str, - api_key: &str, - model: VercelEmbeddingModel, - texts: impl IntoIterator, -) -> impl 'static + Future> { - let uri = format!("{api_url}/embeddings"); - - let request = VercelEmbeddingRequest { - model, - input: texts.into_iter().collect(), - }; - let body = AsyncBody::from(serde_json::to_string(&request).unwrap()); - let request = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(body) - .map(|request| client.send(request)); - - async move { - let mut response = request?.await?; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - anyhow::ensure!( - response.status().is_success(), - "error during embedding, status: {:?}, body: {:?}", - response.status(), - body - ); - let response: VercelEmbeddingResponse = - serde_json::from_str(&body).context("failed to parse Vercel embedding response")?; - Ok(response) - } -}