Add support for Vercel as a language model provider (#33292)
Vercel v0 is an OpenAI-compatible model, so this is mostly a dupe of the OpenAI provider files with some adaptations for v0, including going ahead and using the custom endpoint for the API URL field. Release Notes: - Added support for Vercel as a language model provider.
This commit is contained in:
parent
0d70bcb88c
commit
94735aef69
12 changed files with 1394 additions and 0 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -8985,6 +8985,7 @@ dependencies = [
|
|||
"ui",
|
||||
"ui_input",
|
||||
"util",
|
||||
"vercel",
|
||||
"workspace-hack",
|
||||
"zed_llm_client",
|
||||
]
|
||||
|
@ -17424,6 +17425,20 @@ version = "0.2.15"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "vercel"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "version-compare"
|
||||
version = "0.2.0"
|
||||
|
|
|
@ -165,6 +165,7 @@ members = [
|
|||
"crates/ui_prompt",
|
||||
"crates/util",
|
||||
"crates/util_macros",
|
||||
"crates/vercel",
|
||||
"crates/vim",
|
||||
"crates/vim_mode_setting",
|
||||
"crates/watch",
|
||||
|
@ -375,6 +376,7 @@ ui_macros = { path = "crates/ui_macros" }
|
|||
ui_prompt = { path = "crates/ui_prompt" }
|
||||
util = { path = "crates/util" }
|
||||
util_macros = { path = "crates/util_macros" }
|
||||
vercel = { path = "crates/vercel" }
|
||||
vim = { path = "crates/vim" }
|
||||
vim_mode_setting = { path = "crates/vim_mode_setting" }
|
||||
|
||||
|
|
16
assets/icons/ai_v_zero.svg
Normal file
16
assets/icons/ai_v_zero.svg
Normal file
|
@ -0,0 +1,16 @@
|
|||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g clip-path="url(#clip0_2639_570)">
|
||||
<g clip-path="url(#clip1_2639_570)">
|
||||
<path d="M9.85676 4H13.6675C15.2128 4 16.4654 5.25266 16.4654 6.7979V10.4322H14.9002V6.7979C14.9002 6.76067 14.8988 6.7237 14.8959 6.68706L11.0851 10.4316C11.098 10.432 11.1109 10.4322 11.1238 10.4322H14.9002V11.9105H11.1238C9.57856 11.9105 8.29152 10.6456 8.29152 9.10032V5.47569H9.85676V9.10032C9.85676 9.17012 9.86216 9.23908 9.87264 9.30672L13.7673 5.4798C13.7344 5.47708 13.7012 5.47569 13.6675 5.47569H9.85676V4Z" fill="black"/>
|
||||
<path d="M6.00752 11.6382L0.5 5.47504H2.71573L5.94924 9.09348V5.47504H7.6014V11.0298C7.6014 11.8682 6.56616 12.2634 6.00752 11.6382Z" fill="black"/>
|
||||
</g>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_2639_570">
|
||||
<rect width="16" height="16" fill="white" transform="translate(0.5)"/>
|
||||
</clipPath>
|
||||
<clipPath id="clip1_2639_570">
|
||||
<rect width="16" height="8" fill="white" transform="translate(0.5 4)"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 1,015 B |
|
@ -19,6 +19,7 @@ pub enum IconName {
|
|||
AiOllama,
|
||||
AiOpenAi,
|
||||
AiOpenRouter,
|
||||
AiVZero,
|
||||
AiZed,
|
||||
ArrowCircle,
|
||||
ArrowDown,
|
||||
|
|
|
@ -40,6 +40,7 @@ mistral = { workspace = true, features = ["schemars"] }
|
|||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
project.workspace = true
|
||||
proto.workspace = true
|
||||
|
|
|
@ -20,6 +20,7 @@ use crate::provider::mistral::MistralLanguageModelProvider;
|
|||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::vercel::VercelLanguageModelProvider;
|
||||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
|
||||
|
@ -77,5 +78,9 @@ fn register_language_model_providers(
|
|||
OpenRouterLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
VercelLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
}
|
||||
|
|
|
@ -9,3 +9,4 @@ pub mod mistral;
|
|||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
pub mod open_router;
|
||||
pub mod vercel;
|
||||
|
|
867
crates/language_models/src/provider/vercel.rs
Normal file
867
crates/language_models/src/provider/vercel.rs
Normal file
|
@ -0,0 +1,867 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ImageUrl, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use vercel::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "vercel";
|
||||
const PROVIDER_NAME: &str = "Vercel";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct VercelSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct VercelLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VercelLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: vercel::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(VercelLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for VercelLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiVZero
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(vercel::Model::default()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(vercel::Model::default_fast()))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
// Add base models from vercel::Model::iter()
|
||||
for model in vercel::Model::iter() {
|
||||
if !matches!(model, vercel::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
// Override with available models from settings
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.vercel
|
||||
.available_models
|
||||
{
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
vercel::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| self.create_language_model(model))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VercelLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: vercel::Model,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl VercelLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing Vercel API Key")?;
|
||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for VercelLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto => true,
|
||||
LanguageModelToolChoice::Any => true,
|
||||
LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("vercel/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_vercel_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_vercel(request, &self.model, self.max_output_tokens());
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = VercelEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_vercel(
|
||||
request: LanguageModelRequest,
|
||||
model: &vercel::Model,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model.id().starts_with("o1-");
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
||||
add_message_content_part(
|
||||
open_ai::MessagePart::Text { text: text },
|
||||
message.role,
|
||||
&mut messages,
|
||||
)
|
||||
}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
add_message_content_part(
|
||||
open_ai::MessagePart::Image {
|
||||
image_url: ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
message.role,
|
||||
&mut messages,
|
||||
);
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = open_ai::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
content: open_ai::ToolCallContent::Function {
|
||||
function: open_ai::FunctionContent {
|
||||
name: tool_use.name.to_string(),
|
||||
arguments: serde_json::to_string(&tool_use.input)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
|
||||
messages.last_mut()
|
||||
{
|
||||
tool_calls.push(tool_call);
|
||||
} else {
|
||||
messages.push(open_ai::RequestMessage::Assistant {
|
||||
content: None,
|
||||
tool_calls: vec![tool_call],
|
||||
});
|
||||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
vec![open_ai::MessagePart::Text {
|
||||
text: text.to_string(),
|
||||
}]
|
||||
}
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
vec![open_ai::MessagePart::Image {
|
||||
image_url: ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(open_ai::RequestMessage::Tool {
|
||||
content: content.into(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
open_ai::Request {
|
||||
model: model.id().into(),
|
||||
messages,
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
max_completion_tokens: max_output_tokens,
|
||||
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
|
||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
.map(|tool| open_ai::ToolDefinition::Function {
|
||||
function: open_ai::FunctionDefinition {
|
||||
name: tool.name,
|
||||
description: Some(tool.description),
|
||||
parameters: Some(tool.input_schema),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
||||
LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
|
||||
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
|
||||
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_message_content_part(
|
||||
new_part: open_ai::MessagePart,
|
||||
role: Role,
|
||||
messages: &mut Vec<open_ai::RequestMessage>,
|
||||
) {
|
||||
match (role, messages.last_mut()) {
|
||||
(Role::User, Some(open_ai::RequestMessage::User { content }))
|
||||
| (
|
||||
Role::Assistant,
|
||||
Some(open_ai::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
}),
|
||||
)
|
||||
| (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
_ => {
|
||||
messages.push(match role {
|
||||
Role::User => open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::from(vec![new_part])),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_ai::RequestMessage::System {
|
||||
content: open_ai::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VercelEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
||||
impl VercelEventMapper {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tool_calls_by_index: HashMap::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
for tool_call in tool_calls {
|
||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
||||
|
||||
if let Some(tool_id) = tool_call.id.clone() {
|
||||
entry.id = tool_id;
|
||||
}
|
||||
|
||||
if let Some(function) = tool_call.function.as_ref() {
|
||||
if let Some(name) = function.name.clone() {
|
||||
entry.name = name;
|
||||
}
|
||||
|
||||
if let Some(arguments) = function.arguments.clone() {
|
||||
entry.arguments.push_str(&arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
Some("tool_calls") => {
|
||||
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
||||
LanguageModelToolUse {
|
||||
id: tool_call.id.clone().into(),
|
||||
name: tool_call.name.as_str().into(),
|
||||
is_input_complete: true,
|
||||
input,
|
||||
raw_input: tool_call.arguments.clone(),
|
||||
},
|
||||
)),
|
||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
||||
id: tool_call.id.into(),
|
||||
tool_name: tool_call.name.as_str().into(),
|
||||
raw_input: tool_call.arguments.into(),
|
||||
json_parse_error: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}));
|
||||
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected Vercel stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
events
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RawToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
pub fn count_vercel_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match model {
|
||||
Model::Custom { max_tokens, .. } => {
|
||||
let model = if max_tokens >= 100_000 {
|
||||
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
|
||||
"gpt-4o"
|
||||
} else {
|
||||
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
|
||||
// supported with this tiktoken method
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model, &messages)
|
||||
}
|
||||
// Map Vercel models to appropriate OpenAI models for token counting
|
||||
// since Vercel uses OpenAI-compatible API
|
||||
Model::VZero => {
|
||||
// Vercel v0 is similar to GPT-4o, so use gpt-4o for token counting
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
|
||||
}
|
||||
}
|
||||
.map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"v1:0000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's agent with Vercel v0, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::new(
|
||||
"Create one by visiting",
|
||||
Some("Vercel v0's console"),
|
||||
Some("https://v0.dev/chat/settings/keys"),
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste your API key below and hit enter to start using the agent",
|
||||
)),
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new("Note that Vercel v0 is a custom OpenAI-compatible provider.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use gpui::TestAppContext;
|
||||
use language_model::LanguageModelRequestMessage;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[gpui::test]
|
||||
fn tiktoken_rs_support(cx: &TestAppContext) {
|
||||
let request = LanguageModelRequest {
|
||||
thread_id: None,
|
||||
prompt_id: None,
|
||||
intent: None,
|
||||
mode: None,
|
||||
messages: vec![LanguageModelRequestMessage {
|
||||
role: Role::User,
|
||||
content: vec![MessageContent::Text("message".into())],
|
||||
cache: false,
|
||||
}],
|
||||
tools: vec![],
|
||||
tool_choice: None,
|
||||
stop: vec![],
|
||||
temperature: None,
|
||||
};
|
||||
|
||||
// Validate that all models are supported by tiktoken-rs
|
||||
for model in Model::iter() {
|
||||
let count = cx
|
||||
.executor()
|
||||
.block(count_vercel_tokens(
|
||||
request.clone(),
|
||||
model,
|
||||
&cx.app.borrow(),
|
||||
))
|
||||
.unwrap();
|
||||
assert!(count > 0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ use crate::provider::{
|
|||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
open_router::OpenRouterSettings,
|
||||
vercel::VercelSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
|
@ -64,6 +65,7 @@ pub struct AllLanguageModelSettings {
|
|||
pub open_router: OpenRouterSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub vercel: VercelSettings,
|
||||
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub deepseek: DeepSeekSettings,
|
||||
|
@ -82,6 +84,7 @@ pub struct AllLanguageModelSettingsContent {
|
|||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
}
|
||||
|
@ -259,6 +262,12 @@ pub struct OpenAiSettingsContentV1 {
|
|||
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct VercelSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::vercel::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct GoogleSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -385,6 +394,18 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
&mut settings.openai.available_models,
|
||||
openai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// Vercel
|
||||
let vercel = value.vercel.clone();
|
||||
merge(
|
||||
&mut settings.vercel.api_url,
|
||||
vercel.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.vercel.available_models,
|
||||
vercel.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
&mut settings.zed_dot_dev.available_models,
|
||||
value
|
||||
|
|
26
crates/vercel/Cargo.toml
Normal file
26
crates/vercel/Cargo.toml
Normal file
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "vercel"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/vercel.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
workspace-hack.workspace = true
|
1
crates/vercel/LICENSE-GPL
Symbolic link
1
crates/vercel/LICENSE-GPL
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE-GPL
|
438
crates/vercel/src/vercel.rs
Normal file
438
crates/vercel/src/vercel.rs
Normal file
|
@ -0,0 +1,438 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{convert::TryFrom, future::Future};
|
||||
use strum::EnumIter;
|
||||
|
||||
pub const VERCEL_API_URL: &str = "https://api.v0.dev/v1";
|
||||
|
||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
||||
opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
"system" => Ok(Self::System),
|
||||
"tool" => Ok(Self::Tool),
|
||||
_ => anyhow::bail!("invalid role '{value}'"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
Role::System => "system".to_owned(),
|
||||
Role::Tool => "tool".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "v-0")]
|
||||
#[default]
|
||||
VZero,
|
||||
|
||||
#[serde(rename = "custom")]
|
||||
Custom {
|
||||
name: String,
|
||||
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
|
||||
display_name: Option<String>,
|
||||
max_tokens: u64,
|
||||
max_output_tokens: Option<u64>,
|
||||
max_completion_tokens: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn default_fast() -> Self {
|
||||
Self::VZero
|
||||
}
|
||||
|
||||
pub fn from_id(id: &str) -> Result<Self> {
|
||||
match id {
|
||||
"v-0" => Ok(Self::VZero),
|
||||
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Self::VZero => "v-0",
|
||||
Self::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::VZero => "Vercel v0",
|
||||
Self::Custom {
|
||||
name, display_name, ..
|
||||
} => display_name.as_ref().unwrap_or(name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> u64 {
|
||||
match self {
|
||||
Self::VZero => 128_000,
|
||||
Self::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_output_tokens(&self) -> Option<u64> {
|
||||
match self {
|
||||
Self::Custom {
|
||||
max_output_tokens, ..
|
||||
} => *max_output_tokens,
|
||||
Self::VZero => Some(32_768),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns whether the given model supports the `parallel_tool_calls` parameter.
|
||||
///
|
||||
/// If the model does not support the parameter, do not pass it up, or the API will return an error.
|
||||
pub fn supports_parallel_tool_calls(&self) -> bool {
|
||||
match self {
|
||||
Self::VZero => true,
|
||||
Model::Custom { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub model: String,
|
||||
pub messages: Vec<RequestMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
/// Whether to enable parallel function calling during tool use.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
Other(ToolDefinition),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolDefinition {
|
||||
#[allow(dead_code)]
|
||||
Function { function: FunctionDefinition },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum RequestMessage {
|
||||
Assistant {
|
||||
content: Option<MessageContent>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: MessageContent,
|
||||
},
|
||||
System {
|
||||
content: MessageContent,
|
||||
},
|
||||
Tool {
|
||||
content: MessageContent,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Plain(String),
|
||||
Multipart(Vec<MessagePart>),
|
||||
}
|
||||
|
||||
impl MessageContent {
|
||||
pub fn empty() -> Self {
|
||||
MessageContent::Multipart(vec![])
|
||||
}
|
||||
|
||||
pub fn push_part(&mut self, part: MessagePart) {
|
||||
match self {
|
||||
MessageContent::Plain(text) => {
|
||||
*self =
|
||||
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
||||
}
|
||||
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
||||
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
||||
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
||||
},
|
||||
MessageContent::Multipart(parts) => parts.push(part),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<MessagePart>> for MessageContent {
|
||||
fn from(mut parts: Vec<MessagePart>) -> Self {
|
||||
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
||||
MessageContent::Plain(std::mem::take(text))
|
||||
} else {
|
||||
MessageContent::Multipart(parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagePart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
Image { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub content: ToolCallContent,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ToolCallContent {
|
||||
Function { function: FunctionContent },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionContent {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ResponseMessageDelta {
|
||||
pub role: Option<Role>,
|
||||
pub content: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "is_none_or_empty")]
|
||||
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct ToolCallChunk {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
|
||||
// There is also an optional `type` field that would determine if a
|
||||
// function is there. Sometimes this streams in with the `function` before
|
||||
// it streams in the `type`
|
||||
pub function: Option<FunctionChunk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionChunk {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ChoiceDelta {
|
||||
pub index: u32,
|
||||
pub delta: ResponseMessageDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponseStreamResult {
|
||||
Ok(ResponseStreamEvent),
|
||||
Err { error: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResponseStreamEvent {
|
||||
pub model: String,
|
||||
pub choices: Vec<ChoiceDelta>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: Request,
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key));
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
let line = line.strip_prefix("data: ")?;
|
||||
if line == "[DONE]" {
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str(line) {
|
||||
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
||||
Ok(ResponseStreamResult::Err { error }) => {
|
||||
Some(Err(anyhow!(error)))
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(error) => Some(Err(anyhow!(error))),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct VercelResponse {
|
||||
error: VercelError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct VercelError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
match serde_json::from_str::<VercelResponse>(&body) {
|
||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
||||
"Failed to connect to Vercel API: {}",
|
||||
response.error.message,
|
||||
)),
|
||||
|
||||
_ => anyhow::bail!(
|
||||
"Failed to connect to Vercel API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Serialize, Deserialize)]
|
||||
pub enum VercelEmbeddingModel {
|
||||
#[serde(rename = "text-embedding-3-small")]
|
||||
TextEmbedding3Small,
|
||||
#[serde(rename = "text-embedding-3-large")]
|
||||
TextEmbedding3Large,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct VercelEmbeddingRequest<'a> {
|
||||
model: VercelEmbeddingModel,
|
||||
input: Vec<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VercelEmbeddingResponse {
|
||||
pub data: Vec<VercelEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct VercelEmbedding {
|
||||
pub embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
pub fn embed<'a>(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
model: VercelEmbeddingModel,
|
||||
texts: impl IntoIterator<Item = &'a str>,
|
||||
) -> impl 'static + Future<Output = Result<VercelEmbeddingResponse>> {
|
||||
let uri = format!("{api_url}/embeddings");
|
||||
|
||||
let request = VercelEmbeddingRequest {
|
||||
model,
|
||||
input: texts.into_iter().collect(),
|
||||
};
|
||||
let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
|
||||
let request = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(body)
|
||||
.map(|request| client.send(request));
|
||||
|
||||
async move {
|
||||
let mut response = request?.await?;
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
response.status().is_success(),
|
||||
"error during embedding, status: {:?}, body: {:?}",
|
||||
response.status(),
|
||||
body
|
||||
);
|
||||
let response: VercelEmbeddingResponse =
|
||||
serde_json::from_str(&body).context("failed to parse Vercel embedding response")?;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue