vercel: Reuse existing OpenAI code (#33362)
Follow up to #33292 Since Vercel's API is OpenAI compatible, we can reuse a bunch of code. Release Notes: - N/A
This commit is contained in:
parent
c979452c2d
commit
18f1221a44
6 changed files with 30 additions and 674 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -17431,11 +17431,8 @@ name = "vercel"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"futures 0.3.31",
|
|
||||||
"http_client",
|
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
|
||||||
"strum 0.27.1",
|
"strum 0.27.1",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
|
@ -888,7 +888,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
Ok(model) => model,
|
Ok(model) => model,
|
||||||
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
|
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
|
||||||
};
|
};
|
||||||
let request = into_open_ai(request, &model, None);
|
let request = into_open_ai(
|
||||||
|
request,
|
||||||
|
model.id(),
|
||||||
|
model.supports_parallel_tool_calls(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let PerformLlmCompletionResponse {
|
let PerformLlmCompletionResponse {
|
||||||
|
|
|
@ -344,7 +344,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
LanguageModelCompletionError,
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
let request = into_open_ai(
|
||||||
|
request,
|
||||||
|
self.model.id(),
|
||||||
|
self.model.supports_parallel_tool_calls(),
|
||||||
|
self.max_output_tokens(),
|
||||||
|
);
|
||||||
let completions = self.stream_completion(request, cx);
|
let completions = self.stream_completion(request, cx);
|
||||||
async move {
|
async move {
|
||||||
let mapper = OpenAiEventMapper::new();
|
let mapper = OpenAiEventMapper::new();
|
||||||
|
@ -356,10 +361,11 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
|
|
||||||
pub fn into_open_ai(
|
pub fn into_open_ai(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
model: &Model,
|
model_id: &str,
|
||||||
|
supports_parallel_tool_calls: bool,
|
||||||
max_output_tokens: Option<u64>,
|
max_output_tokens: Option<u64>,
|
||||||
) -> open_ai::Request {
|
) -> open_ai::Request {
|
||||||
let stream = !model.id().starts_with("o1-");
|
let stream = !model_id.starts_with("o1-");
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
for message in request.messages {
|
for message in request.messages {
|
||||||
|
@ -435,13 +441,13 @@ pub fn into_open_ai(
|
||||||
}
|
}
|
||||||
|
|
||||||
open_ai::Request {
|
open_ai::Request {
|
||||||
model: model.id().into(),
|
model: model_id.into(),
|
||||||
messages,
|
messages,
|
||||||
stream,
|
stream,
|
||||||
stop: request.stop,
|
stop: request.stop,
|
||||||
temperature: request.temperature.unwrap_or(1.0),
|
temperature: request.temperature.unwrap_or(1.0),
|
||||||
max_completion_tokens: max_output_tokens,
|
max_completion_tokens: max_output_tokens,
|
||||||
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
|
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
|
||||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
||||||
Some(false)
|
Some(false)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use collections::{BTreeMap, HashMap};
|
use collections::BTreeMap;
|
||||||
use credentials_provider::CredentialsProvider;
|
use credentials_provider::CredentialsProvider;
|
||||||
|
|
||||||
use futures::Stream;
|
|
||||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
|
@ -10,16 +8,13 @@ use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
LanguageModelToolChoice, RateLimiter, Role,
|
||||||
RateLimiter, Role, StopReason,
|
|
||||||
};
|
};
|
||||||
use menu;
|
use menu;
|
||||||
use open_ai::{ImageUrl, ResponseStreamEvent, stream_completion};
|
use open_ai::ResponseStreamEvent;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::pin::Pin;
|
|
||||||
use std::str::FromStr as _;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
use vercel::Model;
|
use vercel::Model;
|
||||||
|
@ -200,14 +195,12 @@ impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||||
let mut models = BTreeMap::default();
|
let mut models = BTreeMap::default();
|
||||||
|
|
||||||
// Add base models from vercel::Model::iter()
|
|
||||||
for model in vercel::Model::iter() {
|
for model in vercel::Model::iter() {
|
||||||
if !matches!(model, vercel::Model::Custom { .. }) {
|
if !matches!(model, vercel::Model::Custom { .. }) {
|
||||||
models.insert(model.id().to_string(), model);
|
models.insert(model.id().to_string(), model);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override with available models from settings
|
|
||||||
for model in &AllLanguageModelSettings::get_global(cx)
|
for model in &AllLanguageModelSettings::get_global(cx)
|
||||||
.vercel
|
.vercel
|
||||||
.available_models
|
.available_models
|
||||||
|
@ -278,7 +271,8 @@ impl VercelLanguageModel {
|
||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let api_key = api_key.context("Missing Vercel API Key")?;
|
let api_key = api_key.context("Missing Vercel API Key")?;
|
||||||
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
let request =
|
||||||
|
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(response)
|
Ok(response)
|
||||||
});
|
});
|
||||||
|
@ -354,264 +348,21 @@ impl LanguageModel for VercelLanguageModel {
|
||||||
LanguageModelCompletionError,
|
LanguageModelCompletionError,
|
||||||
>,
|
>,
|
||||||
> {
|
> {
|
||||||
let request = into_vercel(request, &self.model, self.max_output_tokens());
|
let request = crate::provider::open_ai::into_open_ai(
|
||||||
|
request,
|
||||||
|
self.model.id(),
|
||||||
|
self.model.supports_parallel_tool_calls(),
|
||||||
|
self.max_output_tokens(),
|
||||||
|
);
|
||||||
let completions = self.stream_completion(request, cx);
|
let completions = self.stream_completion(request, cx);
|
||||||
async move {
|
async move {
|
||||||
let mapper = VercelEventMapper::new();
|
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
|
||||||
Ok(mapper.map_stream(completions.await?).boxed())
|
Ok(mapper.map_stream(completions.await?).boxed())
|
||||||
}
|
}
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_vercel(
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
model: &vercel::Model,
|
|
||||||
max_output_tokens: Option<u64>,
|
|
||||||
) -> open_ai::Request {
|
|
||||||
let stream = !model.id().starts_with("o1-");
|
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
|
||||||
for message in request.messages {
|
|
||||||
for content in message.content {
|
|
||||||
match content {
|
|
||||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
|
|
||||||
add_message_content_part(
|
|
||||||
open_ai::MessagePart::Text { text: text },
|
|
||||||
message.role,
|
|
||||||
&mut messages,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
MessageContent::RedactedThinking(_) => {}
|
|
||||||
MessageContent::Image(image) => {
|
|
||||||
add_message_content_part(
|
|
||||||
open_ai::MessagePart::Image {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: image.to_base64_url(),
|
|
||||||
detail: None,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
message.role,
|
|
||||||
&mut messages,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
MessageContent::ToolUse(tool_use) => {
|
|
||||||
let tool_call = open_ai::ToolCall {
|
|
||||||
id: tool_use.id.to_string(),
|
|
||||||
content: open_ai::ToolCallContent::Function {
|
|
||||||
function: open_ai::FunctionContent {
|
|
||||||
name: tool_use.name.to_string(),
|
|
||||||
arguments: serde_json::to_string(&tool_use.input)
|
|
||||||
.unwrap_or_default(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
|
|
||||||
messages.last_mut()
|
|
||||||
{
|
|
||||||
tool_calls.push(tool_call);
|
|
||||||
} else {
|
|
||||||
messages.push(open_ai::RequestMessage::Assistant {
|
|
||||||
content: None,
|
|
||||||
tool_calls: vec![tool_call],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MessageContent::ToolResult(tool_result) => {
|
|
||||||
let content = match &tool_result.content {
|
|
||||||
LanguageModelToolResultContent::Text(text) => {
|
|
||||||
vec![open_ai::MessagePart::Text {
|
|
||||||
text: text.to_string(),
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
LanguageModelToolResultContent::Image(image) => {
|
|
||||||
vec![open_ai::MessagePart::Image {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: image.to_base64_url(),
|
|
||||||
detail: None,
|
|
||||||
},
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
messages.push(open_ai::RequestMessage::Tool {
|
|
||||||
content: content.into(),
|
|
||||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
open_ai::Request {
|
|
||||||
model: model.id().into(),
|
|
||||||
messages,
|
|
||||||
stream,
|
|
||||||
stop: request.stop,
|
|
||||||
temperature: request.temperature.unwrap_or(1.0),
|
|
||||||
max_completion_tokens: max_output_tokens,
|
|
||||||
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
|
|
||||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
|
||||||
Some(false)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
tools: request
|
|
||||||
.tools
|
|
||||||
.into_iter()
|
|
||||||
.map(|tool| open_ai::ToolDefinition::Function {
|
|
||||||
function: open_ai::FunctionDefinition {
|
|
||||||
name: tool.name,
|
|
||||||
description: Some(tool.description),
|
|
||||||
parameters: Some(tool.input_schema),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
tool_choice: request.tool_choice.map(|choice| match choice {
|
|
||||||
LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
|
|
||||||
LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
|
|
||||||
LanguageModelToolChoice::None => open_ai::ToolChoice::None,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_message_content_part(
|
|
||||||
new_part: open_ai::MessagePart,
|
|
||||||
role: Role,
|
|
||||||
messages: &mut Vec<open_ai::RequestMessage>,
|
|
||||||
) {
|
|
||||||
match (role, messages.last_mut()) {
|
|
||||||
(Role::User, Some(open_ai::RequestMessage::User { content }))
|
|
||||||
| (
|
|
||||||
Role::Assistant,
|
|
||||||
Some(open_ai::RequestMessage::Assistant {
|
|
||||||
content: Some(content),
|
|
||||||
..
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
| (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
|
|
||||||
content.push_part(new_part);
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
messages.push(match role {
|
|
||||||
Role::User => open_ai::RequestMessage::User {
|
|
||||||
content: open_ai::MessageContent::from(vec![new_part]),
|
|
||||||
},
|
|
||||||
Role::Assistant => open_ai::RequestMessage::Assistant {
|
|
||||||
content: Some(open_ai::MessageContent::from(vec![new_part])),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
},
|
|
||||||
Role::System => open_ai::RequestMessage::System {
|
|
||||||
content: open_ai::MessageContent::from(vec![new_part]),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct VercelEventMapper {
|
|
||||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VercelEventMapper {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
tool_calls_by_index: HashMap::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn map_stream(
|
|
||||||
mut self,
|
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
|
||||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
|
||||||
{
|
|
||||||
events.flat_map(move |event| {
|
|
||||||
futures::stream::iter(match event {
|
|
||||||
Ok(event) => self.map_event(event),
|
|
||||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn map_event(
|
|
||||||
&mut self,
|
|
||||||
event: ResponseStreamEvent,
|
|
||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
|
||||||
let Some(choice) = event.choices.first() else {
|
|
||||||
return Vec::new();
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut events = Vec::new();
|
|
||||||
if let Some(content) = choice.delta.content.clone() {
|
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
|
|
||||||
|
|
||||||
if let Some(tool_id) = tool_call.id.clone() {
|
|
||||||
entry.id = tool_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(function) = tool_call.function.as_ref() {
|
|
||||||
if let Some(name) = function.name.clone() {
|
|
||||||
entry.name = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(arguments) = function.arguments.clone() {
|
|
||||||
entry.arguments.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match choice.finish_reason.as_deref() {
|
|
||||||
Some("stop") => {
|
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
|
||||||
}
|
|
||||||
Some("tool_calls") => {
|
|
||||||
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
|
|
||||||
match serde_json::Value::from_str(&tool_call.arguments) {
|
|
||||||
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
|
|
||||||
LanguageModelToolUse {
|
|
||||||
id: tool_call.id.clone().into(),
|
|
||||||
name: tool_call.name.as_str().into(),
|
|
||||||
is_input_complete: true,
|
|
||||||
input,
|
|
||||||
raw_input: tool_call.arguments.clone(),
|
|
||||||
},
|
|
||||||
)),
|
|
||||||
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
|
|
||||||
id: tool_call.id.into(),
|
|
||||||
tool_name: tool_call.name.as_str().into(),
|
|
||||||
raw_input: tool_call.arguments.into(),
|
|
||||||
json_parse_error: error.to_string(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
|
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
|
||||||
}
|
|
||||||
Some(stop_reason) => {
|
|
||||||
log::error!("Unexpected Vercel stop_reason: {stop_reason:?}",);
|
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
|
||||||
}
|
|
||||||
None => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
events
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
struct RawToolCall {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn count_vercel_tokens(
|
pub fn count_vercel_tokens(
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
@ -825,43 +576,3 @@ impl Render for ConfigurationView {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use gpui::TestAppContext;
|
|
||||||
use language_model::LanguageModelRequestMessage;
|
|
||||||
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
fn tiktoken_rs_support(cx: &TestAppContext) {
|
|
||||||
let request = LanguageModelRequest {
|
|
||||||
thread_id: None,
|
|
||||||
prompt_id: None,
|
|
||||||
intent: None,
|
|
||||||
mode: None,
|
|
||||||
messages: vec![LanguageModelRequestMessage {
|
|
||||||
role: Role::User,
|
|
||||||
content: vec![MessageContent::Text("message".into())],
|
|
||||||
cache: false,
|
|
||||||
}],
|
|
||||||
tools: vec![],
|
|
||||||
tool_choice: None,
|
|
||||||
stop: vec![],
|
|
||||||
temperature: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate that all models are supported by tiktoken-rs
|
|
||||||
for model in Model::iter() {
|
|
||||||
let count = cx
|
|
||||||
.executor()
|
|
||||||
.block(count_vercel_tokens(
|
|
||||||
request.clone(),
|
|
||||||
model,
|
|
||||||
&cx.app.borrow(),
|
|
||||||
))
|
|
||||||
.unwrap();
|
|
||||||
assert!(count > 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,10 +17,7 @@ schemars = ["dep:schemars"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
futures.workspace = true
|
|
||||||
http_client.workspace = true
|
|
||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
|
||||||
strum.workspace = true
|
strum.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
|
@ -1,51 +1,9 @@
|
||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::Result;
|
||||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
|
||||||
use std::{convert::TryFrom, future::Future};
|
|
||||||
use strum::EnumIter;
|
use strum::EnumIter;
|
||||||
|
|
||||||
pub const VERCEL_API_URL: &str = "https://api.v0.dev/v1";
|
pub const VERCEL_API_URL: &str = "https://api.v0.dev/v1";
|
||||||
|
|
||||||
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
|
|
||||||
opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum Role {
|
|
||||||
User,
|
|
||||||
Assistant,
|
|
||||||
System,
|
|
||||||
Tool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<String> for Role {
|
|
||||||
type Error = anyhow::Error;
|
|
||||||
|
|
||||||
fn try_from(value: String) -> Result<Self> {
|
|
||||||
match value.as_str() {
|
|
||||||
"user" => Ok(Self::User),
|
|
||||||
"assistant" => Ok(Self::Assistant),
|
|
||||||
"system" => Ok(Self::System),
|
|
||||||
"tool" => Ok(Self::Tool),
|
|
||||||
_ => anyhow::bail!("invalid role '{value}'"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Role> for String {
|
|
||||||
fn from(val: Role) -> Self {
|
|
||||||
match val {
|
|
||||||
Role::User => "user".to_owned(),
|
|
||||||
Role::Assistant => "assistant".to_owned(),
|
|
||||||
Role::System => "system".to_owned(),
|
|
||||||
Role::Tool => "tool".to_owned(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
|
@ -118,321 +76,3 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct Request {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<RequestMessage>,
|
|
||||||
pub stream: bool,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_completion_tokens: Option<u64>,
|
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
||||||
pub stop: Vec<String>,
|
|
||||||
pub temperature: f32,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_choice: Option<ToolChoice>,
|
|
||||||
/// Whether to enable parallel function calling during tool use.
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parallel_tool_calls: Option<bool>,
|
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
||||||
pub tools: Vec<ToolDefinition>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ToolChoice {
|
|
||||||
Auto,
|
|
||||||
Required,
|
|
||||||
None,
|
|
||||||
Other(ToolDefinition),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum ToolDefinition {
|
|
||||||
#[allow(dead_code)]
|
|
||||||
Function { function: FunctionDefinition },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct FunctionDefinition {
|
|
||||||
pub name: String,
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub parameters: Option<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(tag = "role", rename_all = "lowercase")]
|
|
||||||
pub enum RequestMessage {
|
|
||||||
Assistant {
|
|
||||||
content: Option<MessageContent>,
|
|
||||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
|
||||||
tool_calls: Vec<ToolCall>,
|
|
||||||
},
|
|
||||||
User {
|
|
||||||
content: MessageContent,
|
|
||||||
},
|
|
||||||
System {
|
|
||||||
content: MessageContent,
|
|
||||||
},
|
|
||||||
Tool {
|
|
||||||
content: MessageContent,
|
|
||||||
tool_call_id: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum MessageContent {
|
|
||||||
Plain(String),
|
|
||||||
Multipart(Vec<MessagePart>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MessageContent {
|
|
||||||
pub fn empty() -> Self {
|
|
||||||
MessageContent::Multipart(vec![])
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn push_part(&mut self, part: MessagePart) {
|
|
||||||
match self {
|
|
||||||
MessageContent::Plain(text) => {
|
|
||||||
*self =
|
|
||||||
MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
|
|
||||||
}
|
|
||||||
MessageContent::Multipart(parts) if parts.is_empty() => match part {
|
|
||||||
MessagePart::Text { text } => *self = MessageContent::Plain(text),
|
|
||||||
MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
|
|
||||||
},
|
|
||||||
MessageContent::Multipart(parts) => parts.push(part),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Vec<MessagePart>> for MessageContent {
|
|
||||||
fn from(mut parts: Vec<MessagePart>) -> Self {
|
|
||||||
if let [MessagePart::Text { text }] = parts.as_mut_slice() {
|
|
||||||
MessageContent::Plain(std::mem::take(text))
|
|
||||||
} else {
|
|
||||||
MessageContent::Multipart(parts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum MessagePart {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text { text: String },
|
|
||||||
#[serde(rename = "image_url")]
|
|
||||||
Image { image_url: ImageUrl },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ImageUrl {
|
|
||||||
pub url: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub detail: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub content: ToolCallContent,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
#[serde(tag = "type", rename_all = "lowercase")]
|
|
||||||
pub enum ToolCallContent {
|
|
||||||
Function { function: FunctionContent },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct FunctionContent {
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ResponseMessageDelta {
|
|
||||||
pub role: Option<Role>,
|
|
||||||
pub content: Option<String>,
|
|
||||||
#[serde(default, skip_serializing_if = "is_none_or_empty")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCallChunk>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct ToolCallChunk {
|
|
||||||
pub index: usize,
|
|
||||||
pub id: Option<String>,
|
|
||||||
|
|
||||||
// There is also an optional `type` field that would determine if a
|
|
||||||
// function is there. Sometimes this streams in with the `function` before
|
|
||||||
// it streams in the `type`
|
|
||||||
pub function: Option<FunctionChunk>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
|
||||||
pub struct FunctionChunk {
|
|
||||||
pub name: Option<String>,
|
|
||||||
pub arguments: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub struct ChoiceDelta {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ResponseMessageDelta,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ResponseStreamResult {
|
|
||||||
Ok(ResponseStreamEvent),
|
|
||||||
Err { error: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub struct ResponseStreamEvent {
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChoiceDelta>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stream_completion(
|
|
||||||
client: &dyn HttpClient,
|
|
||||||
api_url: &str,
|
|
||||||
api_key: &str,
|
|
||||||
request: Request,
|
|
||||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
|
||||||
let uri = format!("{api_url}/chat/completions");
|
|
||||||
let request_builder = HttpRequest::builder()
|
|
||||||
.method(Method::POST)
|
|
||||||
.uri(uri)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key));
|
|
||||||
|
|
||||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
|
||||||
let mut response = client.send(request).await?;
|
|
||||||
if response.status().is_success() {
|
|
||||||
let reader = BufReader::new(response.into_body());
|
|
||||||
Ok(reader
|
|
||||||
.lines()
|
|
||||||
.filter_map(|line| async move {
|
|
||||||
match line {
|
|
||||||
Ok(line) => {
|
|
||||||
let line = line.strip_prefix("data: ")?;
|
|
||||||
if line == "[DONE]" {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
match serde_json::from_str(line) {
|
|
||||||
Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
|
|
||||||
Ok(ResponseStreamResult::Err { error }) => {
|
|
||||||
Some(Err(anyhow!(error)))
|
|
||||||
}
|
|
||||||
Err(error) => Some(Err(anyhow!(error))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => Some(Err(anyhow!(error))),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.boxed())
|
|
||||||
} else {
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct VercelResponse {
|
|
||||||
error: VercelError,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct VercelError {
|
|
||||||
message: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str::<VercelResponse>(&body) {
|
|
||||||
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
|
|
||||||
"Failed to connect to Vercel API: {}",
|
|
||||||
response.error.message,
|
|
||||||
)),
|
|
||||||
|
|
||||||
_ => anyhow::bail!(
|
|
||||||
"Failed to connect to Vercel API: {} {}",
|
|
||||||
response.status(),
|
|
||||||
body,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum VercelEmbeddingModel {
|
|
||||||
#[serde(rename = "text-embedding-3-small")]
|
|
||||||
TextEmbedding3Small,
|
|
||||||
#[serde(rename = "text-embedding-3-large")]
|
|
||||||
TextEmbedding3Large,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct VercelEmbeddingRequest<'a> {
|
|
||||||
model: VercelEmbeddingModel,
|
|
||||||
input: Vec<&'a str>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct VercelEmbeddingResponse {
|
|
||||||
pub data: Vec<VercelEmbedding>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub struct VercelEmbedding {
|
|
||||||
pub embedding: Vec<f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn embed<'a>(
|
|
||||||
client: &dyn HttpClient,
|
|
||||||
api_url: &str,
|
|
||||||
api_key: &str,
|
|
||||||
model: VercelEmbeddingModel,
|
|
||||||
texts: impl IntoIterator<Item = &'a str>,
|
|
||||||
) -> impl 'static + Future<Output = Result<VercelEmbeddingResponse>> {
|
|
||||||
let uri = format!("{api_url}/embeddings");
|
|
||||||
|
|
||||||
let request = VercelEmbeddingRequest {
|
|
||||||
model,
|
|
||||||
input: texts.into_iter().collect(),
|
|
||||||
};
|
|
||||||
let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
|
|
||||||
let request = HttpRequest::builder()
|
|
||||||
.method(Method::POST)
|
|
||||||
.uri(uri)
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.body(body)
|
|
||||||
.map(|request| client.send(request));
|
|
||||||
|
|
||||||
async move {
|
|
||||||
let mut response = request?.await?;
|
|
||||||
let mut body = String::new();
|
|
||||||
response.body_mut().read_to_string(&mut body).await?;
|
|
||||||
|
|
||||||
anyhow::ensure!(
|
|
||||||
response.status().is_success(),
|
|
||||||
"error during embedding, status: {:?}, body: {:?}",
|
|
||||||
response.status(),
|
|
||||||
body
|
|
||||||
);
|
|
||||||
let response: VercelEmbeddingResponse =
|
|
||||||
serde_json::from_str(&body).context("failed to parse Vercel embedding response")?;
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue