Simplify LLM protocol (#15366)

In this pull request, we change the zed.dev protocol so that we pass the
raw JSON for the specified provider directly to our server. This avoids
the need to define a protobuf message that's a superset of all these
formats.

@bennetbo: We also changed the settings for available_models under
zed.dev to be a flat format, because the nesting seemed too confusing.
Can you help us upgrade the local provider configuration to be
consistent with this? We do whatever we need to do when parsing the
settings to make this simple for users, even if it's a bit more complex
on our end. We want to use versioning to avoid breaking existing users,
but need to keep making progress.

```json
"zed.dev": {
  "available_models": [
    {
      "provider": "anthropic",
        "name": "some-newly-released-model-we-havent-added",
        "max_tokens": 200000
      }
  ]
}
```

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-28 11:07:10 +02:00 committed by GitHub
parent e0fe7f632c
commit d6bdaa8a91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 896 additions and 2154 deletions

View file

@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
@ -98,7 +98,7 @@ impl From<Role> for String {
}
}
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
@ -113,7 +113,7 @@ pub struct RequestMessage {
pub content: String,
}
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
@ -138,7 +138,7 @@ pub enum ResponseEvent {
MessageStop {},
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
@ -151,19 +151,19 @@ pub struct ResponseMessage {
pub usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
@ -226,6 +226,25 @@ pub async fn stream_completion(
}
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseEvent>>,
) -> impl Stream<Item = Result<String>> {
response.filter_map(|response| async move {
match response {
Ok(response) => match response {
ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
ContentBlock::Text { text } => Some(Ok(text)),
},
ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
TextDelta::TextDelta { text } => Some(Ok(text)),
},
_ => None,
},
Err(error) => Some(Err(error)),
}
})
}
// #[cfg(test)]
// mod tests {
// use super::*;