Remove code paths that skip LLM db in prod (#16008)
Release Notes: - N/A
This commit is contained in:
parent
c1872e9cb0
commit
225726ba4a
2 changed files with 33 additions and 52 deletions
|
@ -27,7 +27,7 @@ pub use token::*;
|
||||||
pub struct LlmState {
|
pub struct LlmState {
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
pub executor: Executor,
|
pub executor: Executor,
|
||||||
pub db: Option<Arc<LlmDatabase>>,
|
pub db: Arc<LlmDatabase>,
|
||||||
pub http_client: IsahcHttpClient,
|
pub http_client: IsahcHttpClient,
|
||||||
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
|
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
|
||||||
}
|
}
|
||||||
|
@ -36,25 +36,20 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
|
||||||
|
|
||||||
impl LlmState {
|
impl LlmState {
|
||||||
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
|
||||||
// TODO: This is temporary until we have the LLM database stood up.
|
let database_url = config
|
||||||
let db = if config.is_development() {
|
.llm_database_url
|
||||||
let database_url = config
|
.as_ref()
|
||||||
.llm_database_url
|
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
||||||
.as_ref()
|
let max_connections = config
|
||||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
|
.llm_database_max_connections
|
||||||
let max_connections = config
|
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
||||||
.llm_database_max_connections
|
|
||||||
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
|
|
||||||
|
|
||||||
let mut db_options = db::ConnectOptions::new(database_url);
|
let mut db_options = db::ConnectOptions::new(database_url);
|
||||||
db_options.max_connections(max_connections);
|
db_options.max_connections(max_connections);
|
||||||
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
|
||||||
db.initialize().await?;
|
db.initialize().await?;
|
||||||
|
|
||||||
Some(Arc::new(db))
|
let db = Arc::new(db);
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
|
||||||
let http_client = IsahcHttpClient::builder()
|
let http_client = IsahcHttpClient::builder()
|
||||||
|
@ -62,11 +57,8 @@ impl LlmState {
|
||||||
.build()
|
.build()
|
||||||
.context("failed to construct http client")?;
|
.context("failed to construct http client")?;
|
||||||
|
|
||||||
let initial_active_user_count = if let Some(db) = &db {
|
let initial_active_user_count =
|
||||||
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
|
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let this = Self {
|
let this = Self {
|
||||||
config,
|
config,
|
||||||
|
@ -88,14 +80,10 @@ impl LlmState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(db) = &self.db {
|
let mut cache = self.active_user_count.write().await;
|
||||||
let mut cache = self.active_user_count.write().await;
|
let new_count = self.db.get_active_user_count(now).await?;
|
||||||
let new_count = db.get_active_user_count(now).await?;
|
*cache = Some((now, new_count));
|
||||||
*cache = Some((now, new_count));
|
Ok(new_count)
|
||||||
Ok(new_count)
|
|
||||||
} else {
|
|
||||||
Ok(ActiveUserCount::default())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,9 +153,7 @@ async fn perform_completion(
|
||||||
|
|
||||||
let user_id = claims.user_id as i32;
|
let user_id = claims.user_id as i32;
|
||||||
|
|
||||||
if state.db.is_some() {
|
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
||||||
check_usage_limit(&state, params.provider, &model, &claims).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
match params.provider {
|
match params.provider {
|
||||||
LanguageModelProvider::Anthropic => {
|
LanguageModelProvider::Anthropic => {
|
||||||
|
@ -199,14 +185,14 @@ async fn perform_completion(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut recorder = state.db.clone().map(|db| UsageRecorder {
|
let mut recorder = UsageRecorder {
|
||||||
db,
|
db: state.db.clone(),
|
||||||
executor: state.executor.clone(),
|
executor: state.executor.clone(),
|
||||||
user_id,
|
user_id,
|
||||||
provider: params.provider,
|
provider: params.provider,
|
||||||
model,
|
model,
|
||||||
token_count: 0,
|
token_count: 0,
|
||||||
});
|
};
|
||||||
|
|
||||||
let stream = chunks.map(move |event| {
|
let stream = chunks.map(move |event| {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
@ -216,10 +202,8 @@ async fn perform_completion(
|
||||||
message: anthropic::Response { usage, .. },
|
message: anthropic::Response { usage, .. },
|
||||||
}
|
}
|
||||||
| anthropic::Event::MessageDelta { usage, .. } => {
|
| anthropic::Event::MessageDelta { usage, .. } => {
|
||||||
if let Some(recorder) = &mut recorder {
|
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
|
||||||
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
|
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
|
||||||
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -349,12 +333,9 @@ async fn check_usage_limit(
|
||||||
model_name: &str,
|
model_name: &str,
|
||||||
claims: &LlmTokenClaims,
|
claims: &LlmTokenClaims,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let db = state
|
let model = state.db.model(provider, model_name)?;
|
||||||
|
let usage = state
|
||||||
.db
|
.db
|
||||||
.as_ref()
|
|
||||||
.ok_or_else(|| anyhow!("LLM database not configured"))?;
|
|
||||||
let model = db.model(provider, model_name)?;
|
|
||||||
let usage = db
|
|
||||||
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
|
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
|
@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn setup_llm_database(config: &Config) -> Result<()> {
|
async fn setup_llm_database(config: &Config) -> Result<()> {
|
||||||
// TODO: This is temporary until we have the LLM database stood up.
|
|
||||||
if !config.is_development() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let database_url = config
|
let database_url = config
|
||||||
.llm_database_url
|
.llm_database_url
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
@ -298,7 +293,12 @@ async fn handle_liveness_probe(
|
||||||
state.db.get_all_users(0, 1).await?;
|
state.db.get_all_users(0, 1).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(_llm_state) = llm_state {}
|
if let Some(llm_state) = llm_state {
|
||||||
|
llm_state
|
||||||
|
.db
|
||||||
|
.get_active_user_count(chrono::Utc::now())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok("ok".to_string())
|
Ok("ok".to_string())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue