language_model: Denote the availability of language models (#15660)

This PR updates the `LanguageModel` trait with a new method for denoting
the availability of a model.

Right now we have two variants:

- `Public` for models that have no additional restrictions (other than
their respective setup/authentication requirements)
- `RequiresPlan` for models that require a specific Zed plan

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2024-08-01 18:26:27 -04:00 committed by GitHub
parent 906d9736d5
commit 5e011ab029
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 13 deletions

View file

@ -14,6 +14,7 @@ use gpui::{
};
pub use model::*;
use project::Fs;
use proto::Plan;
pub(crate) use rate_limiter::*;
pub use registry::*;
pub use request::*;
@ -32,6 +33,15 @@ pub fn init(
registry::init(user_store, client, cx);
}
/// The availability of a [`LanguageModel`].
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LanguageModelAvailability {
/// The language model is available to the general public.
Public,
/// The language model is available to users on the indicated plan.
RequiresPlan(Plan),
}
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@ -39,6 +49,11 @@ pub trait LanguageModel: Send + Sync {
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
/// Returns the availability of this language model.
fn availability(&self) -> LanguageModelAvailability {
LanguageModelAvailability::Public
}
fn max_token_count(&self) -> usize;
fn count_tokens(

View file

@ -1,7 +1,10 @@
use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
use crate::LanguageModelAvailability;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(tag = "provider", rename_all = "lowercase")]
pub enum CloudModel {
@ -46,28 +49,65 @@ impl Default for CloudModel {
impl CloudModel {
pub fn id(&self) -> &str {
match self {
CloudModel::Anthropic(model) => model.id(),
CloudModel::OpenAi(model) => model.id(),
CloudModel::Google(model) => model.id(),
CloudModel::Zed(model) => model.id(),
Self::Anthropic(model) => model.id(),
Self::OpenAi(model) => model.id(),
Self::Google(model) => model.id(),
Self::Zed(model) => model.id(),
}
}
pub fn display_name(&self) -> &str {
match self {
CloudModel::Anthropic(model) => model.display_name(),
CloudModel::OpenAi(model) => model.display_name(),
CloudModel::Google(model) => model.display_name(),
CloudModel::Zed(model) => model.display_name(),
Self::Anthropic(model) => model.display_name(),
Self::OpenAi(model) => model.display_name(),
Self::Google(model) => model.display_name(),
Self::Zed(model) => model.display_name(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
CloudModel::Anthropic(model) => model.max_token_count(),
CloudModel::OpenAi(model) => model.max_token_count(),
CloudModel::Google(model) => model.max_token_count(),
CloudModel::Zed(model) => model.max_token_count(),
Self::Anthropic(model) => model.max_token_count(),
Self::OpenAi(model) => model.max_token_count(),
Self::Google(model) => model.max_token_count(),
Self::Zed(model) => model.max_token_count(),
}
}
/// Returns the availability of this model.
pub fn availability(&self) -> LanguageModelAvailability {
match self {
Self::Anthropic(model) => match model {
anthropic::Model::Claude3_5Sonnet => {
LanguageModelAvailability::RequiresPlan(Plan::Free)
}
anthropic::Model::Claude3Opus
| anthropic::Model::Claude3Sonnet
| anthropic::Model::Claude3Haiku
| anthropic::Model::Custom { .. } => {
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
}
},
Self::OpenAi(model) => match model {
open_ai::Model::ThreePointFiveTurbo
| open_ai::Model::Four
| open_ai::Model::FourTurbo
| open_ai::Model::FourOmni
| open_ai::Model::FourOmniMini
| open_ai::Model::Custom { .. } => {
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
}
},
Self::Google(model) => match model {
google_ai::Model::Gemini15Pro
| google_ai::Model::Gemini15Flash
| google_ai::Model::Custom { .. } => {
LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
}
},
Self::Zed(model) => match model {
ZedModel::Qwen2_7bInstruct => LanguageModelAvailability::RequiresPlan(Plan::ZedPro),
},
}
}
}

View file

@ -18,7 +18,7 @@ use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
use crate::LanguageModelProvider;
use crate::{LanguageModelAvailability, LanguageModelProvider};
use super::anthropic::count_anthropic_tokens;
@ -236,6 +236,10 @@ impl LanguageModel for CloudLanguageModel {
format!("zed.dev/{}", self.model.id())
}
fn availability(&self) -> LanguageModelAvailability {
self.model.availability()
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}