language_models: Count Google AI tokens through LLM service (#29319)
This PR wires the counting of Google AI tokens back up. It now goes through the LLM service instead of collab's RPC. Still only available for Zed staff. Release Notes: - N/A
This commit is contained in:
parent
8b5835de17
commit
fef2681cfa
3 changed files with 58 additions and 7 deletions
|
@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
|
|||
use thiserror::Error;
|
||||
use ui::{TintColor, prelude::*};
|
||||
use zed_llm_client::{
|
||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
|
||||
MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
|
||||
SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
|
||||
EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
|
||||
MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
|
||||
};
|
||||
|
||||
use crate::AllLanguageModelSettings;
|
||||
|
@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
|
|||
match self.model.clone() {
|
||||
CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
|
||||
CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
|
||||
CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
let request_builder =
|
||||
if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
|
||||
request_builder.uri(completions_url)
|
||||
} else {
|
||||
request_builder.uri(
|
||||
http_client
|
||||
.build_zed_llm_url("/count_tokens", &[])?
|
||||
.as_ref(),
|
||||
)
|
||||
};
|
||||
let request_body = CountTokensBody {
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: model.id().into(),
|
||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
})?,
|
||||
};
|
||||
let request = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.body(serde_json::to_string(&request_body)?.into())?;
|
||||
let mut response = http_client.send(request).await?;
|
||||
let status = response.status();
|
||||
let mut response_body = String::new();
|
||||
response
|
||||
.body_mut()
|
||||
.read_to_string(&mut response_body)
|
||||
.await?;
|
||||
|
||||
if status.is_success() {
|
||||
let response_body: CountTokensResponse =
|
||||
serde_json::from_str(&response_body)?;
|
||||
|
||||
Ok(response_body.tokens)
|
||||
} else {
|
||||
Err(anyhow!(ApiError {
|
||||
status,
|
||||
body: response_body
|
||||
}))
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue