
This PR adds more data to the `GetAuthenticatedUserResponse`. We now return more information about the authenticated user, as well as their plan information. Release Notes: - N/A
370 lines
11 KiB
Rust
370 lines
11 KiB
Rust
use std::str::FromStr;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::Context as _;
|
|
use serde::{Deserialize, Serialize};
|
|
use strum::{Display, EnumIter, EnumString};
|
|
use uuid::Uuid;
|
|
|
|
/// The name of the header used to indicate which version of Zed the client is running.
|
|
pub const ZED_VERSION_HEADER_NAME: &str = "x-zed-version";
|
|
|
|
/// The name of the header used to indicate when a request failed due to an
|
|
/// expired LLM token.
|
|
///
|
|
/// The client may use this as a signal to refresh the token.
|
|
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
|
|
|
|
/// The name of the header used to indicate what plan the user is currently on.
|
|
pub const CURRENT_PLAN_HEADER_NAME: &str = "x-zed-plan";
|
|
|
|
/// The name of the header used to indicate the usage limit for model requests.
|
|
pub const MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-model-requests-usage-limit";
|
|
|
|
/// The name of the header used to indicate the usage amount for model requests.
|
|
pub const MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-model-requests-usage-amount";
|
|
|
|
/// The name of the header used to indicate the usage limit for edit predictions.
|
|
pub const EDIT_PREDICTIONS_USAGE_LIMIT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-limit";
|
|
|
|
/// The name of the header used to indicate the usage amount for edit predictions.
|
|
pub const EDIT_PREDICTIONS_USAGE_AMOUNT_HEADER_NAME: &str = "x-zed-edit-predictions-usage-amount";
|
|
|
|
/// The name of the header used to indicate the resource for which the subscription limit has been reached.
|
|
pub const SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME: &str = "x-zed-subscription-limit-resource";
|
|
|
|
pub const MODEL_REQUESTS_RESOURCE_HEADER_VALUE: &str = "model_requests";
|
|
pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions";
|
|
|
|
/// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached.
|
|
pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached";
|
|
|
|
/// The name of the header used to indicate the the minimum required Zed version.
|
|
///
|
|
/// This can be used to force a Zed upgrade in order to continue communicating
|
|
/// with the LLM service.
|
|
pub const MINIMUM_REQUIRED_VERSION_HEADER_NAME: &str = "x-zed-minimum-required-version";
|
|
|
|
/// The name of the header used by the client to indicate to the server that it supports receiving status messages.
|
|
pub const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
|
"x-zed-client-supports-status-messages";
|
|
|
|
/// The name of the header used by the server to indicate to the client that it supports sending status messages.
|
|
pub const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME: &str =
|
|
"x-zed-server-supports-status-messages";
|
|
|
|
#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum UsageLimit {
|
|
Limited(i32),
|
|
Unlimited,
|
|
}
|
|
|
|
impl FromStr for UsageLimit {
|
|
type Err = anyhow::Error;
|
|
|
|
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
|
match value {
|
|
"unlimited" => Ok(Self::Unlimited),
|
|
limit => limit
|
|
.parse::<i32>()
|
|
.map(Self::Limited)
|
|
.context("failed to parse limit"),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum Plan {
|
|
#[default]
|
|
#[serde(alias = "Free")]
|
|
ZedFree,
|
|
#[serde(alias = "ZedPro")]
|
|
ZedPro,
|
|
#[serde(alias = "ZedProTrial")]
|
|
ZedProTrial,
|
|
}
|
|
|
|
impl Plan {
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
Plan::ZedFree => "zed_free",
|
|
Plan::ZedPro => "zed_pro",
|
|
Plan::ZedProTrial => "zed_pro_trial",
|
|
}
|
|
}
|
|
|
|
pub fn model_requests_limit(&self) -> UsageLimit {
|
|
match self {
|
|
Plan::ZedPro => UsageLimit::Limited(500),
|
|
Plan::ZedProTrial => UsageLimit::Limited(150),
|
|
Plan::ZedFree => UsageLimit::Limited(50),
|
|
}
|
|
}
|
|
|
|
pub fn edit_predictions_limit(&self) -> UsageLimit {
|
|
match self {
|
|
Plan::ZedPro => UsageLimit::Unlimited,
|
|
Plan::ZedProTrial => UsageLimit::Unlimited,
|
|
Plan::ZedFree => UsageLimit::Limited(2_000),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl FromStr for Plan {
|
|
type Err = anyhow::Error;
|
|
|
|
fn from_str(value: &str) -> Result<Self, Self::Err> {
|
|
match value {
|
|
"zed_free" => Ok(Plan::ZedFree),
|
|
"zed_pro" => Ok(Plan::ZedPro),
|
|
"zed_pro_trial" => Ok(Plan::ZedProTrial),
|
|
plan => Err(anyhow::anyhow!("invalid plan: {plan:?}")),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(
|
|
Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize, EnumString, EnumIter, Display,
|
|
)]
|
|
#[serde(rename_all = "snake_case")]
|
|
#[strum(serialize_all = "snake_case")]
|
|
pub enum LanguageModelProvider {
|
|
Anthropic,
|
|
OpenAi,
|
|
Google,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct PredictEditsBody {
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub outline: Option<String>,
|
|
pub input_events: String,
|
|
pub input_excerpt: String,
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub speculated_output: Option<String>,
|
|
/// Whether the user provided consent for sampling this interaction.
|
|
#[serde(default, alias = "data_collection_permission")]
|
|
pub can_collect_data: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct PredictEditsResponse {
|
|
pub request_id: Uuid,
|
|
pub output_excerpt: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AcceptEditPredictionBody {
|
|
pub request_id: Uuid,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum CompletionMode {
|
|
Normal,
|
|
Max,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum CompletionIntent {
|
|
UserPrompt,
|
|
ToolResults,
|
|
ThreadSummarization,
|
|
ThreadContextSummarization,
|
|
CreateFile,
|
|
EditFile,
|
|
InlineAssist,
|
|
TerminalInlineAssist,
|
|
GenerateGitCommitMessage,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct CompletionBody {
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub thread_id: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub prompt_id: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub intent: Option<CompletionIntent>,
|
|
#[serde(skip_serializing_if = "Option::is_none", default)]
|
|
pub mode: Option<CompletionMode>,
|
|
pub provider: LanguageModelProvider,
|
|
pub model: String,
|
|
pub provider_request: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum CompletionRequestStatus {
|
|
Queued {
|
|
position: usize,
|
|
},
|
|
Started,
|
|
Failed {
|
|
code: String,
|
|
message: String,
|
|
request_id: Uuid,
|
|
/// Retry duration in seconds.
|
|
retry_after: Option<f64>,
|
|
},
|
|
UsageUpdated {
|
|
amount: usize,
|
|
limit: UsageLimit,
|
|
},
|
|
ToolUseLimitReached,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum CompletionEvent<T> {
|
|
Status(CompletionRequestStatus),
|
|
Event(T),
|
|
}
|
|
|
|
impl<T> CompletionEvent<T> {
|
|
pub fn into_status(self) -> Option<CompletionRequestStatus> {
|
|
match self {
|
|
Self::Status(status) => Some(status),
|
|
Self::Event(_) => None,
|
|
}
|
|
}
|
|
|
|
pub fn into_event(self) -> Option<T> {
|
|
match self {
|
|
Self::Event(event) => Some(event),
|
|
Self::Status(_) => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct WebSearchBody {
|
|
pub query: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone)]
|
|
pub struct WebSearchResponse {
|
|
pub results: Vec<WebSearchResult>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone)]
|
|
pub struct WebSearchResult {
|
|
pub title: String,
|
|
pub url: String,
|
|
pub text: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct CountTokensBody {
|
|
pub provider: LanguageModelProvider,
|
|
pub model: String,
|
|
pub provider_request: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
pub struct CountTokensResponse {
|
|
pub tokens: usize,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
|
|
pub struct LanguageModelId(pub Arc<str>);
|
|
|
|
impl std::fmt::Display for LanguageModelId {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.0)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
pub struct LanguageModel {
|
|
pub provider: LanguageModelProvider,
|
|
pub id: LanguageModelId,
|
|
pub display_name: String,
|
|
pub max_token_count: usize,
|
|
pub max_token_count_in_max_mode: Option<usize>,
|
|
pub max_output_tokens: usize,
|
|
pub supports_tools: bool,
|
|
pub supports_images: bool,
|
|
pub supports_thinking: bool,
|
|
pub supports_max_mode: bool,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ListModelsResponse {
|
|
pub models: Vec<LanguageModel>,
|
|
pub default_model: LanguageModelId,
|
|
pub default_fast_model: LanguageModelId,
|
|
pub recommended_models: Vec<LanguageModelId>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct GetSubscriptionResponse {
|
|
pub plan: Plan,
|
|
pub usage: Option<CurrentUsage>,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct CurrentUsage {
|
|
pub model_requests: UsageData,
|
|
pub edit_predictions: UsageData,
|
|
}
|
|
|
|
#[derive(Debug, PartialEq, Serialize, Deserialize)]
|
|
pub struct UsageData {
|
|
pub used: u32,
|
|
pub limit: UsageLimit,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::json;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_plan_deserialize_snake_case() {
|
|
let plan = serde_json::from_value::<Plan>(json!("zed_free")).unwrap();
|
|
assert_eq!(plan, Plan::ZedFree);
|
|
|
|
let plan = serde_json::from_value::<Plan>(json!("zed_pro")).unwrap();
|
|
assert_eq!(plan, Plan::ZedPro);
|
|
|
|
let plan = serde_json::from_value::<Plan>(json!("zed_pro_trial")).unwrap();
|
|
assert_eq!(plan, Plan::ZedProTrial);
|
|
}
|
|
|
|
#[test]
|
|
fn test_plan_deserialize_aliases() {
|
|
let plan = serde_json::from_value::<Plan>(json!("Free")).unwrap();
|
|
assert_eq!(plan, Plan::ZedFree);
|
|
|
|
let plan = serde_json::from_value::<Plan>(json!("ZedPro")).unwrap();
|
|
assert_eq!(plan, Plan::ZedPro);
|
|
|
|
let plan = serde_json::from_value::<Plan>(json!("ZedProTrial")).unwrap();
|
|
assert_eq!(plan, Plan::ZedProTrial);
|
|
}
|
|
|
|
#[test]
|
|
fn test_usage_limit_from_str() {
|
|
let limit = UsageLimit::from_str("unlimited").unwrap();
|
|
assert!(matches!(limit, UsageLimit::Unlimited));
|
|
|
|
let limit = UsageLimit::from_str(&0.to_string()).unwrap();
|
|
assert!(matches!(limit, UsageLimit::Limited(0)));
|
|
|
|
let limit = UsageLimit::from_str(&50.to_string()).unwrap();
|
|
assert!(matches!(limit, UsageLimit::Limited(50)));
|
|
|
|
for value in ["not_a_number", "50xyz"] {
|
|
let limit = UsageLimit::from_str(value);
|
|
assert!(limit.is_err());
|
|
}
|
|
}
|
|
}
|