Remove dependents of language_models (#25511)

This PR removes the dependents of the `language_models` crate.

The following types have been moved from `language_models` to
`language_model` to facilitate this:

- `LlmApiToken`
- `RefreshLlmTokenListener`
- `MaxMonthlySpendReachedError`
- `PaymentRequiredError`

With this change only `zed` now depends on `language_models`.

Release Notes:

- N/A
This commit is contained in:
Marshall Bowers 2025-02-24 17:46:45 -05:00 committed by GitHub
parent bbb8d63de0
commit def342e35c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 117 additions and 124 deletions

5
Cargo.lock generated
View file

@ -462,7 +462,6 @@ dependencies = [
"language", "language",
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"language_models",
"log", "log",
"lsp", "lsp",
"markdown", "markdown",
@ -517,7 +516,6 @@ dependencies = [
"language", "language",
"language_model", "language_model",
"language_model_selector", "language_model_selector",
"language_models",
"languages", "languages",
"log", "log",
"multi_buffer", "multi_buffer",
@ -7083,7 +7081,6 @@ dependencies = [
"smol", "smol",
"strum", "strum",
"theme", "theme",
"thiserror 1.0.69",
"tiktoken-rs", "tiktoken-rs",
"ui", "ui",
"util", "util",
@ -17190,7 +17187,7 @@ dependencies = [
"indoc", "indoc",
"inline_completion", "inline_completion",
"language", "language",
"language_models", "language_model",
"log", "log",
"menu", "menu",
"migrator", "migrator",

View file

@ -46,7 +46,6 @@ itertools.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
language_models.workspace = true
log.workspace = true log.workspace = true
lsp.workspace = true lsp.workspace = true
markdown.workspace = true markdown.workspace = true

View file

@ -10,9 +10,9 @@ use gpui::{App, Context, EventEmitter, SharedString, Task};
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, Role, StopReason, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
Role, StopReason,
}; };
use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use util::{post_inc, TryFutureExt as _}; use util::{post_inc, TryFutureExt as _};
use uuid::Uuid; use uuid::Uuid;

View file

@ -30,7 +30,6 @@ indexed_docs.workspace = true
language.workspace = true language.workspace = true
language_model.workspace = true language_model.workspace = true
language_model_selector.workspace = true language_model_selector.workspace = true
language_models.workspace = true
log.workspace = true log.workspace = true
multi_buffer.workspace = true multi_buffer.workspace = true
open_ai.workspace = true open_ai.workspace = true

View file

@ -21,9 +21,9 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
use language_model::{ use language_model::{
report_assistant_event, LanguageModel, LanguageModelCacheConfiguration, report_assistant_event, LanguageModel, LanguageModelCacheConfiguration,
LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, Role, StopReason, LanguageModelRequestMessage, LanguageModelToolUseId, MaxMonthlySpendReachedError,
MessageContent, PaymentRequiredError, Role, StopReason,
}; };
use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
use open_ai::Model as OpenAiModel; use open_ai::Model as OpenAiModel;
use paths::contexts_dir; use paths::contexts_dir;
use project::Project; use project::Project;

View file

@ -9,6 +9,7 @@ mod telemetry;
pub mod fake_provider; pub mod fake_provider;
use anyhow::Result; use anyhow::Result;
use client::Client;
use futures::FutureExt; use futures::FutureExt;
use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _}; use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window}; use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
@ -29,8 +30,9 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub fn init(cx: &mut App) { pub fn init(client: Arc<Client>, cx: &mut App) {
registry::init(cx); registry::init(cx);
RefreshLlmTokenListener::register(client.clone(), cx);
} }
/// The availability of a [`LanguageModel`]. /// The availability of a [`LanguageModel`].

View file

@ -1,7 +1,17 @@
use proto::Plan; use std::fmt;
use std::sync::Arc;
use anyhow::Result;
use client::Client;
use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
};
use proto::{Plan, TypedEnvelope};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use strum::EnumIter; use strum::EnumIter;
use thiserror::Error;
use ui::IconName; use ui::IconName;
use crate::LanguageModelAvailability; use crate::LanguageModelAvailability;
@ -102,3 +112,92 @@ impl CloudModel {
} }
} }
} }
#[derive(Error, Debug)]
pub struct PaymentRequiredError;
impl fmt::Display for PaymentRequiredError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Payment required to use this language model. Please upgrade your account."
)
}
}
#[derive(Error, Debug)]
pub struct MaxMonthlySpendReachedError;
impl fmt::Display for MaxMonthlySpendReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
)
}
}
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
impl LlmApiToken {
pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
}
}
pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
Self::fetch(self.0.write().await, client).await
}
async fn fetch<'a>(
mut lock: RwLockWriteGuard<'a, Option<String>>,
client: &Arc<Client>,
) -> Result<String> {
let response = client.request(proto::GetLlmToken {}).await?;
*lock = Some(response.token.clone());
Ok(response.token.clone())
}
}
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener {
_llm_token_subscription: client::Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &App) -> Entity<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
Self {
_llm_token_subscription: client
.add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
}
}
async fn handle_refresh_llm_token(
this: Entity<Self>,
_: TypedEnvelope<proto::RefreshLlmToken>,
mut cx: AsyncApp,
) -> Result<()> {
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
}
}

View file

@ -41,7 +41,6 @@ settings.workspace = true
smol.workspace = true smol.workspace = true
strum.workspace = true strum.workspace = true
theme.workspace = true theme.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true tiktoken-rs.workspace = true
ui.workspace = true ui.workspace = true
util.workspace = true util.workspace = true

View file

@ -11,8 +11,6 @@ mod settings;
use crate::provider::anthropic::AnthropicLanguageModelProvider; use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider;
pub use crate::provider::cloud::LlmApiToken;
pub use crate::provider::cloud::RefreshLlmTokenListener;
use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider;
use crate::provider::google::GoogleLanguageModelProvider; use crate::provider::google::GoogleLanguageModelProvider;
use crate::provider::lmstudio::LmStudioLanguageModelProvider; use crate::provider::lmstudio::LmStudioLanguageModelProvider;
@ -37,8 +35,6 @@ fn register_language_model_providers(
) { ) {
use feature_flags::FeatureFlagAppExt; use feature_flags::FeatureFlagAppExt;
RefreshLlmTokenListener::register(client.clone(), cx);
registry.register_provider( registry.register_provider(
AnthropicLanguageModelProvider::new(client.http_client(), cx), AnthropicLanguageModelProvider::new(client.http_client(), cx),
cx, cx,

View file

@ -10,10 +10,7 @@ use futures::{
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
TryStreamExt as _, TryStreamExt as _,
}; };
use gpui::{ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal,
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{ use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
@ -22,24 +19,19 @@ use language_model::{
ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_ID,
}; };
use language_model::{ use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
}; };
use proto::TypedEnvelope;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue; use serde_json::value::RawValue;
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use smol::{ use smol::io::{AsyncReadExt, BufReader};
io::{AsyncReadExt, BufReader},
lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
};
use std::fmt;
use std::{ use std::{
future, future,
sync::{Arc, LazyLock}, sync::{Arc, LazyLock},
}; };
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{prelude::*, TintColor}; use ui::{prelude::*, TintColor};
use crate::provider::anthropic::{ use crate::provider::anthropic::{
@ -101,44 +93,6 @@ pub struct AvailableModel {
pub extra_beta_headers: Vec<String>, pub extra_beta_headers: Vec<String>,
} }
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
pub struct RefreshLlmTokenEvent;
pub struct RefreshLlmTokenListener {
_llm_token_subscription: client::Subscription,
}
impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
impl RefreshLlmTokenListener {
pub fn register(client: Arc<Client>, cx: &mut App) {
let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
cx.set_global(GlobalRefreshLlmTokenListener(listener));
}
pub fn global(cx: &App) -> Entity<Self> {
GlobalRefreshLlmTokenListener::global(cx).0.clone()
}
fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
Self {
_llm_token_subscription: client
.add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
}
}
async fn handle_refresh_llm_token(
this: Entity<Self>,
_: TypedEnvelope<proto::RefreshLlmToken>,
mut cx: AsyncApp,
) -> Result<()> {
this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
}
}
pub struct CloudLanguageModelProvider { pub struct CloudLanguageModelProvider {
client: Arc<Client>, client: Arc<Client>,
state: gpui::Entity<State>, state: gpui::Entity<State>,
@ -475,33 +429,6 @@ pub struct CloudLanguageModel {
request_limiter: RateLimiter, request_limiter: RateLimiter,
} }
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
#[derive(Error, Debug)]
pub struct PaymentRequiredError;
impl fmt::Display for PaymentRequiredError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Payment required to use this language model. Please upgrade your account."
)
}
}
#[derive(Error, Debug)]
pub struct MaxMonthlySpendReachedError;
impl fmt::Display for MaxMonthlySpendReachedError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Maximum spending limit reached for this month. For more usage, increase your spending limit."
)
}
}
impl CloudLanguageModel { impl CloudLanguageModel {
async fn perform_llm_completion( async fn perform_llm_completion(
client: Arc<Client>, client: Arc<Client>,
@ -847,30 +774,6 @@ fn response_lines<T: DeserializeOwned>(
) )
} }
impl LlmApiToken {
pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
if let Some(token) = lock.as_ref() {
Ok(token.to_string())
} else {
Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
}
}
pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
Self::fetch(self.0.write().await, client).await
}
async fn fetch<'a>(
mut lock: RwLockWriteGuard<'a, Option<String>>,
client: &Arc<Client>,
) -> Result<String> {
let response = client.request(proto::GetLlmToken {}).await?;
*lock = Some(response.token.clone());
Ok(response.token.clone())
}
}
struct ConfigurationView { struct ConfigurationView {
state: gpui::Entity<State>, state: gpui::Entity<State>,
} }

View file

@ -436,7 +436,7 @@ fn main() {
cx, cx,
); );
supermaven::init(app_state.client.clone(), cx); supermaven::init(app_state.client.clone(), cx);
language_model::init(cx); language_model::init(app_state.client.clone(), cx);
language_models::init( language_models::init(
app_state.user_store.clone(), app_state.user_store.clone(),
app_state.client.clone(), app_state.client.clone(),

View file

@ -4237,7 +4237,7 @@ mod tests {
cx, cx,
); );
image_viewer::init(cx); image_viewer::init(cx);
language_model::init(cx); language_model::init(app_state.client.clone(), cx);
language_models::init( language_models::init(
app_state.user_store.clone(), app_state.user_store.clone(),
app_state.client.clone(), app_state.client.clone(),

View file

@ -33,7 +33,7 @@ http_client.workspace = true
indoc.workspace = true indoc.workspace = true
inline_completion.workspace = true inline_completion.workspace = true
language.workspace = true language.workspace = true
language_models.workspace = true language_model.workspace = true
log.workspace = true log.workspace = true
menu.workspace = true menu.workspace = true
migrator.workspace = true migrator.workspace = true

View file

@ -31,7 +31,7 @@ use input_excerpt::excerpt_for_cursor_position;
use language::{ use language::{
text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint,
}; };
use language_models::LlmApiToken; use language_model::{LlmApiToken, RefreshLlmTokenListener};
use postage::watch; use postage::watch;
use project::Project; use project::Project;
use release_channel::AppVersion; use release_channel::AppVersion;
@ -244,7 +244,7 @@ impl Zeta {
user_store: Entity<UserStore>, user_store: Entity<UserStore>,
cx: &mut Context<Self>, cx: &mut Context<Self>,
) -> Self { ) -> Self {
let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx); let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = Self::load_data_collection_choices();
let data_collection_choice = cx.new(|_| data_collection_choice); let data_collection_choice = cx.new(|_| data_collection_choice);
@ -1649,7 +1649,6 @@ mod tests {
use http_client::FakeHttpClient; use http_client::FakeHttpClient;
use indoc::indoc; use indoc::indoc;
use language::Point; use language::Point;
use language_models::RefreshLlmTokenListener;
use rpc::proto; use rpc::proto;
use settings::SettingsStore; use settings::SettingsStore;