
Release Notes: - Add support for OpenAI GPT-4.1 via Copilot Chat and OpenAI API --------- Co-authored-by: Danilo Leal <daniloleal09@gmail.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
208 lines
6.7 KiB
Rust
208 lines
6.7 KiB
Rust
use std::fmt;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::Result;
|
|
use client::Client;
|
|
use gpui::{
|
|
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
|
|
};
|
|
use proto::{Plan, TypedEnvelope};
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
|
|
use strum::EnumIter;
|
|
use thiserror::Error;
|
|
|
|
use crate::{LanguageModelAvailability, LanguageModelToolSchemaFormat};
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
|
|
#[serde(tag = "provider", rename_all = "lowercase")]
|
|
pub enum CloudModel {
|
|
Anthropic(anthropic::Model),
|
|
OpenAi(open_ai::Model),
|
|
Google(google_ai::Model),
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
|
|
pub enum ZedModel {
|
|
#[serde(rename = "Qwen/Qwen2-7B-Instruct")]
|
|
Qwen2_7bInstruct,
|
|
}
|
|
|
|
impl Default for CloudModel {
|
|
fn default() -> Self {
|
|
Self::Anthropic(anthropic::Model::default())
|
|
}
|
|
}
|
|
|
|
impl CloudModel {
|
|
pub fn id(&self) -> &str {
|
|
match self {
|
|
Self::Anthropic(model) => model.id(),
|
|
Self::OpenAi(model) => model.id(),
|
|
Self::Google(model) => model.id(),
|
|
}
|
|
}
|
|
|
|
pub fn display_name(&self) -> &str {
|
|
match self {
|
|
Self::Anthropic(model) => model.display_name(),
|
|
Self::OpenAi(model) => model.display_name(),
|
|
Self::Google(model) => model.display_name(),
|
|
}
|
|
}
|
|
|
|
pub fn max_token_count(&self) -> usize {
|
|
match self {
|
|
Self::Anthropic(model) => model.max_token_count(),
|
|
Self::OpenAi(model) => model.max_token_count(),
|
|
Self::Google(model) => model.max_token_count(),
|
|
}
|
|
}
|
|
|
|
/// Returns the availability of this model.
|
|
pub fn availability(&self) -> LanguageModelAvailability {
|
|
match self {
|
|
Self::Anthropic(model) => match model {
|
|
anthropic::Model::Claude3_5Sonnet
|
|
| anthropic::Model::Claude3_7Sonnet
|
|
| anthropic::Model::Claude3_7SonnetThinking => {
|
|
LanguageModelAvailability::RequiresPlan(Plan::Free)
|
|
}
|
|
anthropic::Model::Claude3Opus
|
|
| anthropic::Model::Claude3Sonnet
|
|
| anthropic::Model::Claude3Haiku
|
|
| anthropic::Model::Claude3_5Haiku
|
|
| anthropic::Model::Custom { .. } => {
|
|
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
|
|
}
|
|
},
|
|
Self::OpenAi(model) => match model {
|
|
open_ai::Model::ThreePointFiveTurbo
|
|
| open_ai::Model::Four
|
|
| open_ai::Model::FourTurbo
|
|
| open_ai::Model::FourOmni
|
|
| open_ai::Model::FourOmniMini
|
|
| open_ai::Model::FourPointOne
|
|
| open_ai::Model::FourPointOneMini
|
|
| open_ai::Model::FourPointOneNano
|
|
| open_ai::Model::O1Mini
|
|
| open_ai::Model::O1Preview
|
|
| open_ai::Model::O1
|
|
| open_ai::Model::O3Mini
|
|
| open_ai::Model::Custom { .. } => {
|
|
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
|
|
}
|
|
},
|
|
Self::Google(model) => match model {
|
|
google_ai::Model::Gemini15Pro
|
|
| google_ai::Model::Gemini15Flash
|
|
| google_ai::Model::Gemini20Pro
|
|
| google_ai::Model::Gemini20Flash
|
|
| google_ai::Model::Gemini20FlashThinking
|
|
| google_ai::Model::Gemini20FlashLite
|
|
| google_ai::Model::Gemini25ProExp0325
|
|
| google_ai::Model::Gemini25ProPreview0325
|
|
| google_ai::Model::Custom { .. } => {
|
|
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
pub fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
|
|
match self {
|
|
Self::Anthropic(_) | Self::OpenAi(_) => LanguageModelToolSchemaFormat::JsonSchema,
|
|
Self::Google(_) => LanguageModelToolSchemaFormat::JsonSchemaSubset,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Error, Debug)]
|
|
pub struct PaymentRequiredError;
|
|
|
|
impl fmt::Display for PaymentRequiredError {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(
|
|
f,
|
|
"Payment required to use this language model. Please upgrade your account."
|
|
)
|
|
}
|
|
}
|
|
|
|
#[derive(Error, Debug)]
|
|
pub struct MaxMonthlySpendReachedError;
|
|
|
|
impl fmt::Display for MaxMonthlySpendReachedError {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(
|
|
f,
|
|
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
|
|
)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
|
|
|
|
impl LlmApiToken {
|
|
pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
|
let lock = self.0.upgradable_read().await;
|
|
if let Some(token) = lock.as_ref() {
|
|
Ok(token.to_string())
|
|
} else {
|
|
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
|
|
}
|
|
}
|
|
|
|
pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
|
|
Self::fetch(self.0.write().await, client).await
|
|
}
|
|
|
|
async fn fetch(
|
|
mut lock: RwLockWriteGuard<'_, Option<String>>,
|
|
client: &Arc<Client>,
|
|
) -> Result<String> {
|
|
let response = client.request(proto::GetLlmToken {}).await?;
|
|
*lock = Some(response.token.clone());
|
|
Ok(response.token.clone())
|
|
}
|
|
}
|
|
|
|
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
|
|
|
|
impl Global for GlobalRefreshLlmTokenListener {}
|
|
|
|
pub struct RefreshLlmTokenEvent;
|
|
|
|
pub struct RefreshLlmTokenListener {
|
|
_llm_token_subscription: client::Subscription,
|
|
}
|
|
|
|
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
|
|
|
|
impl RefreshLlmTokenListener {
|
|
pub fn register(client: Arc<Client>, cx: &mut App) {
|
|
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
|
|
cx.set_global(GlobalRefreshLlmTokenListener(listener));
|
|
}
|
|
|
|
pub fn global(cx: &App) -> Entity<Self> {
|
|
GlobalRefreshLlmTokenListener::global(cx).0.clone()
|
|
}
|
|
|
|
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
|
|
Self {
|
|
_llm_token_subscription: client
|
|
.add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
|
|
}
|
|
}
|
|
|
|
async fn handle_refresh_llm_token(
|
|
this: Entity<Self>,
|
|
_: TypedEnvelope<proto::RefreshLlmToken>,
|
|
mut cx: AsyncApp,
|
|
) -> Result<()> {
|
|
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
|
|
}
|
|
}
|