mod authorization; pub mod db; mod token; use crate::api::CloudflareIpCountryHeader; use crate::llm::authorization::authorize_access_to_language_model; use crate::llm::db::LlmDatabase; use crate::{executor::Executor, Config, Error, Result}; use anyhow::{anyhow, Context as _}; use axum::TypedHeader; use axum::{ body::Body, http::{self, HeaderName, HeaderValue, Request, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, routing::post, Extension, Json, Router, }; use futures::StreamExt as _; use http_client::IsahcHttpClient; use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; use std::sync::Arc; pub use token::*; pub struct LlmState { pub config: Config, pub executor: Executor, pub db: Option>, pub http_client: IsahcHttpClient, } impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { // TODO: This is temporary until we have the LLM database stood up. let db = if config.is_development() { let database_url = config .llm_database_url .as_ref() .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; let max_connections = config .llm_database_max_connections .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; let mut db_options = db::ConnectOptions::new(database_url); db_options.max_connections(max_connections); let db = LlmDatabase::new(db_options, executor.clone()).await?; Some(Arc::new(db)) } else { None }; let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); let http_client = IsahcHttpClient::builder() .default_header("User-Agent", user_agent) .build() .context("failed to construct http client")?; let this = Self { config, executor, db, http_client, }; Ok(Arc::new(this)) } } pub fn routes() -> Router<(), Body> { Router::new() .route("/completion", post(perform_completion)) .layer(middleware::from_fn(validate_api_token)) } async fn validate_api_token(mut req: Request, next: Next) -> impl IntoResponse { let token = req .headers() .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()) .ok_or_else(|| { Error::http( StatusCode::BAD_REQUEST, "missing authorization header".to_string(), ) })? .strip_prefix("Bearer ") .ok_or_else(|| { Error::http( StatusCode::BAD_REQUEST, "invalid authorization header".to_string(), ) })?; let state = req.extensions().get::>().unwrap(); match LlmTokenClaims::validate(&token, &state.config) { Ok(claims) => { req.extensions_mut().insert(claims); Ok::<_, Error>(next.run(req).await.into_response()) } Err(ValidateLlmTokenError::Expired) => Err(Error::Http( StatusCode::UNAUTHORIZED, "unauthorized".to_string(), [( HeaderName::from_static(EXPIRED_LLM_TOKEN_HEADER_NAME), HeaderValue::from_static("true"), )] .into_iter() .collect(), )), Err(_err) => Err(Error::http( StatusCode::UNAUTHORIZED, "unauthorized".to_string(), )), } } async fn perform_completion( Extension(state): Extension>, Extension(claims): Extension, country_code_header: Option>, Json(params): Json, ) -> Result { authorize_access_to_language_model( &state.config, &claims, country_code_header.map(|header| header.to_string()), params.provider, ¶ms.model, )?; match params.provider { LanguageModelProvider::Anthropic => { let api_key = state .config .anthropic_api_key .as_ref() .context("no Anthropic AI API key configured on the server")?; let mut request: anthropic::Request = serde_json::from_str(¶ms.provider_request.get())?; // Parse the model, throw away the version that was included, and then set a specific // version that we control on the server. // Right now, we use the version that's defined in `model.id()`, but we will likely // want to change this code once a new version of an Anthropic model is released, // so that users can use the new version, without having to update Zed. request.model = match anthropic::Model::from_id(&request.model) { Ok(model) => model.id().to_string(), Err(_) => request.model, }; let chunks = anthropic::stream_completion( &state.http_client, anthropic::ANTHROPIC_API_URL, api_key, request, None, ) .await?; let stream = chunks.map(|event| { let mut buffer = Vec::new(); event.map(|chunk| { buffer.clear(); serde_json::to_writer(&mut buffer, &chunk).unwrap(); buffer.push(b'\n'); buffer }) }); Ok(Response::new(Body::wrap_stream(stream))) } LanguageModelProvider::OpenAi => { let api_key = state .config .openai_api_key .as_ref() .context("no OpenAI API key configured on the server")?; let chunks = open_ai::stream_completion( &state.http_client, open_ai::OPEN_AI_API_URL, api_key, serde_json::from_str(¶ms.provider_request.get())?, None, ) .await?; let stream = chunks.map(|event| { let mut buffer = Vec::new(); event.map(|chunk| { buffer.clear(); serde_json::to_writer(&mut buffer, &chunk).unwrap(); buffer.push(b'\n'); buffer }) }); Ok(Response::new(Body::wrap_stream(stream))) } LanguageModelProvider::Google => { let api_key = state .config .google_ai_api_key .as_ref() .context("no Google AI API key configured on the server")?; let chunks = google_ai::stream_generate_content( &state.http_client, google_ai::API_URL, api_key, serde_json::from_str(¶ms.provider_request.get())?, ) .await?; let stream = chunks.map(|event| { let mut buffer = Vec::new(); event.map(|chunk| { buffer.clear(); serde_json::to_writer(&mut buffer, &chunk).unwrap(); buffer.push(b'\n'); buffer }) }); Ok(Response::new(Body::wrap_stream(stream))) } LanguageModelProvider::Zed => { let api_key = state .config .qwen2_7b_api_key .as_ref() .context("no Qwen2-7B API key configured on the server")?; let api_url = state .config .qwen2_7b_api_url .as_ref() .context("no Qwen2-7B URL configured on the server")?; let chunks = open_ai::stream_completion( &state.http_client, &api_url, api_key, serde_json::from_str(¶ms.provider_request.get())?, None, ) .await?; let stream = chunks.map(|event| { let mut buffer = Vec::new(); event.map(|chunk| { buffer.clear(); serde_json::to_writer(&mut buffer, &chunk).unwrap(); buffer.push(b'\n'); buffer }) }); Ok(Response::new(Body::wrap_stream(stream))) } } }