Simplify LLM protocol (#15366)
In this pull request, we change the zed.dev protocol so that we pass the raw JSON for the specified provider directly to our server. This avoids the need to define a protobuf message that's a superset of all these formats. @bennetbo: We also changed the settings for available_models under zed.dev to be a flat format, because the nesting seemed too confusing. Can you help us upgrade the local provider configuration to be consistent with this? We do whatever we need to do when parsing the settings to make this simple for users, even if it's a bit more complex on our end. We want to use versioning to avoid breaking existing users, but need to keep making progress. ```json "zed.dev": { "available_models": [ { "provider": "anthropic", "name": "some-newly-released-model-we-havent-added", "max_tokens": 200000 } ] } ``` Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
e0fe7f632c
commit
d6bdaa8a91
31 changed files with 896 additions and 2154 deletions
|
@ -11,9 +11,14 @@ workspace = true
|
|||
[lib]
|
||||
path = "src/google_ai.rs"
|
||||
|
||||
[features]
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
|
||||
use http_client::HttpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
client: Arc<dyn HttpClient>,
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
model: &str,
|
||||
request: GenerateContentRequest,
|
||||
mut request: GenerateContentRequest,
|
||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||
let uri = format!(
|
||||
"{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
|
||||
api_url, api_key
|
||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
||||
model = request.model
|
||||
);
|
||||
request.model.clear();
|
||||
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
|
@ -52,8 +50,8 @@ pub async fn stream_generate_content(
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn count_tokens<T: HttpClient>(
|
||||
client: &T,
|
||||
pub async fn count_tokens(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CountTokensRequest,
|
||||
|
@ -91,22 +89,24 @@ pub enum Task {
|
|||
BatchEmbedContents,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
pub model: String,
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponse {
|
||||
pub candidates: Option<Vec<GenerateContentCandidate>>,
|
||||
pub prompt_feedback: Option<PromptFeedback>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentCandidate {
|
||||
pub index: usize,
|
||||
|
@ -157,7 +157,7 @@ pub struct GenerativeContentBlob {
|
|||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationSource {
|
||||
pub start_index: Option<usize>,
|
||||
|
@ -166,13 +166,13 @@ pub struct CitationSource {
|
|||
pub license: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CitationMetadata {
|
||||
pub citation_sources: Vec<CitationSource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptFeedback {
|
||||
pub block_reason: Option<String>,
|
||||
|
@ -180,7 +180,7 @@ pub struct PromptFeedback {
|
|||
pub block_reason_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
pub candidate_count: Option<usize>,
|
||||
|
@ -191,7 +191,7 @@ pub struct GenerationConfig {
|
|||
pub top_k: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
pub category: HarmCategory,
|
||||
|
@ -224,7 +224,7 @@ pub enum HarmCategory {
|
|||
DangerousContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockThreshold {
|
||||
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
|
@ -238,7 +238,7 @@ pub enum HarmBlockThreshold {
|
|||
BlockNone,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum HarmProbability {
|
||||
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
|
||||
|
@ -249,21 +249,85 @@ pub enum HarmProbability {
|
|||
High,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: HarmCategory,
|
||||
pub probability: HarmProbability,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Vec<Content>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
Gemini15Flash,
|
||||
#[serde(rename = "custom")]
|
||||
Custom { name: String, max_tokens: usize },
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "gemini-1.5-pro",
|
||||
Model::Gemini15Flash => "gemini-1.5-flash",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Model::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Model::Custom { name, .. } => name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
match self {
|
||||
Model::Gemini15Pro => 2_000_000,
|
||||
Model::Gemini15Flash => 1_000_000,
|
||||
Model::Custom { max_tokens, .. } => *max_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Model {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.id())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_text_from_events(
|
||||
events: impl Stream<Item = Result<GenerateContentResponse>>,
|
||||
) -> impl Stream<Item = Result<String>> {
|
||||
events.filter_map(|event| async move {
|
||||
match event {
|
||||
Ok(event) => event.candidates.and_then(|candidates| {
|
||||
candidates.into_iter().next().and_then(|candidate| {
|
||||
candidate.content.parts.into_iter().next().and_then(|part| {
|
||||
if let Part::TextPart(TextPart { text }) = part {
|
||||
Some(Ok(text))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
}),
|
||||
Err(error) => Some(Err(error)),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue