WIP and merge
This commit is contained in:
parent
97f4406ef6
commit
1bdde8b2e4
584 changed files with 33536 additions and 17400 deletions
|
@ -40,8 +40,8 @@ mistral = { workspace = true, features = ["schemars"] }
|
|||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
open_router = { workspace = true, features = ["schemars"] }
|
||||
vercel = { workspace = true, features = ["schemars"] }
|
||||
partial-json-fixer.workspace = true
|
||||
project.workspace = true
|
||||
proto.workspace = true
|
||||
release_channel.workspace = true
|
||||
schemars.workspace = true
|
||||
|
@ -55,9 +55,11 @@ thiserror.workspace = true
|
|||
tiktoken-rs.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
zed_llm_client.workspace = true
|
||||
language.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use client::{Client, UserStore};
|
||||
use fs::Fs;
|
||||
use gpui::{App, Context, Entity};
|
||||
use language_model::LanguageModelRegistry;
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
@ -20,10 +19,11 @@ use crate::provider::mistral::MistralLanguageModelProvider;
|
|||
use crate::provider::ollama::OllamaLanguageModelProvider;
|
||||
use crate::provider::open_ai::OpenAiLanguageModelProvider;
|
||||
use crate::provider::open_router::OpenRouterLanguageModelProvider;
|
||||
use crate::provider::vercel::VercelLanguageModelProvider;
|
||||
pub use crate::settings::*;
|
||||
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
|
||||
crate::settings::init(fs, cx);
|
||||
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
crate::settings::init(cx);
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
registry.update(cx, |registry, cx| {
|
||||
register_language_model_providers(registry, user_store, client, cx);
|
||||
|
@ -77,5 +77,9 @@ fn register_language_model_providers(
|
|||
OpenRouterLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(
|
||||
VercelLanguageModelProvider::new(client.http_client(), cx),
|
||||
cx,
|
||||
);
|
||||
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
|
||||
}
|
||||
|
|
|
@ -9,3 +9,4 @@ pub mod mistral;
|
|||
pub mod ollama;
|
||||
pub mod open_ai;
|
||||
pub mod open_router;
|
||||
pub mod vercel;
|
||||
|
|
|
@ -16,10 +16,10 @@ use gpui::{
|
|||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
|
||||
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
|
||||
LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
|
||||
RateLimiter, Role,
|
||||
};
|
||||
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
|
||||
use schemars::JsonSchema;
|
||||
|
@ -41,7 +41,6 @@ pub struct AnthropicSettings {
|
|||
pub api_url: String,
|
||||
/// Extend Zed's list of Anthropic models.
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
|
@ -51,12 +50,12 @@ pub struct AvailableModel {
|
|||
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
|
||||
pub display_name: Option<String>,
|
||||
/// The model's context window size.
|
||||
pub max_tokens: usize,
|
||||
pub max_tokens: u64,
|
||||
/// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
|
||||
pub tool_override: Option<String>,
|
||||
/// Configuration of Anthropic's caching API.
|
||||
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub default_temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
pub extra_beta_headers: Vec<String>,
|
||||
|
@ -321,7 +320,7 @@ pub struct AnthropicModel {
|
|||
pub fn count_anthropic_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request.messages;
|
||||
let mut tokens_from_images = 0;
|
||||
|
@ -377,7 +376,7 @@ pub fn count_anthropic_tokens(
|
|||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
|
||||
.map(|tokens| tokens + tokens_from_images)
|
||||
.map(|tokens| (tokens + tokens_from_images) as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -407,14 +406,7 @@ impl AnthropicModel {
|
|||
let api_key = api_key.context("Missing Anthropic API Key")?;
|
||||
let request =
|
||||
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
request.await.map_err(|err| match err {
|
||||
AnthropicError::RateLimit(duration) => {
|
||||
LanguageModelCompletionError::RateLimit(duration)
|
||||
}
|
||||
err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
|
||||
LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
|
||||
}
|
||||
})
|
||||
request.await.map_err(Into::into)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
@ -461,11 +453,11 @@ impl LanguageModel for AnthropicModel {
|
|||
self.state.read(cx).api_key.clone()
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
Some(self.model.max_output_tokens())
|
||||
}
|
||||
|
||||
|
@ -473,7 +465,7 @@ impl LanguageModel for AnthropicModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_anthropic_tokens(request, cx)
|
||||
}
|
||||
|
||||
|
@ -518,7 +510,7 @@ pub fn into_anthropic(
|
|||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u32,
|
||||
max_output_tokens: u64,
|
||||
mode: AnthropicModelMode,
|
||||
) -> anthropic::Request {
|
||||
let mut new_messages: Vec<anthropic::Message> = Vec::new();
|
||||
|
@ -561,9 +553,7 @@ pub fn into_anthropic(
|
|||
}
|
||||
MessageContent::RedactedThinking(data) => {
|
||||
if !data.is_empty() {
|
||||
Some(anthropic::RequestContent::RedactedThinking {
|
||||
data: String::from_utf8(data).ok()?,
|
||||
})
|
||||
Some(anthropic::RequestContent::RedactedThinking { data })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -714,7 +704,7 @@ impl AnthropicEventMapper {
|
|||
events.flat_map(move |event| {
|
||||
futures::stream::iter(match event {
|
||||
Ok(event) => self.map_event(event),
|
||||
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
|
||||
Err(error) => vec![Err(error.into())],
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -737,10 +727,8 @@ impl AnthropicEventMapper {
|
|||
signature: None,
|
||||
})]
|
||||
}
|
||||
ResponseContent::RedactedThinking { .. } => {
|
||||
// Redacted thinking is encrypted and not accessible to the user, see:
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
|
||||
Vec::new()
|
||||
ResponseContent::RedactedThinking { data } => {
|
||||
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
|
||||
}
|
||||
ResponseContent::ToolUse { id, name, .. } => {
|
||||
self.tool_uses_by_index.insert(
|
||||
|
@ -859,9 +847,7 @@ impl AnthropicEventMapper {
|
|||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||
}
|
||||
Event::Error { error } => {
|
||||
vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
AnthropicError::ApiError(error)
|
||||
)))]
|
||||
vec![Err(error.into())]
|
||||
}
|
||||
_ => Vec::new(),
|
||||
}
|
||||
|
@ -874,16 +860,6 @@ struct RawToolUse {
|
|||
input_json: String,
|
||||
}
|
||||
|
||||
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
|
||||
if let AnthropicError::ApiError(api_err) = &err {
|
||||
if let Some(tokens) = api_err.match_window_exceeded() {
|
||||
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
|
||||
}
|
||||
}
|
||||
|
||||
anyhow!(err)
|
||||
}
|
||||
|
||||
/// Updates usage data by preferring counts from `new`.
|
||||
fn update_usage(usage: &mut Usage, new: &Usage) {
|
||||
if let Some(input_tokens) = new.input_tokens {
|
||||
|
|
|
@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
|
|||
use bedrock::bedrock_client::Client as BedrockClient;
|
||||
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
|
||||
use bedrock::bedrock_client::types::{
|
||||
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
|
||||
StopReason,
|
||||
CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
|
||||
ReasoningContentBlockDelta, StopReason,
|
||||
};
|
||||
use bedrock::{
|
||||
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
|
||||
|
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
|
|||
use theme::ThemeSettings;
|
||||
use tokio::runtime::Handle;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::{ResultExt, default};
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
||||
|
@ -88,9 +88,9 @@ pub enum BedrockAuthMethod {
|
|||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_tokens: u64,
|
||||
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub default_temperature: Option<f32>,
|
||||
pub mode: Option<ModelMode>,
|
||||
}
|
||||
|
@ -229,6 +229,17 @@ impl State {
|
|||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn get_region(&self) -> String {
|
||||
// Get region - from credentials or directly from settings
|
||||
let credentials_region = self.credentials.as_ref().map(|s| s.region.clone());
|
||||
let settings_region = self.settings.as_ref().and_then(|s| s.region.clone());
|
||||
|
||||
// Use credentials region if available, otherwise use settings region, finally fall back to default
|
||||
credentials_region
|
||||
.or(settings_region)
|
||||
.unwrap_or(String::from("us-east-1"))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BedrockLanguageModelProvider {
|
||||
|
@ -289,8 +300,9 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
|
|||
Some(self.create_language_model(bedrock::Model::default()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(bedrock::Model::default_fast()))
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
let region = self.state.read(cx).get_region();
|
||||
Some(self.create_language_model(bedrock::Model::default_fast(region.as_str())))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
|
@ -317,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
|
|||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
default_temperature: model.default_temperature,
|
||||
cache_configuration: model.cache_configuration.as_ref().map(|config| {
|
||||
bedrock::BedrockModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
min_total_token: config.min_total_token,
|
||||
}
|
||||
}),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -377,11 +395,7 @@ impl BedrockModel {
|
|||
|
||||
let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
|
||||
|
||||
let region = state
|
||||
.settings
|
||||
.as_ref()
|
||||
.and_then(|s| s.region.clone())
|
||||
.unwrap_or(String::from("us-east-1"));
|
||||
let region = state.get_region();
|
||||
|
||||
(
|
||||
auth_method,
|
||||
|
@ -495,7 +509,8 @@ impl LanguageModel for BedrockModel {
|
|||
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
|
||||
self.model.supports_tool_use()
|
||||
}
|
||||
LanguageModelToolChoice::None => false,
|
||||
// Add support for None - we'll filter tool calls at response
|
||||
LanguageModelToolChoice::None => self.model.supports_tool_use(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,11 +518,11 @@ impl LanguageModel for BedrockModel {
|
|||
format!("bedrock/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
Some(self.model.max_output_tokens())
|
||||
}
|
||||
|
||||
|
@ -515,7 +530,7 @@ impl LanguageModel for BedrockModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
get_bedrock_tokens(request, cx)
|
||||
}
|
||||
|
||||
|
@ -530,16 +545,7 @@ impl LanguageModel for BedrockModel {
|
|||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
|
||||
// Get region - from credentials or directly from settings
|
||||
let credentials_region = state.credentials.as_ref().map(|s| s.region.clone());
|
||||
let settings_region = state.settings.as_ref().and_then(|s| s.region.clone());
|
||||
|
||||
// Use credentials region if available, otherwise use settings region, finally fall back to default
|
||||
credentials_region
|
||||
.or(settings_region)
|
||||
.unwrap_or(String::from("us-east-1"))
|
||||
}) else {
|
||||
let Ok(region) = cx.read_entity(&self.state, |state, _cx| state.get_region()) else {
|
||||
return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
|
||||
};
|
||||
|
||||
|
@ -550,12 +556,15 @@ impl LanguageModel for BedrockModel {
|
|||
}
|
||||
};
|
||||
|
||||
let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None);
|
||||
|
||||
let request = match into_bedrock(
|
||||
request,
|
||||
model_id,
|
||||
self.model.default_temperature(),
|
||||
self.model.max_output_tokens(),
|
||||
self.model.mode(),
|
||||
self.model.supports_caching(),
|
||||
) {
|
||||
Ok(request) => request,
|
||||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||
|
@ -566,25 +575,53 @@ impl LanguageModel for BedrockModel {
|
|||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request.map_err(|err| anyhow!(err))?.await;
|
||||
Ok(map_to_language_model_completion_events(
|
||||
response,
|
||||
owned_handle,
|
||||
))
|
||||
let events = map_to_language_model_completion_events(response, owned_handle);
|
||||
|
||||
if deny_tool_calls {
|
||||
Ok(deny_tool_use_events(events).boxed())
|
||||
} else {
|
||||
Ok(events.boxed())
|
||||
}
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
None
|
||||
self.model
|
||||
.cache_configuration()
|
||||
.map(|config| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: false,
|
||||
min_total_token: config.min_total_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn deny_tool_use_events(
|
||||
events: impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
events.map(|event| {
|
||||
match event {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
|
||||
// Convert tool use to an error message if model decided to call it
|
||||
Ok(LanguageModelCompletionEvent::Text(format!(
|
||||
"\n\n[Error: Tool calls are disabled in this context. Attempted to call '{}']",
|
||||
tool_use.name
|
||||
)))
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_bedrock(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
default_temperature: f32,
|
||||
max_output_tokens: u32,
|
||||
max_output_tokens: u64,
|
||||
mode: BedrockModelMode,
|
||||
supports_caching: bool,
|
||||
) -> Result<bedrock::Request> {
|
||||
let mut new_messages: Vec<BedrockMessage> = Vec::new();
|
||||
let mut system_message = String::new();
|
||||
|
@ -596,7 +633,7 @@ pub fn into_bedrock(
|
|||
|
||||
match message.role {
|
||||
Role::User | Role::Assistant => {
|
||||
let bedrock_message_content: Vec<BedrockInnerContent> = message
|
||||
let mut bedrock_message_content: Vec<BedrockInnerContent> = message
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
|
@ -608,6 +645,11 @@ pub fn into_bedrock(
|
|||
}
|
||||
}
|
||||
MessageContent::Thinking { text, signature } => {
|
||||
if model.contains(Model::DeepSeekR1.request_id()) {
|
||||
// DeepSeekR1 doesn't support thinking blocks
|
||||
// And the AWS API demands that you strip them
|
||||
return None;
|
||||
}
|
||||
let thinking = BedrockThinkingTextBlock::builder()
|
||||
.text(text)
|
||||
.set_signature(signature)
|
||||
|
@ -620,19 +662,32 @@ pub fn into_bedrock(
|
|||
))
|
||||
}
|
||||
MessageContent::RedactedThinking(blob) => {
|
||||
if model.contains(Model::DeepSeekR1.request_id()) {
|
||||
// DeepSeekR1 doesn't support thinking blocks
|
||||
// And the AWS API demands that you strip them
|
||||
return None;
|
||||
}
|
||||
let redacted =
|
||||
BedrockThinkingBlock::RedactedContent(BedrockBlob::new(blob));
|
||||
|
||||
Some(BedrockInnerContent::ReasoningContent(redacted))
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder()
|
||||
.name(tool_use.name.to_string())
|
||||
.tool_use_id(tool_use.id.to_string())
|
||||
.input(value_to_aws_document(&tool_use.input))
|
||||
.build()
|
||||
.context("failed to build Bedrock tool use block")
|
||||
.log_err()
|
||||
.map(BedrockInnerContent::ToolUse),
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let input = if tool_use.input.is_null() {
|
||||
// Bedrock API requires valid JsonValue, not null, for tool use input
|
||||
value_to_aws_document(&serde_json::json!({}))
|
||||
} else {
|
||||
value_to_aws_document(&tool_use.input)
|
||||
};
|
||||
BedrockToolUseBlock::builder()
|
||||
.name(tool_use.name.to_string())
|
||||
.tool_use_id(tool_use.id.to_string())
|
||||
.input(input)
|
||||
.build()
|
||||
.context("failed to build Bedrock tool use block")
|
||||
.log_err()
|
||||
.map(BedrockInnerContent::ToolUse)
|
||||
},
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
BedrockToolResultBlock::builder()
|
||||
.tool_use_id(tool_result.tool_use_id.to_string())
|
||||
|
@ -662,6 +717,14 @@ pub fn into_bedrock(
|
|||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
if message.cache && supports_caching {
|
||||
bedrock_message_content.push(BedrockInnerContent::CachePoint(
|
||||
CachePointBlock::builder()
|
||||
.r#type(CachePointType::Default)
|
||||
.build()
|
||||
.context("failed to build cache point block")?,
|
||||
));
|
||||
}
|
||||
let bedrock_role = match message.role {
|
||||
Role::User => bedrock::BedrockRole::User,
|
||||
Role::Assistant => bedrock::BedrockRole::Assistant,
|
||||
|
@ -690,7 +753,7 @@ pub fn into_bedrock(
|
|||
}
|
||||
}
|
||||
|
||||
let tool_spec: Vec<BedrockTool> = request
|
||||
let mut tool_spec: Vec<BedrockTool> = request
|
||||
.tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
|
@ -707,6 +770,15 @@ pub fn into_bedrock(
|
|||
})
|
||||
.collect();
|
||||
|
||||
if !tool_spec.is_empty() && supports_caching {
|
||||
tool_spec.push(BedrockTool::CachePoint(
|
||||
CachePointBlock::builder()
|
||||
.r#type(CachePointType::Default)
|
||||
.build()
|
||||
.context("failed to build cache point block")?,
|
||||
));
|
||||
}
|
||||
|
||||
let tool_choice = match request.tool_choice {
|
||||
Some(LanguageModelToolChoice::Auto) | None => {
|
||||
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||
|
@ -715,7 +787,8 @@ pub fn into_bedrock(
|
|||
BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
|
||||
}
|
||||
Some(LanguageModelToolChoice::None) => {
|
||||
anyhow::bail!("LanguageModelToolChoice::None is not supported");
|
||||
// For None, we still use Auto but will filter out tool calls in the response
|
||||
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
|
||||
}
|
||||
};
|
||||
let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
|
||||
|
@ -747,7 +820,7 @@ pub fn into_bedrock(
|
|||
pub fn get_bedrock_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_executor()
|
||||
.spawn(async move {
|
||||
let messages = request.messages;
|
||||
|
@ -799,7 +872,7 @@ pub fn get_bedrock_tokens(
|
|||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
|
||||
.map(|tokens| tokens + tokens_from_images)
|
||||
.map(|tokens| (tokens + tokens_from_images) as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -947,11 +1020,12 @@ pub fn map_to_language_model_completion_events(
|
|||
let completion_event =
|
||||
LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u32,
|
||||
output_tokens: metadata.output_tokens
|
||||
as u32,
|
||||
cache_creation_input_tokens: default(),
|
||||
cache_read_input_tokens: default(),
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens:
|
||||
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
|
||||
cache_read_input_tokens:
|
||||
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
|
||||
},
|
||||
);
|
||||
return Some((Some(Ok(completion_event)), state));
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use client::{Client, UserStore, zed_urls};
|
||||
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
|
||||
use futures::{
|
||||
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
|
||||
};
|
||||
|
@ -14,7 +14,7 @@ use language_model::{
|
|||
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
|
||||
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use language_model::{
|
||||
|
@ -73,9 +73,9 @@ pub struct AvailableModel {
|
|||
/// The size of the context window, indicating the maximum number of tokens the model can process.
|
||||
pub max_tokens: usize,
|
||||
/// The maximum number of output tokens allowed by the model.
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
/// The maximum number of completion tokens allowed by the model (o1-* only)
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
/// Override this model with a different Anthropic model for tool calls.
|
||||
pub tool_override: Option<String>,
|
||||
/// Indicates whether this custom model supports caching.
|
||||
|
@ -530,7 +530,7 @@ pub struct CloudLanguageModel {
|
|||
|
||||
struct PerformLlmCompletionResponse {
|
||||
response: Response<AsyncBody>,
|
||||
usage: Option<RequestUsage>,
|
||||
usage: Option<ModelRequestUsage>,
|
||||
tool_use_limit_reached: bool,
|
||||
includes_status_messages: bool,
|
||||
}
|
||||
|
@ -581,7 +581,7 @@ impl CloudLanguageModel {
|
|||
let usage = if includes_status_messages {
|
||||
None
|
||||
} else {
|
||||
RequestUsage::from_headers(response.headers()).ok()
|
||||
ModelRequestUsage::from_headers(response.headers()).ok()
|
||||
};
|
||||
|
||||
return Ok(PerformLlmCompletionResponse {
|
||||
|
@ -715,8 +715,8 @@ impl LanguageModel for CloudLanguageModel {
|
|||
}
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count as u64
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
|
@ -737,7 +737,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
match self.model.provider {
|
||||
zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
|
||||
zed_llm_client::LanguageModelProvider::OpenAi => {
|
||||
|
@ -786,7 +786,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
let response_body: CountTokensResponse =
|
||||
serde_json::from_str(&response_body)?;
|
||||
|
||||
Ok(response_body.tokens)
|
||||
Ok(response_body.tokens as u64)
|
||||
} else {
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
|
@ -821,7 +821,7 @@ impl LanguageModel for CloudLanguageModel {
|
|||
request,
|
||||
self.model.id.to_string(),
|
||||
1.0,
|
||||
self.model.max_output_tokens as u32,
|
||||
self.model.max_output_tokens as u64,
|
||||
if self.model.id.0.ends_with("-thinking") {
|
||||
AnthropicModelMode::Thinking {
|
||||
budget_tokens: Some(4_096),
|
||||
|
@ -888,7 +888,12 @@ impl LanguageModel for CloudLanguageModel {
|
|||
Ok(model) => model,
|
||||
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
|
||||
};
|
||||
let request = into_open_ai(request, &model, None);
|
||||
let request = into_open_ai(
|
||||
request,
|
||||
model.id(),
|
||||
model.supports_parallel_tool_calls(),
|
||||
None,
|
||||
);
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let PerformLlmCompletionResponse {
|
||||
|
@ -1002,7 +1007,7 @@ where
|
|||
}
|
||||
|
||||
fn usage_updated_event<T>(
|
||||
usage: Option<RequestUsage>,
|
||||
usage: Option<ModelRequestUsage>,
|
||||
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
|
||||
futures::stream::iter(usage.map(|usage| {
|
||||
Ok(CloudCompletionEvent::Status(
|
||||
|
|
|
@ -10,35 +10,30 @@ use copilot::copilot_chat::{
|
|||
ToolCall,
|
||||
};
|
||||
use copilot::{Copilot, Status};
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use fs::Fs;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{FutureExt, Stream, StreamExt};
|
||||
use gpui::{
|
||||
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, FontStyle, Render,
|
||||
Subscription, Task, TextStyle, Transformation, WhiteSpace, percentage, svg,
|
||||
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
|
||||
Transformation, percentage, svg,
|
||||
};
|
||||
use language::language_settings::all_language_settings;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
|
||||
StopReason,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use settings::SettingsStore;
|
||||
use std::time::Duration;
|
||||
use theme::ThemeSettings;
|
||||
use ui::prelude::*;
|
||||
use util::debug_panic;
|
||||
|
||||
use crate::{AllLanguageModelSettings, CopilotChatSettingsContent};
|
||||
|
||||
use super::anthropic::count_anthropic_tokens;
|
||||
use super::google::count_google_tokens;
|
||||
use super::open_ai::count_open_ai_tokens;
|
||||
pub(crate) use copilot::copilot_chat::CopilotChatSettings;
|
||||
|
||||
const PROVIDER_ID: &str = "copilot_chat";
|
||||
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
|
||||
|
@ -69,11 +64,16 @@ impl CopilotChatLanguageModelProvider {
|
|||
_copilot_chat_subscription: copilot_chat_subscription,
|
||||
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
|
||||
if let Some(copilot_chat) = CopilotChat::global(cx) {
|
||||
let settings = AllLanguageModelSettings::get_global(cx)
|
||||
.copilot_chat
|
||||
.clone();
|
||||
let language_settings = all_language_settings(None, cx);
|
||||
let configuration = copilot::copilot_chat::CopilotChatConfiguration {
|
||||
enterprise_uri: language_settings
|
||||
.edit_predictions
|
||||
.copilot
|
||||
.enterprise_uri
|
||||
.clone(),
|
||||
};
|
||||
copilot_chat.update(cx, |chat, cx| {
|
||||
chat.set_settings(settings, cx);
|
||||
chat.set_configuration(configuration, cx);
|
||||
});
|
||||
}
|
||||
cx.notify();
|
||||
|
@ -174,10 +174,9 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
|
|||
Task::ready(Err(err.into()))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
|
||||
let state = self.state.clone();
|
||||
cx.new(|cx| ConfigurationView::new(state, window, cx))
|
||||
.into()
|
||||
cx.new(|cx| ConfigurationView::new(state, cx)).into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
|
||||
|
@ -238,7 +237,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
format!("copilot_chat/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
|
@ -246,7 +245,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
match self.model.vendor() {
|
||||
ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
|
||||
ModelVendor::Google => count_google_tokens(request, cx),
|
||||
|
@ -268,23 +267,6 @@ impl LanguageModel for CopilotChatLanguageModel {
|
|||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
if let Some(message) = request.messages.last() {
|
||||
if message.contents_empty() {
|
||||
const EMPTY_PROMPT_MSG: &str =
|
||||
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
|
||||
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into()))
|
||||
.boxed();
|
||||
}
|
||||
|
||||
// Copilot Chat has a restriction that the final message must be from the user.
|
||||
// While their API does return an error message for this, we can catch it earlier
|
||||
// and provide a more helpful error message.
|
||||
if !matches!(message.role, Role::User) {
|
||||
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
|
||||
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed();
|
||||
}
|
||||
}
|
||||
|
||||
let copilot_request = match into_copilot_chat(&self.model, request) {
|
||||
Ok(request) => request,
|
||||
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
|
||||
|
@ -379,6 +361,17 @@ pub fn map_to_language_model_completion_events(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
)));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||
|
@ -622,38 +615,15 @@ fn into_copilot_chat(
|
|||
|
||||
struct ConfigurationView {
|
||||
copilot_status: Option<copilot::Status>,
|
||||
api_url_editor: Entity<Editor>,
|
||||
models_url_editor: Entity<Editor>,
|
||||
auth_url_editor: Entity<Editor>,
|
||||
state: Entity<State>,
|
||||
_subscription: Option<Subscription>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
pub fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
|
||||
let copilot = Copilot::global(cx);
|
||||
let settings = AllLanguageModelSettings::get_global(cx)
|
||||
.copilot_chat
|
||||
.clone();
|
||||
let api_url_editor = cx.new(|cx| Editor::single_line(window, cx));
|
||||
api_url_editor.update(cx, |this, cx| {
|
||||
this.set_text(settings.api_url.clone(), window, cx);
|
||||
this.set_placeholder_text("GitHub Copilot API URL", cx);
|
||||
});
|
||||
let models_url_editor = cx.new(|cx| Editor::single_line(window, cx));
|
||||
models_url_editor.update(cx, |this, cx| {
|
||||
this.set_text(settings.models_url.clone(), window, cx);
|
||||
this.set_placeholder_text("GitHub Copilot Models URL", cx);
|
||||
});
|
||||
let auth_url_editor = cx.new(|cx| Editor::single_line(window, cx));
|
||||
auth_url_editor.update(cx, |this, cx| {
|
||||
this.set_text(settings.auth_url.clone(), window, cx);
|
||||
this.set_placeholder_text("GitHub Copilot Auth URL", cx);
|
||||
});
|
||||
|
||||
Self {
|
||||
api_url_editor,
|
||||
models_url_editor,
|
||||
auth_url_editor,
|
||||
copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
|
||||
state,
|
||||
_subscription: copilot.as_ref().map(|copilot| {
|
||||
|
@ -664,104 +634,6 @@ impl ConfigurationView {
|
|||
}),
|
||||
}
|
||||
}
|
||||
fn make_input_styles(&self, cx: &App) -> Div {
|
||||
let bg_color = cx.theme().colors().editor_background;
|
||||
let border_color = cx.theme().colors().border;
|
||||
|
||||
h_flex()
|
||||
.w_full()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(bg_color)
|
||||
.border_1()
|
||||
.border_color(border_color)
|
||||
.rounded_sm()
|
||||
}
|
||||
|
||||
fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
background_color: None,
|
||||
underline: None,
|
||||
strikethrough: None,
|
||||
white_space: WhiteSpace::Normal,
|
||||
text_overflow: None,
|
||||
text_align: Default::default(),
|
||||
line_clamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn render_api_url_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let text_style = self.make_text_style(cx);
|
||||
|
||||
EditorElement::new(
|
||||
&self.api_url_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn render_auth_url_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let text_style = self.make_text_style(cx);
|
||||
|
||||
EditorElement::new(
|
||||
&self.auth_url_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
fn render_models_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let text_style = self.make_text_style(cx);
|
||||
|
||||
EditorElement::new(
|
||||
&self.models_url_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
fn update_copilot_settings(&self, cx: &mut Context<'_, Self>) {
|
||||
let settings = CopilotChatSettings {
|
||||
api_url: self.api_url_editor.read(cx).text(cx).into(),
|
||||
models_url: self.models_url_editor.read(cx).text(cx).into(),
|
||||
auth_url: self.auth_url_editor.read(cx).text(cx).into(),
|
||||
};
|
||||
update_settings_file::<AllLanguageModelSettings>(<dyn Fs>::global(cx), cx, {
|
||||
let settings = settings.clone();
|
||||
move |content, _| {
|
||||
content.copilot_chat = Some(CopilotChatSettingsContent {
|
||||
api_url: Some(settings.api_url.as_ref().into()),
|
||||
models_url: Some(settings.models_url.as_ref().into()),
|
||||
auth_url: Some(settings.auth_url.as_ref().into()),
|
||||
});
|
||||
}
|
||||
});
|
||||
if let Some(chat) = CopilotChat::global(cx) {
|
||||
chat.update(cx, |this, cx| {
|
||||
this.set_settings(settings, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
|
@ -819,59 +691,15 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
_ => {
|
||||
const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Label::new(LABEL))
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, window, cx| {
|
||||
this.update_copilot_settings(cx);
|
||||
copilot::initiate_sign_in(window, cx);
|
||||
}))
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Label::new("API URL").size(LabelSize::Small))
|
||||
.child(
|
||||
self.make_input_styles(cx)
|
||||
.child(self.render_api_url_editor(cx)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Label::new("Auth URL").size(LabelSize::Small))
|
||||
.child(
|
||||
self.make_input_styles(cx)
|
||||
.child(self.render_auth_url_editor(cx)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
v_flex()
|
||||
.gap_0p5()
|
||||
.child(Label::new("Models list URL").size(LabelSize::Small))
|
||||
.child(
|
||||
self.make_input_styles(cx)
|
||||
.child(self.render_models_editor(cx)),
|
||||
),
|
||||
)
|
||||
.child(
|
||||
Button::new("sign_in", "Sign in to use GitHub Copilot")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Medium)
|
||||
.full_width()
|
||||
.on_click(cx.listener(|this, _, window, cx| {
|
||||
this.update_copilot_settings(cx);
|
||||
copilot::initiate_sign_in(window, cx)
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {} environment variable and restart Zed.", copilot::copilot_chat::COPILOT_OAUTH_ENV_VAR),
|
||||
)
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
v_flex().gap_2().child(Label::new(LABEL)).child(
|
||||
Button::new("sign_in", "Sign in to use GitHub Copilot")
|
||||
.icon_color(Color::Muted)
|
||||
.icon(IconName::Github)
|
||||
.icon_position(IconPosition::Start)
|
||||
.icon_size(IconSize::Medium)
|
||||
.full_width()
|
||||
.on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
|
||||
)
|
||||
}
|
||||
},
|
||||
None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
|
||||
|
|
|
@ -14,7 +14,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -49,8 +49,8 @@ pub struct DeepSeekSettings {
|
|||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct DeepSeekLanguageModelProvider {
|
||||
|
@ -306,11 +306,11 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
format!("deepseek/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
|
@ -318,7 +318,7 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
|
@ -335,7 +335,7 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -365,7 +365,7 @@ impl LanguageModel for DeepSeekLanguageModel {
|
|||
pub fn into_deepseek(
|
||||
request: LanguageModelRequest,
|
||||
model: &deepseek::Model,
|
||||
max_output_tokens: Option<u32>,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> deepseek::Request {
|
||||
let is_reasoner = *model == deepseek::Model::Reasoner;
|
||||
|
||||
|
@ -513,6 +513,15 @@ impl DeepSeekEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
|
|
|
@ -79,7 +79,7 @@ impl From<GoogleModelMode> for ModelMode {
|
|||
pub struct AvailableModel {
|
||||
name: String,
|
||||
display_name: Option<String>,
|
||||
max_tokens: usize,
|
||||
max_tokens: u64,
|
||||
mode: Option<ModelMode>,
|
||||
}
|
||||
|
||||
|
@ -342,11 +342,11 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
self.model.supports_tools()
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
true
|
||||
self.model.supports_images()
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
|
@ -365,15 +365,19 @@ impl LanguageModel for GoogleLanguageModel {
|
|||
format!("google/{}", self.model.request_id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
let model_id = self.model.request_id().to_string();
|
||||
let request = into_google(request, model_id.clone(), self.model.mode());
|
||||
let http_client = self.http_client.clone();
|
||||
|
@ -437,14 +441,29 @@ pub fn into_google(
|
|||
content
|
||||
.into_iter()
|
||||
.flat_map(|content| match content {
|
||||
language_model::MessageContent::Text(text)
|
||||
| language_model::MessageContent::Thinking { text, .. } => {
|
||||
language_model::MessageContent::Text(text) => {
|
||||
if !text.is_empty() {
|
||||
vec![Part::TextPart(google_ai::TextPart { text })]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::Thinking {
|
||||
text: _,
|
||||
signature: Some(signature),
|
||||
} => {
|
||||
if !signature.is_empty() {
|
||||
vec![Part::ThoughtPart(google_ai::ThoughtPart {
|
||||
thought: true,
|
||||
thought_signature: signature,
|
||||
})]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
language_model::MessageContent::Thinking { .. } => {
|
||||
vec![]
|
||||
}
|
||||
language_model::MessageContent::RedactedThinking(_) => vec![],
|
||||
language_model::MessageContent::Image(image) => {
|
||||
vec![Part::InlineDataPart(google_ai::InlineDataPart {
|
||||
|
@ -664,7 +683,12 @@ impl GoogleEventMapper {
|
|||
)));
|
||||
}
|
||||
Part::FunctionResponsePart(_) => {}
|
||||
Part::ThoughtPart(_) => {}
|
||||
Part::ThoughtPart(part) => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
|
||||
signature: Some(part.thought_signature),
|
||||
}));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -682,7 +706,7 @@ impl GoogleEventMapper {
|
|||
pub fn count_google_tokens(
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
|
||||
// So we have to use tokenizer from tiktoken_rs to count tokens.
|
||||
cx.background_spawn(async move {
|
||||
|
@ -703,7 +727,7 @@ pub fn count_google_tokens(
|
|||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -730,10 +754,10 @@ fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
|
|||
}
|
||||
|
||||
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
|
||||
let prompt_tokens = usage.prompt_token_count.unwrap_or(0) as u32;
|
||||
let cached_tokens = usage.cached_content_token_count.unwrap_or(0) as u32;
|
||||
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
|
||||
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
|
||||
let input_tokens = prompt_tokens - cached_tokens;
|
||||
let output_tokens = usage.candidates_token_count.unwrap_or(0) as u32;
|
||||
let output_tokens = usage.candidates_token_count.unwrap_or(0);
|
||||
|
||||
language_model::TokenUsage {
|
||||
input_tokens,
|
||||
|
|
|
@ -7,17 +7,14 @@ use http_client::HttpClient;
|
|||
use language_model::{
|
||||
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
StopReason,
|
||||
StopReason, TokenUsage,
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use lmstudio::{
|
||||
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
|
||||
stream_chat_completion,
|
||||
};
|
||||
use lmstudio::{ModelType, get_models};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -47,8 +44,9 @@ pub struct LmStudioSettings {
|
|||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_tokens: u64,
|
||||
pub supports_tool_calls: bool,
|
||||
pub supports_images: bool,
|
||||
}
|
||||
|
||||
pub struct LmStudioLanguageModelProvider {
|
||||
|
@ -88,6 +86,7 @@ impl State {
|
|||
.loaded_context_length
|
||||
.or_else(|| model.max_context_length),
|
||||
model.capabilities.supports_tool_calls(),
|
||||
model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
|
|||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tool_calls: model.supports_tool_calls,
|
||||
supports_images: model.supports_images,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
|
|||
}
|
||||
|
||||
impl LmStudioLanguageModel {
|
||||
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
|
||||
fn to_lmstudio_request(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
) -> lmstudio::ChatCompletionRequest {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) => messages.push(match message.role {
|
||||
Role::User => ChatMessage::User { content: text },
|
||||
Role::Assistant => ChatMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => ChatMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::Text(text) => add_message_content_part(
|
||||
lmstudio::MessagePart::Text { text },
|
||||
message.role,
|
||||
&mut messages,
|
||||
),
|
||||
MessageContent::Thinking { .. } => {}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
add_message_content_part(
|
||||
lmstudio::MessagePart::Image {
|
||||
image_url: lmstudio::ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
message.role,
|
||||
&mut messages,
|
||||
);
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = lmstudio::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
|
@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
|
|||
}
|
||||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
match &tool_result.content {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
messages.push(lmstudio::ChatMessage::Tool {
|
||||
content: text.to_string(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
vec![lmstudio::MessagePart::Text {
|
||||
text: text.to_string(),
|
||||
}]
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
// no support for images for now
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
vec![lmstudio::MessagePart::Image {
|
||||
image_url: lmstudio::ImageUrl {
|
||||
url: image.to_base64_url(),
|
||||
detail: None,
|
||||
},
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(lmstudio::ChatMessage::Tool {
|
||||
content: content.into(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ChatCompletionRequest {
|
||||
lmstudio::ChatCompletionRequest {
|
||||
model: self.model.name.clone(),
|
||||
messages,
|
||||
stream: true,
|
||||
|
@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
|
|||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
request: lmstudio::ChatCompletionRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok(api_url) = cx.update(|cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
|
||||
|
@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
|
|||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||
let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
@ -385,14 +407,14 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_images
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("lmstudio/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
|
@ -400,7 +422,7 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// Endpoint for this is coming soon. In the meantime, hacky estimation
|
||||
let token_count = request
|
||||
.messages
|
||||
|
@ -408,7 +430,7 @@ impl LanguageModel for LmStudioLanguageModel {
|
|||
.map(|msg| msg.string_contents().split_whitespace().count())
|
||||
.sum::<usize>();
|
||||
|
||||
let estimated_tokens = (token_count as f64 * 0.75) as usize;
|
||||
let estimated_tokens = (token_count as f64 * 0.75) as u64;
|
||||
async move { Ok(estimated_tokens) }.boxed()
|
||||
}
|
||||
|
||||
|
@ -446,7 +468,7 @@ impl LmStudioEventMapper {
|
|||
|
||||
pub fn map_stream(
|
||||
mut self,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||
events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
|
||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
|
||||
{
|
||||
events.flat_map(move |event| {
|
||||
|
@ -459,7 +481,7 @@ impl LmStudioEventMapper {
|
|||
|
||||
pub fn map_event(
|
||||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
event: lmstudio::ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let Some(choice) = event.choices.into_iter().next() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
|
@ -506,6 +528,15 @@ impl LmStudioEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
|
@ -551,6 +582,40 @@ struct RawToolCall {
|
|||
arguments: String,
|
||||
}
|
||||
|
||||
fn add_message_content_part(
|
||||
new_part: lmstudio::MessagePart,
|
||||
role: Role,
|
||||
messages: &mut Vec<lmstudio::ChatMessage>,
|
||||
) {
|
||||
match (role, messages.last_mut()) {
|
||||
(Role::User, Some(lmstudio::ChatMessage::User { content }))
|
||||
| (
|
||||
Role::Assistant,
|
||||
Some(lmstudio::ChatMessage::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
}),
|
||||
)
|
||||
| (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
_ => {
|
||||
messages.push(match role {
|
||||
Role::User => lmstudio::ChatMessage::User {
|
||||
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
Role::Assistant => lmstudio::ChatMessage::Assistant {
|
||||
content: Some(lmstudio::MessageContent::from(vec![new_part])),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => lmstudio::ChatMessage::System {
|
||||
content: lmstudio::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
state: gpui::Entity<State>,
|
||||
loading_models_task: Option<Task<()>>,
|
||||
|
|
|
@ -13,7 +13,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -36,16 +36,15 @@ const PROVIDER_NAME: &str = "Mistral";
|
|||
pub struct MistralSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
}
|
||||
|
@ -322,11 +321,11 @@ impl LanguageModel for MistralLanguageModel {
|
|||
format!("mistral/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
|
@ -334,7 +333,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
|
@ -351,7 +350,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
@ -386,7 +385,7 @@ impl LanguageModel for MistralLanguageModel {
|
|||
pub fn into_mistral(
|
||||
request: LanguageModelRequest,
|
||||
model: String,
|
||||
max_output_tokens: Option<u32>,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> mistral::Request {
|
||||
let stream = true;
|
||||
|
||||
|
@ -626,6 +625,15 @@ impl MistralEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
if let Some(finish_reason) = choice.finish_reason.as_deref() {
|
||||
match finish_reason {
|
||||
"stop" => {
|
||||
|
|
|
@ -8,7 +8,7 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
|
||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason,
|
||||
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use ollama::{
|
||||
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
|
||||
|
@ -46,7 +46,7 @@ pub struct AvailableModel {
|
|||
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
|
||||
pub display_name: Option<String>,
|
||||
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
|
||||
pub max_tokens: usize,
|
||||
pub max_tokens: u64,
|
||||
/// The number of seconds to keep the connection open after the last request
|
||||
pub keep_alive: Option<KeepAlive>,
|
||||
/// Whether the model supports tools
|
||||
|
@ -377,7 +377,7 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
format!("ollama/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
|
@ -385,7 +385,7 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
_cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
// There is no endpoint for this _yet_ in Ollama
|
||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||
let token_count = request
|
||||
|
@ -395,7 +395,7 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
.sum::<usize>()
|
||||
/ 4;
|
||||
|
||||
async move { Ok(token_count) }.boxed()
|
||||
async move { Ok(token_count as u64) }.boxed()
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
|
@ -507,6 +507,12 @@ fn map_to_language_model_completion_events(
|
|||
};
|
||||
|
||||
if delta.done {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
||||
output_tokens: delta.eval_count.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
if state.used_tools {
|
||||
state.used_tools = false;
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
|
|
|
@ -1,32 +1,34 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{BTreeMap, HashMap};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
|
||||
use fs::Fs;
|
||||
use futures::Stream;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr as _;
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::OpenAiSettingsContent;
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "openai";
|
||||
|
@ -36,16 +38,15 @@ const PROVIDER_NAME: &str = "OpenAI";
|
|||
pub struct OpenAiSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
pub needs_setting_migration: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct OpenAiLanguageModelProvider {
|
||||
|
@ -62,6 +63,7 @@ pub struct State {
|
|||
const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
|
||||
|
||||
impl State {
|
||||
//
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
@ -312,11 +314,11 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
format!("openai/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
|
@ -324,7 +326,7 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_open_ai_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
|
@ -342,7 +344,12 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = into_open_ai(request, &self.model, self.max_output_tokens());
|
||||
let request = into_open_ai(
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = OpenAiEventMapper::new();
|
||||
|
@ -354,10 +361,11 @@ impl LanguageModel for OpenAiLanguageModel {
|
|||
|
||||
pub fn into_open_ai(
|
||||
request: LanguageModelRequest,
|
||||
model: &Model,
|
||||
max_output_tokens: Option<u32>,
|
||||
model_id: &str,
|
||||
supports_parallel_tool_calls: bool,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> open_ai::Request {
|
||||
let stream = !model.id().starts_with("o1-");
|
||||
let stream = !model_id.starts_with("o1-");
|
||||
|
||||
let mut messages = Vec::new();
|
||||
for message in request.messages {
|
||||
|
@ -433,13 +441,13 @@ pub fn into_open_ai(
|
|||
}
|
||||
|
||||
open_ai::Request {
|
||||
model: model.id().into(),
|
||||
model: model_id.into(),
|
||||
messages,
|
||||
stream,
|
||||
stop: request.stop,
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
max_tokens: max_output_tokens,
|
||||
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
|
||||
max_completion_tokens: max_output_tokens,
|
||||
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
|
||||
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
|
||||
Some(false)
|
||||
} else {
|
||||
|
@ -526,13 +534,20 @@ impl OpenAiEventMapper {
|
|||
&mut self,
|
||||
event: ResponseStreamEvent,
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
let Some(choice) = event.choices.first() else {
|
||||
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
|
||||
"Response contained no choices"
|
||||
)))];
|
||||
return events;
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
|
@ -606,7 +621,7 @@ pub fn count_open_ai_tokens(
|
|||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
|
@ -648,18 +663,18 @@ pub fn count_open_ai_tokens(
|
|||
| Model::FourPointOneMini
|
||||
| Model::FourPointOneNano
|
||||
| Model::O1
|
||||
| Model::O1Preview
|
||||
| Model::O1Mini
|
||||
| Model::O3
|
||||
| Model::O3Mini
|
||||
| Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
|
||||
}
|
||||
.map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
api_url_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
@ -667,9 +682,28 @@ struct ConfigurationView {
|
|||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
let mut editor = Editor::single_line(window, cx);
|
||||
editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
|
||||
editor
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"sk-000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
let api_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let api_url_editor = cx.new(|cx| {
|
||||
let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
|
||||
|
||||
if !api_url.is_empty() {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text(&*api_url, window, cx);
|
||||
});
|
||||
}
|
||||
input
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
|
@ -687,7 +721,6 @@ impl ConfigurationView {
|
|||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
|
@ -698,14 +731,24 @@ impl ConfigurationView {
|
|||
|
||||
Self {
|
||||
api_key_editor,
|
||||
api_url_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self.api_key_editor.read(cx).text(cx);
|
||||
if api_key.is_empty() {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -721,8 +764,11 @@ impl ConfigurationView {
|
|||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor
|
||||
.update(cx, |editor, cx| editor.set_text("", window, cx));
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
|
@ -733,29 +779,55 @@ impl ConfigurationView {
|
|||
cx.notify();
|
||||
}
|
||||
|
||||
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings = ThemeSettings::get_global(cx);
|
||||
let text_style = TextStyle {
|
||||
color: cx.theme().colors().text,
|
||||
font_family: settings.ui_font.family.clone(),
|
||||
font_features: settings.ui_font.features.clone(),
|
||||
font_fallbacks: settings.ui_font.fallbacks.clone(),
|
||||
font_size: rems(0.875).into(),
|
||||
font_weight: settings.ui_font.weight,
|
||||
font_style: FontStyle::Normal,
|
||||
line_height: relative(1.3),
|
||||
white_space: WhiteSpace::Normal,
|
||||
..Default::default()
|
||||
fn save_api_url(&mut self, cx: &mut Context<Self>) {
|
||||
let api_url = self
|
||||
.api_url_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let current_url = AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.api_url
|
||||
.clone();
|
||||
|
||||
let effective_current_url = if current_url.is_empty() {
|
||||
open_ai::OPEN_AI_API_URL
|
||||
} else {
|
||||
¤t_url
|
||||
};
|
||||
EditorElement::new(
|
||||
&self.api_key_editor,
|
||||
EditorStyle {
|
||||
background: cx.theme().colors().editor_background,
|
||||
local_player: cx.theme().players().local(),
|
||||
text: text_style,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
|
||||
if !api_url.is_empty() && api_url != effective_current_url {
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = Some(api_url.clone());
|
||||
} else {
|
||||
settings.openai = Some(OpenAiSettingsContent {
|
||||
api_url: Some(api_url.clone()),
|
||||
available_models: None,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_url_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
let fs = <dyn Fs>::global(cx);
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
|
||||
if let Some(settings) = settings.openai.as_mut() {
|
||||
settings.api_url = None;
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
|
@ -767,12 +839,10 @@ impl Render for ConfigurationView {
|
|||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials...")).into_any()
|
||||
} else if self.should_render_editor(cx) {
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.size_full()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
|
||||
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
|
@ -788,18 +858,7 @@ impl Render for ConfigurationView {
|
|||
"Paste your API key below and hit enter to start using the assistant",
|
||||
)),
|
||||
)
|
||||
.child(
|
||||
h_flex()
|
||||
.w_full()
|
||||
.my_2()
|
||||
.px_2()
|
||||
.py_1()
|
||||
.bg(cx.theme().colors().editor_background)
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.rounded_sm()
|
||||
.child(self.render_api_key_editor(cx)),
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(
|
||||
format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
|
||||
|
@ -808,7 +867,7 @@ impl Render for ConfigurationView {
|
|||
)
|
||||
.child(
|
||||
Label::new(
|
||||
"Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
|
||||
"Note that having a subscription for another service like GitHub Copilot won't work.",
|
||||
)
|
||||
.size(LabelSize::Small).color(Color::Muted),
|
||||
)
|
||||
|
@ -833,18 +892,82 @@ impl Render for ConfigurationView {
|
|||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-key", "Reset Key")
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(Some(IconName::Trash))
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.disabled(env_var_set)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
let custom_api_url_set =
|
||||
AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
|
||||
|
||||
let api_url_section = if custom_api_url_set {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new("Custom API URL configured.")),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-url", "Reset API URL")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.on_click(
|
||||
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
|
||||
),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
v_flex()
|
||||
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
|
||||
this.save_api_url(cx);
|
||||
cx.notify();
|
||||
}))
|
||||
.mt_2()
|
||||
.pt_2()
|
||||
.border_t_1()
|
||||
.border_color(cx.theme().colors().border_variant)
|
||||
.gap_1()
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::text_only(
|
||||
"Optionally, you can change the base URL for the OpenAI API request.",
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste the new API endpoint below and hit enter",
|
||||
)),
|
||||
)
|
||||
.child(self.api_url_editor.clone())
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex()
|
||||
.size_full()
|
||||
.child(api_key_section)
|
||||
.child(api_url_section)
|
||||
.into_any()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,9 +12,11 @@ use language_model::{
|
|||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
|
||||
RateLimiter, Role, StopReason,
|
||||
RateLimiter, Role, StopReason, TokenUsage,
|
||||
};
|
||||
use open_router::{
|
||||
Model, ModelMode as OpenRouterModelMode, ResponseStreamEvent, list_models, stream_completion,
|
||||
};
|
||||
use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
|
@ -40,9 +42,44 @@ pub struct OpenRouterSettings {
|
|||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub max_output_tokens: Option<u32>,
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
pub mode: Option<ModelMode>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum ModelMode {
|
||||
#[default]
|
||||
Default,
|
||||
Thinking {
|
||||
budget_tokens: Option<u32>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<ModelMode> for OpenRouterModelMode {
|
||||
fn from(value: ModelMode) -> Self {
|
||||
match value {
|
||||
ModelMode::Default => OpenRouterModelMode::Default,
|
||||
ModelMode::Thinking { budget_tokens } => {
|
||||
OpenRouterModelMode::Thinking { budget_tokens }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenRouterModelMode> for ModelMode {
|
||||
fn from(value: OpenRouterModelMode) -> Self {
|
||||
match value {
|
||||
OpenRouterModelMode::Default => ModelMode::Default,
|
||||
OpenRouterModelMode::Thinking { budget_tokens } => {
|
||||
ModelMode::Thinking { budget_tokens }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenRouterLanguageModelProvider {
|
||||
|
@ -56,6 +93,7 @@ pub struct State {
|
|||
http_client: Arc<dyn HttpClient>,
|
||||
available_models: Vec<open_router::Model>,
|
||||
fetch_models_task: Option<Task<Result<()>>>,
|
||||
settings: OpenRouterSettings,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
|
@ -98,6 +136,7 @@ impl State {
|
|||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.restart_fetch_models_task(cx);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
|
@ -130,6 +169,7 @@ impl State {
|
|||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
this.restart_fetch_models_task(cx);
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
|
@ -153,8 +193,10 @@ impl State {
|
|||
}
|
||||
|
||||
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
|
||||
let task = self.fetch_models(cx);
|
||||
self.fetch_models_task.replace(task);
|
||||
if self.is_authenticated() {
|
||||
let task = self.fetch_models(cx);
|
||||
self.fetch_models_task.replace(task);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,8 +208,14 @@ impl OpenRouterLanguageModelProvider {
|
|||
http_client: http_client.clone(),
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
settings: OpenRouterSettings::default(),
|
||||
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
|
||||
this.restart_fetch_models_task(cx);
|
||||
let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
|
||||
let settings_changed = current_settings != &this.settings;
|
||||
if settings_changed {
|
||||
this.settings = current_settings.clone();
|
||||
this.restart_fetch_models_task(cx);
|
||||
}
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
@ -227,7 +275,9 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
|
|||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tools: Some(false),
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
mode: model.mode.clone().unwrap_or_default().into(),
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -328,11 +378,11 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
format!("openrouter/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u32> {
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
|
@ -345,14 +395,14 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
false
|
||||
self.model.supports_images.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_open_router_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
|
@ -383,23 +433,28 @@ impl LanguageModel for OpenRouterLanguageModel {
|
|||
pub fn into_open_router(
|
||||
request: LanguageModelRequest,
|
||||
model: &Model,
|
||||
max_output_tokens: Option<u32>,
|
||||
max_output_tokens: Option<u64>,
|
||||
) -> open_router::Request {
|
||||
let mut messages = Vec::new();
|
||||
for req_message in request.messages {
|
||||
for content in req_message.content {
|
||||
for message in request.messages {
|
||||
for content in message.content {
|
||||
match content {
|
||||
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
|
||||
.push(match req_message.role {
|
||||
Role::User => open_router::RequestMessage::User { content: text },
|
||||
Role::Assistant => open_router::RequestMessage::Assistant {
|
||||
content: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_router::RequestMessage::System { content: text },
|
||||
}),
|
||||
MessageContent::Text(text) => add_message_content_part(
|
||||
open_router::MessagePart::Text { text },
|
||||
message.role,
|
||||
&mut messages,
|
||||
),
|
||||
MessageContent::Thinking { .. } => {}
|
||||
MessageContent::RedactedThinking(_) => {}
|
||||
MessageContent::Image(_) => {}
|
||||
MessageContent::Image(image) => {
|
||||
add_message_content_part(
|
||||
open_router::MessagePart::Image {
|
||||
image_url: image.to_base64_url(),
|
||||
},
|
||||
message.role,
|
||||
&mut messages,
|
||||
);
|
||||
}
|
||||
MessageContent::ToolUse(tool_use) => {
|
||||
let tool_call = open_router::ToolCall {
|
||||
id: tool_use.id.to_string(),
|
||||
|
@ -425,16 +480,20 @@ pub fn into_open_router(
|
|||
}
|
||||
MessageContent::ToolResult(tool_result) => {
|
||||
let content = match &tool_result.content {
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
text.to_string()
|
||||
LanguageModelToolResultContent::Text(text) => {
|
||||
vec![open_router::MessagePart::Text {
|
||||
text: text.to_string(),
|
||||
}]
|
||||
}
|
||||
LanguageModelToolResultContent::Image(_) => {
|
||||
"[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
|
||||
LanguageModelToolResultContent::Image(image) => {
|
||||
vec![open_router::MessagePart::Image {
|
||||
image_url: image.to_base64_url(),
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
messages.push(open_router::RequestMessage::Tool {
|
||||
content: content,
|
||||
content: content.into(),
|
||||
tool_call_id: tool_result.tool_use_id.to_string(),
|
||||
});
|
||||
}
|
||||
|
@ -454,6 +513,17 @@ pub fn into_open_router(
|
|||
} else {
|
||||
None
|
||||
},
|
||||
usage: open_router::RequestUsage { include: true },
|
||||
reasoning: if let OpenRouterModelMode::Thinking { budget_tokens } = model.mode {
|
||||
Some(open_router::Reasoning {
|
||||
effort: None,
|
||||
max_tokens: budget_tokens,
|
||||
exclude: Some(false),
|
||||
enabled: Some(true),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
tools: request
|
||||
.tools
|
||||
.into_iter()
|
||||
|
@ -473,6 +543,42 @@ pub fn into_open_router(
|
|||
}
|
||||
}
|
||||
|
||||
fn add_message_content_part(
|
||||
new_part: open_router::MessagePart,
|
||||
role: Role,
|
||||
messages: &mut Vec<open_router::RequestMessage>,
|
||||
) {
|
||||
match (role, messages.last_mut()) {
|
||||
(Role::User, Some(open_router::RequestMessage::User { content }))
|
||||
| (Role::System, Some(open_router::RequestMessage::System { content })) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
(
|
||||
Role::Assistant,
|
||||
Some(open_router::RequestMessage::Assistant {
|
||||
content: Some(content),
|
||||
..
|
||||
}),
|
||||
) => {
|
||||
content.push_part(new_part);
|
||||
}
|
||||
_ => {
|
||||
messages.push(match role {
|
||||
Role::User => open_router::RequestMessage::User {
|
||||
content: open_router::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
Role::Assistant => open_router::RequestMessage::Assistant {
|
||||
content: Some(open_router::MessageContent::from(vec![new_part])),
|
||||
tool_calls: Vec::new(),
|
||||
},
|
||||
Role::System => open_router::RequestMessage::System {
|
||||
content: open_router::MessageContent::from(vec![new_part]),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenRouterEventMapper {
|
||||
tool_calls_by_index: HashMap<usize, RawToolCall>,
|
||||
}
|
||||
|
@ -508,8 +614,19 @@ impl OpenRouterEventMapper {
|
|||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(reasoning) = choice.delta.reasoning.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Thinking {
|
||||
text: reasoning,
|
||||
signature: None,
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(content) = choice.delta.content.clone() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
// OpenRouter send empty content string with the reasoning content
|
||||
// This is a workaround for the OpenRouter API bug
|
||||
if !content.is_empty() {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
|
||||
|
@ -532,6 +649,15 @@ impl OpenRouterEventMapper {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("stop") => {
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
|
@ -560,7 +686,7 @@ impl OpenRouterEventMapper {
|
|||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
|
||||
}
|
||||
Some(stop_reason) => {
|
||||
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
|
||||
log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
|
||||
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
|
||||
}
|
||||
None => {}
|
||||
|
@ -581,7 +707,7 @@ pub fn count_open_router_tokens(
|
|||
request: LanguageModelRequest,
|
||||
_model: open_router::Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<usize>> {
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
|
@ -598,7 +724,7 @@ pub fn count_open_router_tokens(
|
|||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
|
577
crates/language_models/src/provider/vercel.rs
Normal file
577
crates/language_models/src/provider/vercel.rs
Normal file
|
@ -0,0 +1,577 @@
|
|||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
|
||||
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
|
||||
LanguageModelToolChoice, RateLimiter, Role,
|
||||
};
|
||||
use menu;
|
||||
use open_ai::ResponseStreamEvent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::sync::Arc;
|
||||
use strum::IntoEnumIterator;
|
||||
use vercel::Model;
|
||||
|
||||
use ui::{ElevationIndex, List, Tooltip, prelude::*};
|
||||
use ui_input::SingleLineInput;
|
||||
use util::ResultExt;
|
||||
|
||||
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
|
||||
|
||||
const PROVIDER_ID: &str = "vercel";
|
||||
const PROVIDER_NAME: &str = "Vercel";
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct VercelSettings {
|
||||
pub api_url: String,
|
||||
pub available_models: Vec<AvailableModel>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct AvailableModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub max_output_tokens: Option<u64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
pub struct VercelLanguageModelProvider {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
state: gpui::Entity<State>,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
api_key: Option<String>,
|
||||
api_key_from_env: bool,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
|
||||
|
||||
impl State {
|
||||
fn is_authenticated(&self) -> bool {
|
||||
self.api_key.is_some()
|
||||
}
|
||||
|
||||
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.delete_credentials(&api_url, &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = None;
|
||||
this.api_key_from_env = false;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
credentials_provider
|
||||
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
.log_err();
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
|
||||
if self.is_authenticated() {
|
||||
return Task::ready(Ok(()));
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
|
||||
(api_key, true)
|
||||
} else {
|
||||
let (_, api_key) = credentials_provider
|
||||
.read_credentials(&api_url, &cx)
|
||||
.await?
|
||||
.ok_or(AuthenticateError::CredentialsNotFound)?;
|
||||
(
|
||||
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
|
||||
false,
|
||||
)
|
||||
};
|
||||
this.update(cx, |this, cx| {
|
||||
this.api_key = Some(api_key);
|
||||
this.api_key_from_env = from_env;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl VercelLanguageModelProvider {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let state = cx.new(|cx| State {
|
||||
api_key: None,
|
||||
api_key_from_env: false,
|
||||
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
|
||||
cx.notify();
|
||||
}),
|
||||
});
|
||||
|
||||
Self { http_client, state }
|
||||
}
|
||||
|
||||
fn create_language_model(&self, model: vercel::Model) -> Arc<dyn LanguageModel> {
|
||||
Arc::new(VercelLanguageModel {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for VercelLanguageModelProvider {
|
||||
type ObservableEntity = State;
|
||||
|
||||
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
|
||||
Some(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for VercelLanguageModelProvider {
|
||||
fn id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn icon(&self) -> IconName {
|
||||
IconName::AiVZero
|
||||
}
|
||||
|
||||
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(vercel::Model::default()))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
Some(self.create_language_model(vercel::Model::default_fast()))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
let mut models = BTreeMap::default();
|
||||
|
||||
for model in vercel::Model::iter() {
|
||||
if !matches!(model, vercel::Model::Custom { .. }) {
|
||||
models.insert(model.id().to_string(), model);
|
||||
}
|
||||
}
|
||||
|
||||
for model in &AllLanguageModelSettings::get_global(cx)
|
||||
.vercel
|
||||
.available_models
|
||||
{
|
||||
models.insert(
|
||||
model.name.clone(),
|
||||
vercel::Model::Custom {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
max_output_tokens: model.max_output_tokens,
|
||||
max_completion_tokens: model.max_completion_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
models
|
||||
.into_values()
|
||||
.map(|model| self.create_language_model(model))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_authenticated(&self, cx: &App) -> bool {
|
||||
self.state.read(cx).is_authenticated()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
|
||||
self.state.update(cx, |state, cx| state.authenticate(cx))
|
||||
}
|
||||
|
||||
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
|
||||
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
|
||||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.reset_api_key(cx))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VercelLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: vercel::Model,
|
||||
state: gpui::Entity<State>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl VercelLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: open_ai::Request,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
|
||||
{
|
||||
let http_client = self.http_client.clone();
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
|
||||
let api_url = if settings.api_url.is_empty() {
|
||||
vercel::VERCEL_API_URL.to_string()
|
||||
} else {
|
||||
settings.api_url.clone()
|
||||
};
|
||||
(state.api_key.clone(), api_url)
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let api_key = api_key.context("Missing Vercel API Key")?;
|
||||
let request =
|
||||
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
let response = request.await?;
|
||||
Ok(response)
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModel for VercelLanguageModel {
|
||||
fn id(&self) -> LanguageModelId {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
|
||||
fn provider_id(&self) -> LanguageModelProviderId {
|
||||
LanguageModelProviderId(PROVIDER_ID.into())
|
||||
}
|
||||
|
||||
fn provider_name(&self) -> LanguageModelProviderName {
|
||||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn supports_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
|
||||
match choice {
|
||||
LanguageModelToolChoice::Auto
|
||||
| LanguageModelToolChoice::Any
|
||||
| LanguageModelToolChoice::None => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("vercel/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> u64 {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn max_output_tokens(&self) -> Option<u64> {
|
||||
self.model.max_output_tokens()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
count_vercel_tokens(request, self.model.clone(), cx)
|
||||
}
|
||||
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<
|
||||
futures::stream::BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
> {
|
||||
let request = crate::provider::open_ai::into_open_ai(
|
||||
request,
|
||||
self.model.id(),
|
||||
self.model.supports_parallel_tool_calls(),
|
||||
self.max_output_tokens(),
|
||||
);
|
||||
let completions = self.stream_completion(request, cx);
|
||||
async move {
|
||||
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
|
||||
Ok(mapper.map_stream(completions.await?).boxed())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn count_vercel_tokens(
|
||||
request: LanguageModelRequest,
|
||||
model: Model,
|
||||
cx: &App,
|
||||
) -> BoxFuture<'static, Result<u64>> {
|
||||
cx.background_spawn(async move {
|
||||
let messages = request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: Some(message.string_contents()),
|
||||
name: None,
|
||||
function_call: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
match model {
|
||||
Model::Custom { max_tokens, .. } => {
|
||||
let model = if max_tokens >= 100_000 {
|
||||
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
|
||||
"gpt-4o"
|
||||
} else {
|
||||
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
|
||||
// supported with this tiktoken method
|
||||
"gpt-4"
|
||||
};
|
||||
tiktoken_rs::num_tokens_from_messages(model, &messages)
|
||||
}
|
||||
// Map Vercel models to appropriate OpenAI models for token counting
|
||||
// since Vercel uses OpenAI-compatible API
|
||||
Model::VZeroOnePointFiveMedium => {
|
||||
// Vercel v0 is similar to GPT-4o, so use gpt-4o for token counting
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
|
||||
}
|
||||
}
|
||||
.map(|tokens| tokens as u64)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<SingleLineInput>,
|
||||
state: gpui::Entity<State>,
|
||||
load_credentials_task: Option<Task<()>>,
|
||||
}
|
||||
|
||||
impl ConfigurationView {
|
||||
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
|
||||
let api_key_editor = cx.new(|cx| {
|
||||
SingleLineInput::new(
|
||||
window,
|
||||
cx,
|
||||
"v1:0000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
.label("API key")
|
||||
});
|
||||
|
||||
cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
})
|
||||
.detach();
|
||||
|
||||
let load_credentials_task = Some(cx.spawn_in(window, {
|
||||
let state = state.clone();
|
||||
async move |this, cx| {
|
||||
if let Some(task) = state
|
||||
.update(cx, |state, cx| state.authenticate(cx))
|
||||
.log_err()
|
||||
{
|
||||
// We don't log an error, because "not signed in" is also an error.
|
||||
let _ = task.await;
|
||||
}
|
||||
this.update(cx, |this, cx| {
|
||||
this.load_credentials_task = None;
|
||||
cx.notify();
|
||||
})
|
||||
.log_err();
|
||||
}
|
||||
}));
|
||||
|
||||
Self {
|
||||
api_key_editor,
|
||||
state,
|
||||
load_credentials_task,
|
||||
}
|
||||
}
|
||||
|
||||
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let api_key = self
|
||||
.api_key_editor
|
||||
.read(cx)
|
||||
.editor()
|
||||
.read(cx)
|
||||
.text(cx)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
// Don't proceed if no API key is provided and we're not authenticated
|
||||
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
|
||||
return;
|
||||
}
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state
|
||||
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
|
||||
.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
self.api_key_editor.update(cx, |input, cx| {
|
||||
input.editor.update(cx, |editor, cx| {
|
||||
editor.set_text("", window, cx);
|
||||
});
|
||||
});
|
||||
|
||||
let state = self.state.clone();
|
||||
cx.spawn_in(window, async move |_, cx| {
|
||||
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
|
||||
!self.state.read(cx).is_authenticated()
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for ConfigurationView {
|
||||
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let env_var_set = self.state.read(cx).api_key_from_env;
|
||||
|
||||
let api_key_section = if self.should_render_editor(cx) {
|
||||
v_flex()
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new("To use Zed's agent with Vercel v0, you need to add an API key. Follow these steps:"))
|
||||
.child(
|
||||
List::new()
|
||||
.child(InstructionListItem::new(
|
||||
"Create one by visiting",
|
||||
Some("Vercel v0's console"),
|
||||
Some("https://v0.dev/chat/settings/keys"),
|
||||
))
|
||||
.child(InstructionListItem::text_only(
|
||||
"Paste your API key below and hit enter to start using the agent",
|
||||
)),
|
||||
)
|
||||
.child(self.api_key_editor.clone())
|
||||
.child(
|
||||
Label::new(format!(
|
||||
"You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
|
||||
))
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.child(
|
||||
Label::new("Note that Vercel v0 is a custom OpenAI-compatible provider.")
|
||||
.size(LabelSize::Small)
|
||||
.color(Color::Muted),
|
||||
)
|
||||
.into_any()
|
||||
} else {
|
||||
h_flex()
|
||||
.mt_1()
|
||||
.p_1()
|
||||
.justify_between()
|
||||
.rounded_md()
|
||||
.border_1()
|
||||
.border_color(cx.theme().colors().border)
|
||||
.bg(cx.theme().colors().background)
|
||||
.child(
|
||||
h_flex()
|
||||
.gap_1()
|
||||
.child(Icon::new(IconName::Check).color(Color::Success))
|
||||
.child(Label::new(if env_var_set {
|
||||
format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
|
||||
} else {
|
||||
"API key configured.".to_string()
|
||||
})),
|
||||
)
|
||||
.child(
|
||||
Button::new("reset-api-key", "Reset API Key")
|
||||
.label_size(LabelSize::Small)
|
||||
.icon(IconName::Undo)
|
||||
.icon_size(IconSize::Small)
|
||||
.icon_position(IconPosition::Start)
|
||||
.layer(ElevationIndex::ModalSurface)
|
||||
.when(env_var_set, |this| {
|
||||
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
|
||||
})
|
||||
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
|
||||
)
|
||||
.into_any()
|
||||
};
|
||||
|
||||
if self.load_credentials_task.is_some() {
|
||||
div().child(Label::new("Loading credentials…")).into_any()
|
||||
} else {
|
||||
v_flex().size_full().child(api_key_section).into_any()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,19 +1,14 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use gpui::App;
|
||||
use language_model::LanguageModelCacheConfiguration;
|
||||
use project::Fs;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsSources, update_settings_file};
|
||||
use settings::{Settings, SettingsSources};
|
||||
|
||||
use crate::provider::{
|
||||
self,
|
||||
anthropic::AnthropicSettings,
|
||||
bedrock::AmazonBedrockSettings,
|
||||
cloud::{self, ZedDotDevSettings},
|
||||
copilot_chat::CopilotChatSettings,
|
||||
deepseek::DeepSeekSettings,
|
||||
google::GoogleSettings,
|
||||
lmstudio::LmStudioSettings,
|
||||
|
@ -21,39 +16,12 @@ use crate::provider::{
|
|||
ollama::OllamaSettings,
|
||||
open_ai::OpenAiSettings,
|
||||
open_router::OpenRouterSettings,
|
||||
vercel::VercelSettings,
|
||||
};
|
||||
|
||||
/// Initializes the language model settings.
|
||||
pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
|
||||
pub fn init(cx: &mut App) {
|
||||
AllLanguageModelSettings::register(cx);
|
||||
|
||||
if AllLanguageModelSettings::get_global(cx)
|
||||
.openai
|
||||
.needs_setting_migration
|
||||
{
|
||||
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
|
||||
if let Some(settings) = setting.openai.clone() {
|
||||
let (newest_version, _) = settings.upgrade();
|
||||
setting.openai = Some(OpenAiSettingsContent::Versioned(
|
||||
VersionedOpenAiSettingsContent::V1(newest_version),
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if AllLanguageModelSettings::get_global(cx)
|
||||
.anthropic
|
||||
.needs_setting_migration
|
||||
{
|
||||
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
|
||||
if let Some(settings) = setting.anthropic.clone() {
|
||||
let (newest_version, _) = settings.upgrade();
|
||||
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
|
||||
VersionedAnthropicSettingsContent::V1(newest_version),
|
||||
));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
@ -65,7 +33,8 @@ pub struct AllLanguageModelSettings {
|
|||
pub open_router: OpenRouterSettings,
|
||||
pub zed_dot_dev: ZedDotDevSettings,
|
||||
pub google: GoogleSettings,
|
||||
pub copilot_chat: CopilotChatSettings,
|
||||
pub vercel: VercelSettings,
|
||||
|
||||
pub lmstudio: LmStudioSettings,
|
||||
pub deepseek: DeepSeekSettings,
|
||||
pub mistral: MistralSettings,
|
||||
|
@ -83,83 +52,13 @@ pub struct AllLanguageModelSettingsContent {
|
|||
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
|
||||
pub google: Option<GoogleSettingsContent>,
|
||||
pub deepseek: Option<DeepseekSettingsContent>,
|
||||
pub copilot_chat: Option<CopilotChatSettingsContent>,
|
||||
pub vercel: Option<VercelSettingsContent>,
|
||||
|
||||
pub mistral: Option<MistralSettingsContent>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicSettingsContent {
|
||||
Versioned(VersionedAnthropicSettingsContent),
|
||||
Legacy(LegacyAnthropicSettingsContent),
|
||||
}
|
||||
|
||||
impl AnthropicSettingsContent {
|
||||
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
|
||||
match self {
|
||||
AnthropicSettingsContent::Legacy(content) => (
|
||||
AnthropicSettingsContentV1 {
|
||||
api_url: content.api_url,
|
||||
available_models: content.available_models.map(|models| {
|
||||
models
|
||||
.into_iter()
|
||||
.filter_map(|model| match model {
|
||||
anthropic::Model::Custom {
|
||||
name,
|
||||
display_name,
|
||||
max_tokens,
|
||||
tool_override,
|
||||
cache_configuration,
|
||||
max_output_tokens,
|
||||
default_temperature,
|
||||
extra_beta_headers,
|
||||
mode,
|
||||
} => Some(provider::anthropic::AvailableModel {
|
||||
name,
|
||||
display_name,
|
||||
max_tokens,
|
||||
tool_override,
|
||||
cache_configuration: cache_configuration.as_ref().map(
|
||||
|config| LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors,
|
||||
should_speculate: config.should_speculate,
|
||||
min_total_token: config.min_total_token,
|
||||
},
|
||||
),
|
||||
max_output_tokens,
|
||||
default_temperature,
|
||||
extra_beta_headers,
|
||||
mode: Some(mode.into()),
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
},
|
||||
true,
|
||||
),
|
||||
AnthropicSettingsContent::Versioned(content) => match content {
|
||||
VersionedAnthropicSettingsContent::V1(content) => (content, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct LegacyAnthropicSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<anthropic::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedAnthropicSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(AnthropicSettingsContentV1),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct AnthropicSettingsContentV1 {
|
||||
pub struct AnthropicSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
|
||||
}
|
||||
|
@ -198,68 +97,17 @@ pub struct MistralSettingsContent {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(untagged)]
|
||||
pub enum OpenAiSettingsContent {
|
||||
Versioned(VersionedOpenAiSettingsContent),
|
||||
Legacy(LegacyOpenAiSettingsContent),
|
||||
}
|
||||
|
||||
impl OpenAiSettingsContent {
|
||||
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
|
||||
match self {
|
||||
OpenAiSettingsContent::Legacy(content) => (
|
||||
OpenAiSettingsContentV1 {
|
||||
api_url: content.api_url,
|
||||
available_models: content.available_models.map(|models| {
|
||||
models
|
||||
.into_iter()
|
||||
.filter_map(|model| match model {
|
||||
open_ai::Model::Custom {
|
||||
name,
|
||||
display_name,
|
||||
max_tokens,
|
||||
max_output_tokens,
|
||||
max_completion_tokens,
|
||||
} => Some(provider::open_ai::AvailableModel {
|
||||
name,
|
||||
max_tokens,
|
||||
max_output_tokens,
|
||||
display_name,
|
||||
max_completion_tokens,
|
||||
}),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
},
|
||||
true,
|
||||
),
|
||||
OpenAiSettingsContent::Versioned(content) => match content {
|
||||
VersionedOpenAiSettingsContent::V1(content) => (content, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct LegacyOpenAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<open_ai::Model>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
#[serde(tag = "version")]
|
||||
pub enum VersionedOpenAiSettingsContent {
|
||||
#[serde(rename = "1")]
|
||||
V1(OpenAiSettingsContentV1),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenAiSettingsContentV1 {
|
||||
pub struct OpenAiSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct VercelSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub available_models: Option<Vec<provider::vercel::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct GoogleSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -271,13 +119,6 @@ pub struct ZedDotDevSettingsContent {
|
|||
available_models: Option<Vec<cloud::AvailableModel>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct CopilotChatSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
pub auth_url: Option<String>,
|
||||
pub models_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
|
||||
pub struct OpenRouterSettingsContent {
|
||||
pub api_url: Option<String>,
|
||||
|
@ -302,15 +143,7 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
|
||||
for value in sources.defaults_and_customizations() {
|
||||
// Anthropic
|
||||
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
|
||||
Some((content, upgraded)) => (Some(content), upgraded),
|
||||
None => (None, false),
|
||||
};
|
||||
|
||||
if upgraded {
|
||||
settings.anthropic.needs_setting_migration = true;
|
||||
}
|
||||
|
||||
let anthropic = value.anthropic.clone();
|
||||
merge(
|
||||
&mut settings.anthropic.api_url,
|
||||
anthropic.as_ref().and_then(|s| s.api_url.clone()),
|
||||
|
@ -376,15 +209,7 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
);
|
||||
|
||||
// OpenAI
|
||||
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
|
||||
Some((content, upgraded)) => (Some(content), upgraded),
|
||||
None => (None, false),
|
||||
};
|
||||
|
||||
if upgraded {
|
||||
settings.openai.needs_setting_migration = true;
|
||||
}
|
||||
|
||||
let openai = value.openai.clone();
|
||||
merge(
|
||||
&mut settings.openai.api_url,
|
||||
openai.as_ref().and_then(|s| s.api_url.clone()),
|
||||
|
@ -393,6 +218,18 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
&mut settings.openai.available_models,
|
||||
openai.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// Vercel
|
||||
let vercel = value.vercel.clone();
|
||||
merge(
|
||||
&mut settings.vercel.api_url,
|
||||
vercel.as_ref().and_then(|s| s.api_url.clone()),
|
||||
);
|
||||
merge(
|
||||
&mut settings.vercel.available_models,
|
||||
vercel.as_ref().and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
merge(
|
||||
&mut settings.zed_dot_dev.available_models,
|
||||
value
|
||||
|
@ -435,24 +272,6 @@ impl settings::Settings for AllLanguageModelSettings {
|
|||
.as_ref()
|
||||
.and_then(|s| s.available_models.clone()),
|
||||
);
|
||||
|
||||
// Copilot Chat
|
||||
let copilot_chat = value.copilot_chat.clone().unwrap_or_default();
|
||||
|
||||
settings.copilot_chat.api_url = copilot_chat.api_url.map_or_else(
|
||||
|| Arc::from("https://api.githubcopilot.com/chat/completions"),
|
||||
Arc::from,
|
||||
);
|
||||
|
||||
settings.copilot_chat.auth_url = copilot_chat.auth_url.map_or_else(
|
||||
|| Arc::from("https://api.github.com/copilot_internal/v2/token"),
|
||||
Arc::from,
|
||||
);
|
||||
|
||||
settings.copilot_chat.models_url = copilot_chat.models_url.map_or_else(
|
||||
|| Arc::from("https://api.githubcopilot.com/models"),
|
||||
Arc::from,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(settings)
|
||||
|
|
|
@ -40,29 +40,30 @@ impl IntoElement for InstructionListItem {
|
|||
let link = button_link.clone();
|
||||
let unique_id = SharedString::from(format!("{}-button", self.label));
|
||||
|
||||
h_flex().flex_wrap().child(Label::new(self.label)).child(
|
||||
Button::new(unique_id, button_label)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _window, cx| cx.open_url(&link)),
|
||||
)
|
||||
h_flex()
|
||||
.flex_wrap()
|
||||
.child(Label::new(self.label))
|
||||
.child(
|
||||
Button::new(unique_id, button_label)
|
||||
.style(ButtonStyle::Subtle)
|
||||
.icon(IconName::ArrowUpRight)
|
||||
.icon_size(IconSize::XSmall)
|
||||
.icon_color(Color::Muted)
|
||||
.on_click(move |_, _window, cx| cx.open_url(&link)),
|
||||
)
|
||||
.into_any_element()
|
||||
} else {
|
||||
div().child(Label::new(self.label))
|
||||
Label::new(self.label).into_any_element()
|
||||
};
|
||||
|
||||
div()
|
||||
.child(
|
||||
ListItem::new("list-item")
|
||||
.selectable(false)
|
||||
.start_slot(
|
||||
Icon::new(IconName::Dash)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hidden),
|
||||
)
|
||||
.child(item_content),
|
||||
ListItem::new("list-item")
|
||||
.selectable(false)
|
||||
.start_slot(
|
||||
Icon::new(IconName::Dash)
|
||||
.size(IconSize::XSmall)
|
||||
.color(Color::Hidden),
|
||||
)
|
||||
.into_any()
|
||||
.child(div().w_full().child(item_content))
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue