WIP and merge

This commit is contained in:
Anthony 2025-06-27 18:38:25 -04:00
parent 97f4406ef6
commit 1bdde8b2e4
584 changed files with 33536 additions and 17400 deletions

View file

@ -40,8 +40,8 @@ mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
open_router = { workspace = true, features = ["schemars"] }
vercel = { workspace = true, features = ["schemars"] }
partial-json-fixer.workspace = true
project.workspace = true
proto.workspace = true
release_channel.workspace = true
schemars.workspace = true
@ -55,9 +55,11 @@ thiserror.workspace = true
tiktoken-rs.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
ui_input.workspace = true
util.workspace = true
workspace-hack.workspace = true
zed_llm_client.workspace = true
language.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }

View file

@ -1,7 +1,6 @@
use std::sync::Arc;
use client::{Client, UserStore};
use fs::Fs;
use gpui::{App, Context, Entity};
use language_model::LanguageModelRegistry;
use provider::deepseek::DeepSeekLanguageModelProvider;
@ -20,10 +19,11 @@ use crate::provider::mistral::MistralLanguageModelProvider;
use crate::provider::ollama::OllamaLanguageModelProvider;
use crate::provider::open_ai::OpenAiLanguageModelProvider;
use crate::provider::open_router::OpenRouterLanguageModelProvider;
use crate::provider::vercel::VercelLanguageModelProvider;
pub use crate::settings::*;
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut App) {
crate::settings::init(fs, cx);
pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
crate::settings::init(cx);
let registry = LanguageModelRegistry::global(cx);
registry.update(cx, |registry, cx| {
register_language_model_providers(registry, user_store, client, cx);
@ -77,5 +77,9 @@ fn register_language_model_providers(
OpenRouterLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(
VercelLanguageModelProvider::new(client.http_client(), cx),
cx,
);
registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx);
}

View file

@ -9,3 +9,4 @@ pub mod mistral;
pub mod ollama;
pub mod open_ai;
pub mod open_router;
pub mod vercel;

View file

@ -16,10 +16,10 @@ use gpui::{
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@ -41,7 +41,6 @@ pub struct AnthropicSettings {
pub api_url: String,
/// Extend Zed's list of Anthropic models.
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
@ -51,12 +50,12 @@ pub struct AvailableModel {
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The model's context window size.
pub max_tokens: usize,
pub max_tokens: u64,
/// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
pub tool_override: Option<String>,
/// Configuration of Anthropic's caching API.
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>,
pub max_output_tokens: Option<u64>,
pub default_temperature: Option<f32>,
#[serde(default)]
pub extra_beta_headers: Vec<String>,
@ -321,7 +320,7 @@ pub struct AnthropicModel {
pub fn count_anthropic_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request.messages;
let mut tokens_from_images = 0;
@ -377,7 +376,7 @@ pub fn count_anthropic_tokens(
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| tokens + tokens_from_images)
.map(|tokens| (tokens + tokens_from_images) as u64)
})
.boxed()
}
@ -407,14 +406,7 @@ impl AnthropicModel {
let api_key = api_key.context("Missing Anthropic API Key")?;
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(|err| match err {
AnthropicError::RateLimit(duration) => {
LanguageModelCompletionError::RateLimit(duration)
}
err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
}
})
request.await.map_err(Into::into)
}
.boxed()
}
@ -461,11 +453,11 @@ impl LanguageModel for AnthropicModel {
self.state.read(cx).api_key.clone()
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
Some(self.model.max_output_tokens())
}
@ -473,7 +465,7 @@ impl LanguageModel for AnthropicModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
count_anthropic_tokens(request, cx)
}
@ -518,7 +510,7 @@ pub fn into_anthropic(
request: LanguageModelRequest,
model: String,
default_temperature: f32,
max_output_tokens: u32,
max_output_tokens: u64,
mode: AnthropicModelMode,
) -> anthropic::Request {
let mut new_messages: Vec<anthropic::Message> = Vec::new();
@ -561,9 +553,7 @@ pub fn into_anthropic(
}
MessageContent::RedactedThinking(data) => {
if !data.is_empty() {
Some(anthropic::RequestContent::RedactedThinking {
data: String::from_utf8(data).ok()?,
})
Some(anthropic::RequestContent::RedactedThinking { data })
} else {
None
}
@ -714,7 +704,7 @@ impl AnthropicEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
Err(error) => vec![Err(error.into())],
})
})
}
@ -737,10 +727,8 @@ impl AnthropicEventMapper {
signature: None,
})]
}
ResponseContent::RedactedThinking { .. } => {
// Redacted thinking is encrypted and not accessible to the user, see:
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
Vec::new()
ResponseContent::RedactedThinking { data } => {
vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
}
ResponseContent::ToolUse { id, name, .. } => {
self.tool_uses_by_index.insert(
@ -859,9 +847,7 @@ impl AnthropicEventMapper {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
}
Event::Error { error } => {
vec![Err(LanguageModelCompletionError::Other(anyhow!(
AnthropicError::ApiError(error)
)))]
vec![Err(error.into())]
}
_ => Vec::new(),
}
@ -874,16 +860,6 @@ struct RawToolUse {
input_json: String,
}
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
if let AnthropicError::ApiError(api_err) = &err {
if let Some(tokens) = api_err.match_window_exceeded() {
return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
}
}
anyhow!(err)
}
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {

View file

@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
use bedrock::bedrock_client::Client as BedrockClient;
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
use bedrock::bedrock_client::types::{
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
StopReason,
CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
ReasoningContentBlockDelta, StopReason,
};
use bedrock::{
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings;
use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::{ResultExt, default};
use util::ResultExt;
use crate::AllLanguageModelSettings;
@ -88,9 +88,9 @@ pub enum BedrockAuthMethod {
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_tokens: u64,
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>,
pub max_output_tokens: Option<u64>,
pub default_temperature: Option<f32>,
pub mode: Option<ModelMode>,
}
@ -229,6 +229,17 @@ impl State {
Ok(())
})
}
fn get_region(&self) -> String {
// Get region - from credentials or directly from settings
let credentials_region = self.credentials.as_ref().map(|s| s.region.clone());
let settings_region = self.settings.as_ref().and_then(|s| s.region.clone());
// Use credentials region if available, otherwise use settings region, finally fall back to default
credentials_region
.or(settings_region)
.unwrap_or(String::from("us-east-1"))
}
}
pub struct BedrockLanguageModelProvider {
@ -289,8 +300,9 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
Some(self.create_language_model(bedrock::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(bedrock::Model::default_fast()))
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let region = self.state.read(cx).get_region();
Some(self.create_language_model(bedrock::Model::default_fast(region.as_str())))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -317,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
cache_configuration: model.cache_configuration.as_ref().map(|config| {
bedrock::BedrockModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
min_total_token: config.min_total_token,
}
}),
},
);
}
@ -377,11 +395,7 @@ impl BedrockModel {
let endpoint = state.settings.as_ref().and_then(|s| s.endpoint.clone());
let region = state
.settings
.as_ref()
.and_then(|s| s.region.clone())
.unwrap_or(String::from("us-east-1"));
let region = state.get_region();
(
auth_method,
@ -495,7 +509,8 @@ impl LanguageModel for BedrockModel {
LanguageModelToolChoice::Auto | LanguageModelToolChoice::Any => {
self.model.supports_tool_use()
}
LanguageModelToolChoice::None => false,
// Add support for None - we'll filter tool calls at response
LanguageModelToolChoice::None => self.model.supports_tool_use(),
}
}
@ -503,11 +518,11 @@ impl LanguageModel for BedrockModel {
format!("bedrock/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
Some(self.model.max_output_tokens())
}
@ -515,7 +530,7 @@ impl LanguageModel for BedrockModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
get_bedrock_tokens(request, cx)
}
@ -530,16 +545,7 @@ impl LanguageModel for BedrockModel {
LanguageModelCompletionError,
>,
> {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| {
// Get region - from credentials or directly from settings
let credentials_region = state.credentials.as_ref().map(|s| s.region.clone());
let settings_region = state.settings.as_ref().and_then(|s| s.region.clone());
// Use credentials region if available, otherwise use settings region, finally fall back to default
credentials_region
.or(settings_region)
.unwrap_or(String::from("us-east-1"))
}) else {
let Ok(region) = cx.read_entity(&self.state, |state, _cx| state.get_region()) else {
return async move { Err(anyhow::anyhow!("App State Dropped").into()) }.boxed();
};
@ -550,12 +556,15 @@ impl LanguageModel for BedrockModel {
}
};
let deny_tool_calls = request.tool_choice == Some(LanguageModelToolChoice::None);
let request = match into_bedrock(
request,
model_id,
self.model.default_temperature(),
self.model.max_output_tokens(),
self.model.mode(),
self.model.supports_caching(),
) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@ -566,25 +575,53 @@ impl LanguageModel for BedrockModel {
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.map_err(|err| anyhow!(err))?.await;
Ok(map_to_language_model_completion_events(
response,
owned_handle,
))
let events = map_to_language_model_completion_events(response, owned_handle);
if deny_tool_calls {
Ok(deny_tool_use_events(events).boxed())
} else {
Ok(events.boxed())
}
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
None
self.model
.cache_configuration()
.map(|config| LanguageModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: false,
min_total_token: config.min_total_token,
})
}
}
fn deny_tool_use_events(
events: impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
events.map(|event| {
match event {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
// Convert tool use to an error message if model decided to call it
Ok(LanguageModelCompletionEvent::Text(format!(
"\n\n[Error: Tool calls are disabled in this context. Attempted to call '{}']",
tool_use.name
)))
}
other => other,
}
})
}
pub fn into_bedrock(
request: LanguageModelRequest,
model: String,
default_temperature: f32,
max_output_tokens: u32,
max_output_tokens: u64,
mode: BedrockModelMode,
supports_caching: bool,
) -> Result<bedrock::Request> {
let mut new_messages: Vec<BedrockMessage> = Vec::new();
let mut system_message = String::new();
@ -596,7 +633,7 @@ pub fn into_bedrock(
match message.role {
Role::User | Role::Assistant => {
let bedrock_message_content: Vec<BedrockInnerContent> = message
let mut bedrock_message_content: Vec<BedrockInnerContent> = message
.content
.into_iter()
.filter_map(|content| match content {
@ -608,6 +645,11 @@ pub fn into_bedrock(
}
}
MessageContent::Thinking { text, signature } => {
if model.contains(Model::DeepSeekR1.request_id()) {
// DeepSeekR1 doesn't support thinking blocks
// And the AWS API demands that you strip them
return None;
}
let thinking = BedrockThinkingTextBlock::builder()
.text(text)
.set_signature(signature)
@ -620,19 +662,32 @@ pub fn into_bedrock(
))
}
MessageContent::RedactedThinking(blob) => {
if model.contains(Model::DeepSeekR1.request_id()) {
// DeepSeekR1 doesn't support thinking blocks
// And the AWS API demands that you strip them
return None;
}
let redacted =
BedrockThinkingBlock::RedactedContent(BedrockBlob::new(blob));
Some(BedrockInnerContent::ReasoningContent(redacted))
}
MessageContent::ToolUse(tool_use) => BedrockToolUseBlock::builder()
.name(tool_use.name.to_string())
.tool_use_id(tool_use.id.to_string())
.input(value_to_aws_document(&tool_use.input))
.build()
.context("failed to build Bedrock tool use block")
.log_err()
.map(BedrockInnerContent::ToolUse),
MessageContent::ToolUse(tool_use) => {
let input = if tool_use.input.is_null() {
// Bedrock API requires valid JsonValue, not null, for tool use input
value_to_aws_document(&serde_json::json!({}))
} else {
value_to_aws_document(&tool_use.input)
};
BedrockToolUseBlock::builder()
.name(tool_use.name.to_string())
.tool_use_id(tool_use.id.to_string())
.input(input)
.build()
.context("failed to build Bedrock tool use block")
.log_err()
.map(BedrockInnerContent::ToolUse)
},
MessageContent::ToolResult(tool_result) => {
BedrockToolResultBlock::builder()
.tool_use_id(tool_result.tool_use_id.to_string())
@ -662,6 +717,14 @@ pub fn into_bedrock(
_ => None,
})
.collect();
if message.cache && supports_caching {
bedrock_message_content.push(BedrockInnerContent::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let bedrock_role = match message.role {
Role::User => bedrock::BedrockRole::User,
Role::Assistant => bedrock::BedrockRole::Assistant,
@ -690,7 +753,7 @@ pub fn into_bedrock(
}
}
let tool_spec: Vec<BedrockTool> = request
let mut tool_spec: Vec<BedrockTool> = request
.tools
.iter()
.filter_map(|tool| {
@ -707,6 +770,15 @@ pub fn into_bedrock(
})
.collect();
if !tool_spec.is_empty() && supports_caching {
tool_spec.push(BedrockTool::CachePoint(
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.context("failed to build cache point block")?,
));
}
let tool_choice = match request.tool_choice {
Some(LanguageModelToolChoice::Auto) | None => {
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
@ -715,7 +787,8 @@ pub fn into_bedrock(
BedrockToolChoice::Any(BedrockAnyToolChoice::builder().build())
}
Some(LanguageModelToolChoice::None) => {
anyhow::bail!("LanguageModelToolChoice::None is not supported");
// For None, we still use Auto but will filter out tool calls in the response
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
}
};
let tool_config: BedrockToolConfig = BedrockToolConfig::builder()
@ -747,7 +820,7 @@ pub fn into_bedrock(
pub fn get_bedrock_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_executor()
.spawn(async move {
let messages = request.messages;
@ -799,7 +872,7 @@ pub fn get_bedrock_tokens(
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
.map(|tokens| tokens + tokens_from_images)
.map(|tokens| (tokens + tokens_from_images) as u64)
})
.boxed()
}
@ -947,11 +1020,12 @@ pub fn map_to_language_model_completion_events(
let completion_event =
LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: metadata.input_tokens as u32,
output_tokens: metadata.output_tokens
as u32,
cache_creation_input_tokens: default(),
cache_read_input_tokens: default(),
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens:
metadata.cache_write_input_tokens.unwrap_or_default() as u64,
cache_read_input_tokens:
metadata.cache_read_input_tokens.unwrap_or_default() as u64,
},
);
return Some((Some(Ok(completion_event)), state));

View file

@ -1,6 +1,6 @@
use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Context as _, Result, anyhow};
use client::{Client, UserStore, zed_urls};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
@ -14,7 +14,7 @@ use language_model::{
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
@ -73,9 +73,9 @@ pub struct AvailableModel {
/// The size of the context window, indicating the maximum number of tokens the model can process.
pub max_tokens: usize,
/// The maximum number of output tokens allowed by the model.
pub max_output_tokens: Option<u32>,
pub max_output_tokens: Option<u64>,
/// The maximum number of completion tokens allowed by the model (o1-* only)
pub max_completion_tokens: Option<u32>,
pub max_completion_tokens: Option<u64>,
/// Override this model with a different Anthropic model for tool calls.
pub tool_override: Option<String>,
/// Indicates whether this custom model supports caching.
@ -530,7 +530,7 @@ pub struct CloudLanguageModel {
struct PerformLlmCompletionResponse {
response: Response<AsyncBody>,
usage: Option<RequestUsage>,
usage: Option<ModelRequestUsage>,
tool_use_limit_reached: bool,
includes_status_messages: bool,
}
@ -581,7 +581,7 @@ impl CloudLanguageModel {
let usage = if includes_status_messages {
None
} else {
RequestUsage::from_headers(response.headers()).ok()
ModelRequestUsage::from_headers(response.headers()).ok()
};
return Ok(PerformLlmCompletionResponse {
@ -715,8 +715,8 @@ impl LanguageModel for CloudLanguageModel {
}
}
fn max_token_count(&self) -> usize {
self.model.max_token_count
fn max_token_count(&self) -> u64 {
self.model.max_token_count as u64
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
@ -737,7 +737,7 @@ impl LanguageModel for CloudLanguageModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
match self.model.provider {
zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
zed_llm_client::LanguageModelProvider::OpenAi => {
@ -786,7 +786,7 @@ impl LanguageModel for CloudLanguageModel {
let response_body: CountTokensResponse =
serde_json::from_str(&response_body)?;
Ok(response_body.tokens)
Ok(response_body.tokens as u64)
} else {
Err(anyhow!(ApiError {
status,
@ -821,7 +821,7 @@ impl LanguageModel for CloudLanguageModel {
request,
self.model.id.to_string(),
1.0,
self.model.max_output_tokens as u32,
self.model.max_output_tokens as u64,
if self.model.id.0.ends_with("-thinking") {
AnthropicModelMode::Thinking {
budget_tokens: Some(4_096),
@ -888,7 +888,12 @@ impl LanguageModel for CloudLanguageModel {
Ok(model) => model,
Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
};
let request = into_open_ai(request, &model, None);
let request = into_open_ai(
request,
model.id(),
model.supports_parallel_tool_calls(),
None,
);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
@ -1002,7 +1007,7 @@ where
}
fn usage_updated_event<T>(
usage: Option<RequestUsage>,
usage: Option<ModelRequestUsage>,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| {
Ok(CloudCompletionEvent::Status(

View file

@ -10,35 +10,30 @@ use copilot::copilot_chat::{
ToolCall,
};
use copilot::{Copilot, Status};
use editor::{Editor, EditorElement, EditorStyle};
use fs::Fs;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use futures::{FutureExt, Stream, StreamExt};
use gpui::{
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, FontStyle, Render,
Subscription, Task, TextStyle, Transformation, WhiteSpace, percentage, svg,
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
Transformation, percentage, svg,
};
use language::language_settings::all_language_settings;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
StopReason,
StopReason, TokenUsage,
};
use settings::{Settings, SettingsStore, update_settings_file};
use settings::SettingsStore;
use std::time::Duration;
use theme::ThemeSettings;
use ui::prelude::*;
use util::debug_panic;
use crate::{AllLanguageModelSettings, CopilotChatSettingsContent};
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
use super::open_ai::count_open_ai_tokens;
pub(crate) use copilot::copilot_chat::CopilotChatSettings;
const PROVIDER_ID: &str = "copilot_chat";
const PROVIDER_NAME: &str = "GitHub Copilot Chat";
@ -69,11 +64,16 @@ impl CopilotChatLanguageModelProvider {
_copilot_chat_subscription: copilot_chat_subscription,
_settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
if let Some(copilot_chat) = CopilotChat::global(cx) {
let settings = AllLanguageModelSettings::get_global(cx)
.copilot_chat
.clone();
let language_settings = all_language_settings(None, cx);
let configuration = copilot::copilot_chat::CopilotChatConfiguration {
enterprise_uri: language_settings
.edit_predictions
.copilot
.enterprise_uri
.clone(),
};
copilot_chat.update(cx, |chat, cx| {
chat.set_settings(settings, cx);
chat.set_configuration(configuration, cx);
});
}
cx.notify();
@ -174,10 +174,9 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
Task::ready(Err(err.into()))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, window, cx))
.into()
cx.new(|cx| ConfigurationView::new(state, cx)).into()
}
fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
@ -238,7 +237,7 @@ impl LanguageModel for CopilotChatLanguageModel {
format!("copilot_chat/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
@ -246,7 +245,7 @@ impl LanguageModel for CopilotChatLanguageModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
match self.model.vendor() {
ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
ModelVendor::Google => count_google_tokens(request, cx),
@ -268,23 +267,6 @@ impl LanguageModel for CopilotChatLanguageModel {
LanguageModelCompletionError,
>,
> {
if let Some(message) = request.messages.last() {
if message.contents_empty() {
const EMPTY_PROMPT_MSG: &str =
"Empty prompts aren't allowed. Please provide a non-empty prompt.";
return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG).into()))
.boxed();
}
// Copilot Chat has a restriction that the final message must be from the user.
// While their API does return an error message for this, we can catch it earlier
// and provide a more helpful error message.
if !matches!(message.role, Role::User) {
const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG).into())).boxed();
}
}
let copilot_request = match into_copilot_chat(&self.model, request) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@ -379,6 +361,17 @@ pub fn map_to_language_model_completion_events(
}
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
)));
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(
@ -622,38 +615,15 @@ fn into_copilot_chat(
struct ConfigurationView {
copilot_status: Option<copilot::Status>,
api_url_editor: Entity<Editor>,
models_url_editor: Entity<Editor>,
auth_url_editor: Entity<Editor>,
state: Entity<State>,
_subscription: Option<Subscription>,
}
impl ConfigurationView {
pub fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
let copilot = Copilot::global(cx);
let settings = AllLanguageModelSettings::get_global(cx)
.copilot_chat
.clone();
let api_url_editor = cx.new(|cx| Editor::single_line(window, cx));
api_url_editor.update(cx, |this, cx| {
this.set_text(settings.api_url.clone(), window, cx);
this.set_placeholder_text("GitHub Copilot API URL", cx);
});
let models_url_editor = cx.new(|cx| Editor::single_line(window, cx));
models_url_editor.update(cx, |this, cx| {
this.set_text(settings.models_url.clone(), window, cx);
this.set_placeholder_text("GitHub Copilot Models URL", cx);
});
let auth_url_editor = cx.new(|cx| Editor::single_line(window, cx));
auth_url_editor.update(cx, |this, cx| {
this.set_text(settings.auth_url.clone(), window, cx);
this.set_placeholder_text("GitHub Copilot Auth URL", cx);
});
Self {
api_url_editor,
models_url_editor,
auth_url_editor,
copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
state,
_subscription: copilot.as_ref().map(|copilot| {
@ -664,104 +634,6 @@ impl ConfigurationView {
}),
}
}
fn make_input_styles(&self, cx: &App) -> Div {
let bg_color = cx.theme().colors().editor_background;
let border_color = cx.theme().colors().border;
h_flex()
.w_full()
.px_2()
.py_1()
.bg(bg_color)
.border_1()
.border_color(border_color)
.rounded_sm()
}
fn make_text_style(&self, cx: &Context<Self>) -> TextStyle {
let settings = ThemeSettings::get_global(cx);
TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
text_overflow: None,
text_align: Default::default(),
line_clamp: None,
}
}
fn render_api_url_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
EditorElement::new(
&self.api_url_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
fn render_auth_url_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
EditorElement::new(
&self.auth_url_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
fn render_models_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let text_style = self.make_text_style(cx);
EditorElement::new(
&self.models_url_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}
fn update_copilot_settings(&self, cx: &mut Context<'_, Self>) {
let settings = CopilotChatSettings {
api_url: self.api_url_editor.read(cx).text(cx).into(),
models_url: self.models_url_editor.read(cx).text(cx).into(),
auth_url: self.auth_url_editor.read(cx).text(cx).into(),
};
update_settings_file::<AllLanguageModelSettings>(<dyn Fs>::global(cx), cx, {
let settings = settings.clone();
move |content, _| {
content.copilot_chat = Some(CopilotChatSettingsContent {
api_url: Some(settings.api_url.as_ref().into()),
models_url: Some(settings.models_url.as_ref().into()),
auth_url: Some(settings.auth_url.as_ref().into()),
});
}
});
if let Some(chat) = CopilotChat::global(cx) {
chat.update(cx, |this, cx| {
this.set_settings(settings, cx);
});
}
}
}
impl Render for ConfigurationView {
@ -819,59 +691,15 @@ impl Render for ConfigurationView {
}
_ => {
const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
v_flex()
.gap_2()
.child(Label::new(LABEL))
.on_action(cx.listener(|this, _: &menu::Confirm, window, cx| {
this.update_copilot_settings(cx);
copilot::initiate_sign_in(window, cx);
}))
.child(
v_flex()
.gap_0p5()
.child(Label::new("API URL").size(LabelSize::Small))
.child(
self.make_input_styles(cx)
.child(self.render_api_url_editor(cx)),
),
)
.child(
v_flex()
.gap_0p5()
.child(Label::new("Auth URL").size(LabelSize::Small))
.child(
self.make_input_styles(cx)
.child(self.render_auth_url_editor(cx)),
),
)
.child(
v_flex()
.gap_0p5()
.child(Label::new("Models list URL").size(LabelSize::Small))
.child(
self.make_input_styles(cx)
.child(self.render_models_editor(cx)),
),
)
.child(
Button::new("sign_in", "Sign in to use GitHub Copilot")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Medium)
.full_width()
.on_click(cx.listener(|this, _, window, cx| {
this.update_copilot_settings(cx);
copilot::initiate_sign_in(window, cx)
})),
)
.child(
Label::new(
format!("You can also assign the {} environment variable and restart Zed.", copilot::copilot_chat::COPILOT_OAUTH_ENV_VAR),
)
.size(LabelSize::Small)
.color(Color::Muted),
)
v_flex().gap_2().child(Label::new(LABEL)).child(
Button::new("sign_in", "Sign in to use GitHub Copilot")
.icon_color(Color::Muted)
.icon(IconName::Github)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Medium)
.full_width()
.on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
)
}
},
None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),

View file

@ -14,7 +14,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
RateLimiter, Role, StopReason, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -49,8 +49,8 @@ pub struct DeepSeekSettings {
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
}
pub struct DeepSeekLanguageModelProvider {
@ -306,11 +306,11 @@ impl LanguageModel for DeepSeekLanguageModel {
format!("deepseek/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
@ -318,7 +318,7 @@ impl LanguageModel for DeepSeekLanguageModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
@ -335,7 +335,7 @@ impl LanguageModel for DeepSeekLanguageModel {
})
.collect::<Vec<_>>();
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
@ -365,7 +365,7 @@ impl LanguageModel for DeepSeekLanguageModel {
pub fn into_deepseek(
request: LanguageModelRequest,
model: &deepseek::Model,
max_output_tokens: Option<u32>,
max_output_tokens: Option<u64>,
) -> deepseek::Request {
let is_reasoner = *model == deepseek::Model::Reasoner;
@ -513,6 +513,15 @@ impl DeepSeekEventMapper {
}
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));

View file

@ -79,7 +79,7 @@ impl From<GoogleModelMode> for ModelMode {
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: usize,
max_tokens: u64,
mode: Option<ModelMode>,
}
@ -342,11 +342,11 @@ impl LanguageModel for GoogleLanguageModel {
}
fn supports_tools(&self) -> bool {
true
self.model.supports_tools()
}
fn supports_images(&self) -> bool {
true
self.model.supports_images()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
@ -365,15 +365,19 @@ impl LanguageModel for GoogleLanguageModel {
format!("google/{}", self.model.request_id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
let model_id = self.model.request_id().to_string();
let request = into_google(request, model_id.clone(), self.model.mode());
let http_client = self.http_client.clone();
@ -437,14 +441,29 @@ pub fn into_google(
content
.into_iter()
.flat_map(|content| match content {
language_model::MessageContent::Text(text)
| language_model::MessageContent::Thinking { text, .. } => {
language_model::MessageContent::Text(text) => {
if !text.is_empty() {
vec![Part::TextPart(google_ai::TextPart { text })]
} else {
vec![]
}
}
language_model::MessageContent::Thinking {
text: _,
signature: Some(signature),
} => {
if !signature.is_empty() {
vec![Part::ThoughtPart(google_ai::ThoughtPart {
thought: true,
thought_signature: signature,
})]
} else {
vec![]
}
}
language_model::MessageContent::Thinking { .. } => {
vec![]
}
language_model::MessageContent::RedactedThinking(_) => vec![],
language_model::MessageContent::Image(image) => {
vec![Part::InlineDataPart(google_ai::InlineDataPart {
@ -664,7 +683,12 @@ impl GoogleEventMapper {
)));
}
Part::FunctionResponsePart(_) => {}
Part::ThoughtPart(_) => {}
Part::ThoughtPart(part) => {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
signature: Some(part.thought_signature),
}));
}
});
}
}
@ -682,7 +706,7 @@ impl GoogleEventMapper {
pub fn count_google_tokens(
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
// We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
// So we have to use tokenizer from tiktoken_rs to count tokens.
cx.background_spawn(async move {
@ -703,7 +727,7 @@ pub fn count_google_tokens(
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
@ -730,10 +754,10 @@ fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
}
fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
let prompt_tokens = usage.prompt_token_count.unwrap_or(0) as u32;
let cached_tokens = usage.cached_content_token_count.unwrap_or(0) as u32;
let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
let input_tokens = prompt_tokens - cached_tokens;
let output_tokens = usage.candidates_token_count.unwrap_or(0) as u32;
let output_tokens = usage.candidates_token_count.unwrap_or(0);
language_model::TokenUsage {
input_tokens,

View file

@ -7,17 +7,14 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
StopReason,
StopReason, TokenUsage,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
use lmstudio::{
ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
stream_chat_completion,
};
use lmstudio::{ModelType, get_models};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@ -47,8 +44,9 @@ pub struct LmStudioSettings {
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_tokens: u64,
pub supports_tool_calls: bool,
pub supports_images: bool,
}
pub struct LmStudioLanguageModelProvider {
@ -88,6 +86,7 @@ impl State {
.loaded_context_length
.or_else(|| model.max_context_length),
model.capabilities.supports_tool_calls(),
model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
)
})
.collect();
@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tool_calls: model.supports_tool_calls,
supports_images: model.supports_images,
},
);
}
@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
}
impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
fn to_lmstudio_request(
&self,
request: LanguageModelRequest,
) -> lmstudio::ChatCompletionRequest {
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) => messages.push(match message.role {
Role::User => ChatMessage::User { content: text },
Role::Assistant => ChatMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => ChatMessage::System { content: text },
}),
MessageContent::Text(text) => add_message_content_part(
lmstudio::MessagePart::Text { text },
message.role,
&mut messages,
),
MessageContent::Thinking { .. } => {}
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::Image(image) => {
add_message_content_part(
lmstudio::MessagePart::Image {
image_url: lmstudio::ImageUrl {
url: image.to_base64_url(),
detail: None,
},
},
message.role,
&mut messages,
);
}
MessageContent::ToolUse(tool_use) => {
let tool_call = lmstudio::ToolCall {
id: tool_use.id.to_string(),
@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
}
}
MessageContent::ToolResult(tool_result) => {
match &tool_result.content {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
messages.push(lmstudio::ChatMessage::Tool {
content: text.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
vec![lmstudio::MessagePart::Text {
text: text.to_string(),
}]
}
LanguageModelToolResultContent::Image(_) => {
// no support for images for now
LanguageModelToolResultContent::Image(image) => {
vec![lmstudio::MessagePart::Image {
image_url: lmstudio::ImageUrl {
url: image.to_base64_url(),
detail: None,
},
}]
}
};
messages.push(lmstudio::ChatMessage::Tool {
content: content.into(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
}
}
}
ChatCompletionRequest {
lmstudio::ChatCompletionRequest {
model: self.model.name.clone(),
messages,
stream: true,
@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
fn stream_completion(
&self,
request: ChatCompletionRequest,
request: lmstudio::ChatCompletionRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
> {
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
};
let future = self.request_limiter.stream(async move {
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
let response = request.await?;
Ok(response)
});
@ -385,14 +407,14 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn supports_images(&self) -> bool {
false
self.model.supports_images
}
fn telemetry_id(&self) -> String {
format!("lmstudio/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
@ -400,7 +422,7 @@ impl LanguageModel for LmStudioLanguageModel {
&self,
request: LanguageModelRequest,
_cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
// Endpoint for this is coming soon. In the meantime, hacky estimation
let token_count = request
.messages
@ -408,7 +430,7 @@ impl LanguageModel for LmStudioLanguageModel {
.map(|msg| msg.string_contents().split_whitespace().count())
.sum::<usize>();
let estimated_tokens = (token_count as f64 * 0.75) as usize;
let estimated_tokens = (token_count as f64 * 0.75) as u64;
async move { Ok(estimated_tokens) }.boxed()
}
@ -446,7 +468,7 @@ impl LmStudioEventMapper {
pub fn map_stream(
mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
@ -459,7 +481,7 @@ impl LmStudioEventMapper {
pub fn map_event(
&mut self,
event: ResponseStreamEvent,
event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
@ -506,6 +528,15 @@ impl LmStudioEventMapper {
}
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
@ -551,6 +582,40 @@ struct RawToolCall {
arguments: String,
}
fn add_message_content_part(
new_part: lmstudio::MessagePart,
role: Role,
messages: &mut Vec<lmstudio::ChatMessage>,
) {
match (role, messages.last_mut()) {
(Role::User, Some(lmstudio::ChatMessage::User { content }))
| (
Role::Assistant,
Some(lmstudio::ChatMessage::Assistant {
content: Some(content),
..
}),
)
| (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
content.push_part(new_part);
}
_ => {
messages.push(match role {
Role::User => lmstudio::ChatMessage::User {
content: lmstudio::MessageContent::from(vec![new_part]),
},
Role::Assistant => lmstudio::ChatMessage::Assistant {
content: Some(lmstudio::MessageContent::from(vec![new_part])),
tool_calls: Vec::new(),
},
Role::System => lmstudio::ChatMessage::System {
content: lmstudio::MessageContent::from(vec![new_part]),
},
});
}
}
}
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,

View file

@ -13,7 +13,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
RateLimiter, Role, StopReason, TokenUsage,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@ -36,16 +36,15 @@ const PROVIDER_NAME: &str = "Mistral";
pub struct MistralSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
pub supports_tools: Option<bool>,
pub supports_images: Option<bool>,
}
@ -322,11 +321,11 @@ impl LanguageModel for MistralLanguageModel {
format!("mistral/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
@ -334,7 +333,7 @@ impl LanguageModel for MistralLanguageModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
@ -351,7 +350,7 @@ impl LanguageModel for MistralLanguageModel {
})
.collect::<Vec<_>>();
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
})
.boxed()
}
@ -386,7 +385,7 @@ impl LanguageModel for MistralLanguageModel {
pub fn into_mistral(
request: LanguageModelRequest,
model: String,
max_output_tokens: Option<u32>,
max_output_tokens: Option<u64>,
) -> mistral::Request {
let stream = true;
@ -626,6 +625,15 @@ impl MistralEventMapper {
}
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
if let Some(finish_reason) = choice.finish_reason.as_deref() {
match finish_reason {
"stop" => {

View file

@ -8,7 +8,7 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use ollama::{
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
@ -46,7 +46,7 @@ pub struct AvailableModel {
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
pub max_tokens: usize,
pub max_tokens: u64,
/// The number of seconds to keep the connection open after the last request
pub keep_alive: Option<KeepAlive>,
/// Whether the model supports tools
@ -377,7 +377,7 @@ impl LanguageModel for OllamaLanguageModel {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
@ -385,7 +385,7 @@ impl LanguageModel for OllamaLanguageModel {
&self,
request: LanguageModelRequest,
_cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
// There is no endpoint for this _yet_ in Ollama
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
let token_count = request
@ -395,7 +395,7 @@ impl LanguageModel for OllamaLanguageModel {
.sum::<usize>()
/ 4;
async move { Ok(token_count) }.boxed()
async move { Ok(token_count as u64) }.boxed()
}
fn stream_completion(
@ -507,6 +507,12 @@ fn map_to_language_model_completion_events(
};
if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
if state.used_tools {
state.used_tools = false;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));

View file

@ -1,32 +1,34 @@
use anyhow::{Context as _, Result, anyhow};
use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
RateLimiter, Role, StopReason, TokenUsage,
};
use menu;
use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use settings::{Settings, SettingsStore, update_settings_file};
use std::pin::Pin;
use std::str::FromStr as _;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "openai";
@ -36,16 +38,15 @@ const PROVIDER_NAME: &str = "OpenAI";
pub struct OpenAiSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
}
pub struct OpenAiLanguageModelProvider {
@ -62,6 +63,7 @@ pub struct State {
const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
impl State {
//
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
@ -312,11 +314,11 @@ impl LanguageModel for OpenAiLanguageModel {
format!("openai/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
@ -324,7 +326,7 @@ impl LanguageModel for OpenAiLanguageModel {
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
count_open_ai_tokens(request, self.model.clone(), cx)
}
@ -342,7 +344,12 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelCompletionError,
>,
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
let request = into_open_ai(
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
);
let completions = self.stream_completion(request, cx);
async move {
let mapper = OpenAiEventMapper::new();
@ -354,10 +361,11 @@ impl LanguageModel for OpenAiLanguageModel {
pub fn into_open_ai(
request: LanguageModelRequest,
model: &Model,
max_output_tokens: Option<u32>,
model_id: &str,
supports_parallel_tool_calls: bool,
max_output_tokens: Option<u64>,
) -> open_ai::Request {
let stream = !model.id().starts_with("o1-");
let stream = !model_id.starts_with("o1-");
let mut messages = Vec::new();
for message in request.messages {
@ -433,13 +441,13 @@ pub fn into_open_ai(
}
open_ai::Request {
model: model.id().into(),
model: model_id.into(),
messages,
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
// Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
Some(false)
} else {
@ -526,13 +534,20 @@ impl OpenAiEventMapper {
&mut self,
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let mut events = Vec::new();
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
let Some(choice) = event.choices.first() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))];
return events;
};
let mut events = Vec::new();
if let Some(content) = choice.delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
@ -606,7 +621,7 @@ pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
@ -648,18 +663,18 @@ pub fn count_open_ai_tokens(
| Model::FourPointOneMini
| Model::FourPointOneNano
| Model::O1
| Model::O1Preview
| Model::O1Mini
| Model::O3
| Model::O3Mini
| Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
}
.map(|tokens| tokens as u64)
})
.boxed()
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
api_key_editor: Entity<SingleLineInput>,
api_url_editor: Entity<SingleLineInput>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
@ -667,9 +682,28 @@ struct ConfigurationView {
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
let mut editor = Editor::single_line(window, cx);
editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
editor
SingleLineInput::new(
window,
cx,
"sk-000000000000000000000000000000000000000000000000",
)
.label("API key")
});
let api_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let api_url_editor = cx.new(|cx| {
let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
if !api_url.is_empty() {
input.editor.update(cx, |editor, cx| {
editor.set_text(&*api_url, window, cx);
});
}
input
});
cx.observe(&state, |_, _, cx| {
@ -687,7 +721,6 @@ impl ConfigurationView {
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
@ -698,14 +731,24 @@ impl ConfigurationView {
Self {
api_key_editor,
api_url_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
if api_key.is_empty() {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
return;
}
@ -721,8 +764,11 @@ impl ConfigurationView {
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", window, cx));
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
@ -733,29 +779,55 @@ impl ConfigurationView {
cx.notify();
}
fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
white_space: WhiteSpace::Normal,
..Default::default()
fn save_api_url(&mut self, cx: &mut Context<Self>) {
let api_url = self
.api_url_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
let current_url = AllLanguageModelSettings::get_global(cx)
.openai
.api_url
.clone();
let effective_current_url = if current_url.is_empty() {
open_ai::OPEN_AI_API_URL
} else {
&current_url
};
EditorElement::new(
&self.api_key_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
if !api_url.is_empty() && api_url != effective_current_url {
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
if let Some(settings) = settings.openai.as_mut() {
settings.api_url = Some(api_url.clone());
} else {
settings.openai = Some(OpenAiSettingsContent {
api_url: Some(api_url.clone()),
available_models: None,
});
}
});
}
}
fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_url_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let fs = <dyn Fs>::global(cx);
update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
if let Some(settings) = settings.openai.as_mut() {
settings.api_url = None;
}
});
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
@ -767,12 +839,10 @@ impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
.child(
List::new()
@ -788,18 +858,7 @@ impl Render for ConfigurationView {
"Paste your API key below and hit enter to start using the assistant",
)),
)
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.border_1()
.border_color(cx.theme().colors().border)
.rounded_sm()
.child(self.render_api_key_editor(cx)),
)
.child(self.api_key_editor.clone())
.child(
Label::new(
format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
@ -808,7 +867,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
"Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
"Note that having a subscription for another service like GitHub Copilot won't work.",
)
.size(LabelSize::Small).color(Color::Muted),
)
@ -833,18 +892,82 @@ impl Render for ConfigurationView {
})),
)
.child(
Button::new("reset-key", "Reset Key")
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
.icon(Some(IconName::Trash))
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
.into_any()
};
let custom_api_url_set =
AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
let api_url_section = if custom_api_url_set {
h_flex()
.mt_1()
.p_1()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new("Custom API URL configured.")),
)
.child(
Button::new("reset-api-url", "Reset API URL")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.on_click(
cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
),
)
.into_any()
} else {
v_flex()
.on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
this.save_api_url(cx);
cx.notify();
}))
.mt_2()
.pt_2()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
.gap_1()
.child(
List::new()
.child(InstructionListItem::text_only(
"Optionally, you can change the base URL for the OpenAI API request.",
))
.child(InstructionListItem::text_only(
"Paste the new API endpoint below and hit enter",
)),
)
.child(self.api_url_editor.clone())
.into_any()
};
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
} else {
v_flex()
.size_full()
.child(api_key_section)
.child(api_url_section)
.into_any()
}
}
}

View file

@ -12,9 +12,11 @@ use language_model::{
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason,
RateLimiter, Role, StopReason, TokenUsage,
};
use open_router::{
Model, ModelMode as OpenRouterModelMode, ResponseStreamEvent, list_models, stream_completion,
};
use open_router::{Model, ResponseStreamEvent, list_models, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@ -40,9 +42,44 @@ pub struct OpenRouterSettings {
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
pub supports_tools: Option<bool>,
pub supports_images: Option<bool>,
pub mode: Option<ModelMode>,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ModelMode {
#[default]
Default,
Thinking {
budget_tokens: Option<u32>,
},
}
impl From<ModelMode> for OpenRouterModelMode {
fn from(value: ModelMode) -> Self {
match value {
ModelMode::Default => OpenRouterModelMode::Default,
ModelMode::Thinking { budget_tokens } => {
OpenRouterModelMode::Thinking { budget_tokens }
}
}
}
}
impl From<OpenRouterModelMode> for ModelMode {
fn from(value: OpenRouterModelMode) -> Self {
match value {
OpenRouterModelMode::Default => ModelMode::Default,
OpenRouterModelMode::Thinking { budget_tokens } => {
ModelMode::Thinking { budget_tokens }
}
}
}
}
pub struct OpenRouterLanguageModelProvider {
@ -56,6 +93,7 @@ pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<open_router::Model>,
fetch_models_task: Option<Task<Result<()>>>,
settings: OpenRouterSettings,
_subscription: Subscription,
}
@ -98,6 +136,7 @@ impl State {
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.restart_fetch_models_task(cx);
cx.notify();
})
})
@ -130,6 +169,7 @@ impl State {
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
this.restart_fetch_models_task(cx);
cx.notify();
})?;
@ -153,8 +193,10 @@ impl State {
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
let task = self.fetch_models(cx);
self.fetch_models_task.replace(task);
if self.is_authenticated() {
let task = self.fetch_models(cx);
self.fetch_models_task.replace(task);
}
}
}
@ -166,8 +208,14 @@ impl OpenRouterLanguageModelProvider {
http_client: http_client.clone(),
available_models: Vec::new(),
fetch_models_task: None,
settings: OpenRouterSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.restart_fetch_models_task(cx);
let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
let settings_changed = current_settings != &this.settings;
if settings_changed {
this.settings = current_settings.clone();
this.restart_fetch_models_task(cx);
}
cx.notify();
}),
});
@ -227,7 +275,9 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tools: Some(false),
supports_tools: model.supports_tools,
supports_images: model.supports_images,
mode: model.mode.clone().unwrap_or_default().into(),
});
}
@ -328,11 +378,11 @@ impl LanguageModel for OpenRouterLanguageModel {
format!("openrouter/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u32> {
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
@ -345,14 +395,14 @@ impl LanguageModel for OpenRouterLanguageModel {
}
fn supports_images(&self) -> bool {
false
self.model.supports_images.unwrap_or(false)
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
count_open_router_tokens(request, self.model.clone(), cx)
}
@ -383,23 +433,28 @@ impl LanguageModel for OpenRouterLanguageModel {
pub fn into_open_router(
request: LanguageModelRequest,
model: &Model,
max_output_tokens: Option<u32>,
max_output_tokens: Option<u64>,
) -> open_router::Request {
let mut messages = Vec::new();
for req_message in request.messages {
for content in req_message.content {
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
.push(match req_message.role {
Role::User => open_router::RequestMessage::User { content: text },
Role::Assistant => open_router::RequestMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => open_router::RequestMessage::System { content: text },
}),
MessageContent::Text(text) => add_message_content_part(
open_router::MessagePart::Text { text },
message.role,
&mut messages,
),
MessageContent::Thinking { .. } => {}
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::Image(image) => {
add_message_content_part(
open_router::MessagePart::Image {
image_url: image.to_base64_url(),
},
message.role,
&mut messages,
);
}
MessageContent::ToolUse(tool_use) => {
let tool_call = open_router::ToolCall {
id: tool_use.id.to_string(),
@ -425,16 +480,20 @@ pub fn into_open_router(
}
MessageContent::ToolResult(tool_result) => {
let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
text.to_string()
LanguageModelToolResultContent::Text(text) => {
vec![open_router::MessagePart::Text {
text: text.to_string(),
}]
}
LanguageModelToolResultContent::Image(_) => {
"[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
LanguageModelToolResultContent::Image(image) => {
vec![open_router::MessagePart::Image {
image_url: image.to_base64_url(),
}]
}
};
messages.push(open_router::RequestMessage::Tool {
content: content,
content: content.into(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
@ -454,6 +513,17 @@ pub fn into_open_router(
} else {
None
},
usage: open_router::RequestUsage { include: true },
reasoning: if let OpenRouterModelMode::Thinking { budget_tokens } = model.mode {
Some(open_router::Reasoning {
effort: None,
max_tokens: budget_tokens,
exclude: Some(false),
enabled: Some(true),
})
} else {
None
},
tools: request
.tools
.into_iter()
@ -473,6 +543,42 @@ pub fn into_open_router(
}
}
fn add_message_content_part(
new_part: open_router::MessagePart,
role: Role,
messages: &mut Vec<open_router::RequestMessage>,
) {
match (role, messages.last_mut()) {
(Role::User, Some(open_router::RequestMessage::User { content }))
| (Role::System, Some(open_router::RequestMessage::System { content })) => {
content.push_part(new_part);
}
(
Role::Assistant,
Some(open_router::RequestMessage::Assistant {
content: Some(content),
..
}),
) => {
content.push_part(new_part);
}
_ => {
messages.push(match role {
Role::User => open_router::RequestMessage::User {
content: open_router::MessageContent::from(vec![new_part]),
},
Role::Assistant => open_router::RequestMessage::Assistant {
content: Some(open_router::MessageContent::from(vec![new_part])),
tool_calls: Vec::new(),
},
Role::System => open_router::RequestMessage::System {
content: open_router::MessageContent::from(vec![new_part]),
},
});
}
}
}
pub struct OpenRouterEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
@ -508,8 +614,19 @@ impl OpenRouterEventMapper {
};
let mut events = Vec::new();
if let Some(reasoning) = choice.delta.reasoning.clone() {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text: reasoning,
signature: None,
}));
}
if let Some(content) = choice.delta.content.clone() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
// OpenRouter send empty content string with the reasoning content
// This is a workaround for the OpenRouter API bug
if !content.is_empty() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
}
if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
@ -532,6 +649,15 @@ impl OpenRouterEventMapper {
}
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() {
Some("stop") => {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
@ -560,7 +686,7 @@ impl OpenRouterEventMapper {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
None => {}
@ -581,7 +707,7 @@ pub fn count_open_router_tokens(
request: LanguageModelRequest,
_model: open_router::Model,
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
@ -598,7 +724,7 @@ pub fn count_open_router_tokens(
})
.collect::<Vec<_>>();
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
})
.boxed()
}

View file

@ -0,0 +1,577 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelToolChoice, RateLimiter, Role,
};
use menu;
use open_ai::ResponseStreamEvent;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use strum::IntoEnumIterator;
use vercel::Model;
use ui::{ElevationIndex, List, Tooltip, prelude::*};
use ui_input::SingleLineInput;
use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
const PROVIDER_ID: &str = "vercel";
const PROVIDER_NAME: &str = "Vercel";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
}
pub struct VercelLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}
const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY";
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.delete_credentials(&api_url, &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
cx.notify();
})
})
}
fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
credentials_provider
.write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
.await
.log_err();
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}
fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) {
(api_key, true)
} else {
let (_, api_key) = credentials_provider
.read_credentials(&api_url, &cx)
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
cx.notify();
})?;
Ok(())
})
}
}
impl VercelLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let state = cx.new(|cx| State {
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
Self { http_client, state }
}
fn create_language_model(&self, model: vercel::Model) -> Arc<dyn LanguageModel> {
Arc::new(VercelLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
}
}
impl LanguageModelProviderState for VercelLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for VercelLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiVZero
}
fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(vercel::Model::default()))
}
fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(self.create_language_model(vercel::Model::default_fast()))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
for model in vercel::Model::iter() {
if !matches!(model, vercel::Model::Custom { .. }) {
models.insert(model.id().to_string(), model);
}
}
for model in &AllLanguageModelSettings::get_global(cx)
.vercel
.available_models
{
models.insert(
model.name.clone(),
vercel::Model::Custom {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
},
);
}
models
.into_values()
.map(|model| self.create_language_model(model))
.collect()
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.reset_api_key(cx))
}
}
pub struct VercelLanguageModel {
id: LanguageModelId,
model: vercel::Model,
state: gpui::Entity<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl VercelLanguageModel {
fn stream_completion(
&self,
request: open_ai::Request,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).vercel;
let api_url = if settings.api_url.is_empty() {
vercel::VERCEL_API_URL.to_string()
} else {
settings.api_url.clone()
};
(state.api_key.clone(), api_url)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let api_key = api_key.context("Missing Vercel API Key")?;
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for VercelLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn supports_tools(&self) -> bool {
true
}
fn supports_images(&self) -> bool {
true
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto
| LanguageModelToolChoice::Any
| LanguageModelToolChoice::None => true,
}
}
fn telemetry_id(&self) -> String {
format!("vercel/{}", self.model.id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn max_output_tokens(&self) -> Option<u64> {
self.model.max_output_tokens()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
count_vercel_tokens(request, self.model.clone(), cx)
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
futures::stream::BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
> {
let request = crate::provider::open_ai::into_open_ai(
request,
self.model.id(),
self.model.supports_parallel_tool_calls(),
self.max_output_tokens(),
);
let completions = self.stream_completion(request, cx);
async move {
let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
pub fn count_vercel_tokens(
request: LanguageModelRequest,
model: Model,
cx: &App,
) -> BoxFuture<'static, Result<u64>> {
cx.background_spawn(async move {
let messages = request
.messages
.into_iter()
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
role: match message.role {
Role::User => "user".into(),
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
content: Some(message.string_contents()),
name: None,
function_call: None,
})
.collect::<Vec<_>>();
match model {
Model::Custom { max_tokens, .. } => {
let model = if max_tokens >= 100_000 {
// If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
"gpt-4o"
} else {
// Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
// supported with this tiktoken method
"gpt-4"
};
tiktoken_rs::num_tokens_from_messages(model, &messages)
}
// Map Vercel models to appropriate OpenAI models for token counting
// since Vercel uses OpenAI-compatible API
Model::VZeroOnePointFiveMedium => {
// Vercel v0 is similar to GPT-4o, so use gpt-4o for token counting
tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
}
}
.map(|tokens| tokens as u64)
})
.boxed()
}
struct ConfigurationView {
api_key_editor: Entity<SingleLineInput>,
state: gpui::Entity<State>,
load_credentials_task: Option<Task<()>>,
}
impl ConfigurationView {
fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let api_key_editor = cx.new(|cx| {
SingleLineInput::new(
window,
cx,
"v1:0000000000000000000000000000000000000000000000000",
)
.label("API key")
});
cx.observe(&state, |_, _, cx| {
cx.notify();
})
.detach();
let load_credentials_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
api_key_editor,
state,
load_credentials_task,
}
}
fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let api_key = self
.api_key_editor
.read(cx)
.editor()
.read(cx)
.text(cx)
.trim()
.to_string();
// Don't proceed if no API key is provided and we're not authenticated
if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
return;
}
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state
.update(cx, |state, cx| state.set_api_key(api_key, cx))?
.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
self.api_key_editor.update(cx, |input, cx| {
input.editor.update(cx, |editor, cx| {
editor.set_text("", window, cx);
});
});
let state = self.state.clone();
cx.spawn_in(window, async move |_, cx| {
state.update(cx, |state, cx| state.reset_api_key(cx))?.await
})
.detach_and_log_err(cx);
cx.notify();
}
fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let env_var_set = self.state.read(cx).api_key_from_env;
let api_key_section = if self.should_render_editor(cx) {
v_flex()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new("To use Zed's agent with Vercel v0, you need to add an API key. Follow these steps:"))
.child(
List::new()
.child(InstructionListItem::new(
"Create one by visiting",
Some("Vercel v0's console"),
Some("https://v0.dev/chat/settings/keys"),
))
.child(InstructionListItem::text_only(
"Paste your API key below and hit enter to start using the agent",
)),
)
.child(self.api_key_editor.clone())
.child(
Label::new(format!(
"You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed."
))
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
Label::new("Note that Vercel v0 is a custom OpenAI-compatible provider.")
.size(LabelSize::Small)
.color(Color::Muted),
)
.into_any()
} else {
h_flex()
.mt_1()
.p_1()
.justify_between()
.rounded_md()
.border_1()
.border_color(cx.theme().colors().border)
.bg(cx.theme().colors().background)
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {VERCEL_API_KEY_VAR} environment variable.")
} else {
"API key configured.".to_string()
})),
)
.child(
Button::new("reset-api-key", "Reset API Key")
.label_size(LabelSize::Small)
.icon(IconName::Undo)
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.layer(ElevationIndex::ModalSurface)
.when(env_var_set, |this| {
this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable.")))
})
.on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
)
.into_any()
};
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials…")).into_any()
} else {
v_flex().size_full().child(api_key_section).into_any()
}
}
}

View file

@ -1,19 +1,14 @@
use std::sync::Arc;
use anyhow::Result;
use gpui::App;
use language_model::LanguageModelCacheConfiguration;
use project::Fs;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources, update_settings_file};
use settings::{Settings, SettingsSources};
use crate::provider::{
self,
anthropic::AnthropicSettings,
bedrock::AmazonBedrockSettings,
cloud::{self, ZedDotDevSettings},
copilot_chat::CopilotChatSettings,
deepseek::DeepSeekSettings,
google::GoogleSettings,
lmstudio::LmStudioSettings,
@ -21,39 +16,12 @@ use crate::provider::{
ollama::OllamaSettings,
open_ai::OpenAiSettings,
open_router::OpenRouterSettings,
vercel::VercelSettings,
};
/// Initializes the language model settings.
pub fn init(fs: Arc<dyn Fs>, cx: &mut App) {
pub fn init(cx: &mut App) {
AllLanguageModelSettings::register(cx);
if AllLanguageModelSettings::get_global(cx)
.openai
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs.clone(), cx, move |setting, _| {
if let Some(settings) = setting.openai.clone() {
let (newest_version, _) = settings.upgrade();
setting.openai = Some(OpenAiSettingsContent::Versioned(
VersionedOpenAiSettingsContent::V1(newest_version),
));
}
});
}
if AllLanguageModelSettings::get_global(cx)
.anthropic
.needs_setting_migration
{
update_settings_file::<AllLanguageModelSettings>(fs, cx, move |setting, _| {
if let Some(settings) = setting.anthropic.clone() {
let (newest_version, _) = settings.upgrade();
setting.anthropic = Some(AnthropicSettingsContent::Versioned(
VersionedAnthropicSettingsContent::V1(newest_version),
));
}
});
}
}
#[derive(Default)]
@ -65,7 +33,8 @@ pub struct AllLanguageModelSettings {
pub open_router: OpenRouterSettings,
pub zed_dot_dev: ZedDotDevSettings,
pub google: GoogleSettings,
pub copilot_chat: CopilotChatSettings,
pub vercel: VercelSettings,
pub lmstudio: LmStudioSettings,
pub deepseek: DeepSeekSettings,
pub mistral: MistralSettings,
@ -83,83 +52,13 @@ pub struct AllLanguageModelSettingsContent {
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
pub google: Option<GoogleSettingsContent>,
pub deepseek: Option<DeepseekSettingsContent>,
pub copilot_chat: Option<CopilotChatSettingsContent>,
pub vercel: Option<VercelSettingsContent>,
pub mistral: Option<MistralSettingsContent>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum AnthropicSettingsContent {
Versioned(VersionedAnthropicSettingsContent),
Legacy(LegacyAnthropicSettingsContent),
}
impl AnthropicSettingsContent {
pub fn upgrade(self) -> (AnthropicSettingsContentV1, bool) {
match self {
AnthropicSettingsContent::Legacy(content) => (
AnthropicSettingsContentV1 {
api_url: content.api_url,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
anthropic::Model::Custom {
name,
display_name,
max_tokens,
tool_override,
cache_configuration,
max_output_tokens,
default_temperature,
extra_beta_headers,
mode,
} => Some(provider::anthropic::AvailableModel {
name,
display_name,
max_tokens,
tool_override,
cache_configuration: cache_configuration.as_ref().map(
|config| LanguageModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors,
should_speculate: config.should_speculate,
min_total_token: config.min_total_token,
},
),
max_output_tokens,
default_temperature,
extra_beta_headers,
mode: Some(mode.into()),
}),
_ => None,
})
.collect()
}),
},
true,
),
AnthropicSettingsContent::Versioned(content) => match content {
VersionedAnthropicSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyAnthropicSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<anthropic::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedAnthropicSettingsContent {
#[serde(rename = "1")]
V1(AnthropicSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct AnthropicSettingsContentV1 {
pub struct AnthropicSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<provider::anthropic::AvailableModel>>,
}
@ -198,68 +97,17 @@ pub struct MistralSettingsContent {
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(untagged)]
pub enum OpenAiSettingsContent {
Versioned(VersionedOpenAiSettingsContent),
Legacy(LegacyOpenAiSettingsContent),
}
impl OpenAiSettingsContent {
pub fn upgrade(self) -> (OpenAiSettingsContentV1, bool) {
match self {
OpenAiSettingsContent::Legacy(content) => (
OpenAiSettingsContentV1 {
api_url: content.api_url,
available_models: content.available_models.map(|models| {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom {
name,
display_name,
max_tokens,
max_output_tokens,
max_completion_tokens,
} => Some(provider::open_ai::AvailableModel {
name,
max_tokens,
max_output_tokens,
display_name,
max_completion_tokens,
}),
_ => None,
})
.collect()
}),
},
true,
),
OpenAiSettingsContent::Versioned(content) => match content {
VersionedOpenAiSettingsContent::V1(content) => (content, false),
},
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct LegacyOpenAiSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<open_ai::Model>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
#[serde(tag = "version")]
pub enum VersionedOpenAiSettingsContent {
#[serde(rename = "1")]
V1(OpenAiSettingsContentV1),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenAiSettingsContentV1 {
pub struct OpenAiSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<provider::open_ai::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct VercelSettingsContent {
pub api_url: Option<String>,
pub available_models: Option<Vec<provider::vercel::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct GoogleSettingsContent {
pub api_url: Option<String>,
@ -271,13 +119,6 @@ pub struct ZedDotDevSettingsContent {
available_models: Option<Vec<cloud::AvailableModel>>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct CopilotChatSettingsContent {
pub api_url: Option<String>,
pub auth_url: Option<String>,
pub models_url: Option<String>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OpenRouterSettingsContent {
pub api_url: Option<String>,
@ -302,15 +143,7 @@ impl settings::Settings for AllLanguageModelSettings {
for value in sources.defaults_and_customizations() {
// Anthropic
let (anthropic, upgraded) = match value.anthropic.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.anthropic.needs_setting_migration = true;
}
let anthropic = value.anthropic.clone();
merge(
&mut settings.anthropic.api_url,
anthropic.as_ref().and_then(|s| s.api_url.clone()),
@ -376,15 +209,7 @@ impl settings::Settings for AllLanguageModelSettings {
);
// OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {
Some((content, upgraded)) => (Some(content), upgraded),
None => (None, false),
};
if upgraded {
settings.openai.needs_setting_migration = true;
}
let openai = value.openai.clone();
merge(
&mut settings.openai.api_url,
openai.as_ref().and_then(|s| s.api_url.clone()),
@ -393,6 +218,18 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.openai.available_models,
openai.as_ref().and_then(|s| s.available_models.clone()),
);
// Vercel
let vercel = value.vercel.clone();
merge(
&mut settings.vercel.api_url,
vercel.as_ref().and_then(|s| s.api_url.clone()),
);
merge(
&mut settings.vercel.available_models,
vercel.as_ref().and_then(|s| s.available_models.clone()),
);
merge(
&mut settings.zed_dot_dev.available_models,
value
@ -435,24 +272,6 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
// Copilot Chat
let copilot_chat = value.copilot_chat.clone().unwrap_or_default();
settings.copilot_chat.api_url = copilot_chat.api_url.map_or_else(
|| Arc::from("https://api.githubcopilot.com/chat/completions"),
Arc::from,
);
settings.copilot_chat.auth_url = copilot_chat.auth_url.map_or_else(
|| Arc::from("https://api.github.com/copilot_internal/v2/token"),
Arc::from,
);
settings.copilot_chat.models_url = copilot_chat.models_url.map_or_else(
|| Arc::from("https://api.githubcopilot.com/models"),
Arc::from,
);
}
Ok(settings)

View file

@ -40,29 +40,30 @@ impl IntoElement for InstructionListItem {
let link = button_link.clone();
let unique_id = SharedString::from(format!("{}-button", self.label));
h_flex().flex_wrap().child(Label::new(self.label)).child(
Button::new(unique_id, button_label)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| cx.open_url(&link)),
)
h_flex()
.flex_wrap()
.child(Label::new(self.label))
.child(
Button::new(unique_id, button_label)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| cx.open_url(&link)),
)
.into_any_element()
} else {
div().child(Label::new(self.label))
Label::new(self.label).into_any_element()
};
div()
.child(
ListItem::new("list-item")
.selectable(false)
.start_slot(
Icon::new(IconName::Dash)
.size(IconSize::XSmall)
.color(Color::Hidden),
)
.child(item_content),
ListItem::new("list-item")
.selectable(false)
.start_slot(
Icon::new(IconName::Dash)
.size(IconSize::XSmall)
.color(Color::Hidden),
)
.into_any()
.child(div().w_full().child(item_content))
.into_any_element()
}
}