use anyhow::{Result, anyhow}; use futures::{ AsyncBufReadExt, AsyncReadExt, io::BufReader, stream::{BoxStream, StreamExt}, }; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::convert::TryFrom; pub const DEEPSEEK_API_URL: &str = "https://api.deepseek.com"; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum Role { User, Assistant, System, Tool, } impl TryFrom for Role { type Error = anyhow::Error; fn try_from(value: String) -> Result { match value.as_str() { "user" => Ok(Self::User), "assistant" => Ok(Self::Assistant), "system" => Ok(Self::System), "tool" => Ok(Self::Tool), _ => anyhow::bail!("invalid role '{value}'"), } } } impl From for String { fn from(val: Role) -> Self { match val { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), Role::System => "system".to_owned(), Role::Tool => "tool".to_owned(), } } } #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] pub enum Model { #[serde(rename = "deepseek-chat")] #[default] Chat, #[serde(rename = "deepseek-reasoner")] Reasoner, #[serde(rename = "custom")] Custom { name: String, /// The name displayed in the UI, such as in the assistant panel model dropdown menu. display_name: Option, max_tokens: u64, max_output_tokens: Option, }, } impl Model { pub fn default_fast() -> Self { Model::Chat } pub fn from_id(id: &str) -> Result { match id { "deepseek-chat" => Ok(Self::Chat), "deepseek-reasoner" => Ok(Self::Reasoner), _ => anyhow::bail!("invalid model id {id}"), } } pub fn id(&self) -> &str { match self { Self::Chat => "deepseek-chat", Self::Reasoner => "deepseek-reasoner", Self::Custom { name, .. } => name, } } pub fn display_name(&self) -> &str { match self { Self::Chat => "DeepSeek Chat", Self::Reasoner => "DeepSeek Reasoner", Self::Custom { name, display_name, .. } => display_name.as_ref().unwrap_or(name).as_str(), } } pub fn max_token_count(&self) -> u64 { match self { Self::Chat | Self::Reasoner => 64_000, Self::Custom { max_tokens, .. } => *max_tokens, } } pub fn max_output_tokens(&self) -> Option { match self { Self::Chat => Some(8_192), Self::Reasoner => Some(8_192), Self::Custom { max_output_tokens, .. } => *max_output_tokens, } } } #[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub response_format: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ResponseFormat { Text, #[serde(rename = "json_object")] JsonObject, } #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { Function { function: FunctionDefinition }, } #[derive(Debug, Serialize, Deserialize)] pub struct FunctionDefinition { pub name: String, pub description: Option, pub parameters: Option, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { content: String, }, System { content: String, }, Tool { content: String, tool_call_id: String, }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, #[serde(flatten)] pub content: ToolCallContent, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] pub enum ToolCallContent { Function { function: FunctionContent }, } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct FunctionContent { pub name: String, pub arguments: String, } #[derive(Serialize, Deserialize, Debug)] pub struct Response { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, pub usage: Usage, #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u64, pub completion_tokens: u64, pub total_tokens: u64, #[serde(default)] pub prompt_cache_hit_tokens: u64, #[serde(default)] pub prompt_cache_miss_tokens: u64, } #[derive(Serialize, Deserialize, Debug)] pub struct Choice { pub index: u32, pub message: RequestMessage, pub finish_reason: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct StreamResponse { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, pub usage: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct StreamChoice { pub index: u32, pub delta: StreamDelta, pub finish_reason: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct StreamDelta { pub role: Option, pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct ToolCallChunk { pub index: usize, pub id: Option, pub function: Option, } #[derive(Serialize, Deserialize, Debug)] pub struct FunctionChunk { pub name: Option, pub arguments: Option, } pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, ) -> Result>> { let uri = format!("{api_url}/v1/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)); let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader .lines() .filter_map(|line| async move { match line { Ok(line) => { let line = line.strip_prefix("data: ")?; if line == "[DONE]" { None } else { match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), Err(error) => Some(Err(anyhow!(error))), } } } Err(error) => Some(Err(anyhow!(error))), } }) .boxed()) } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; anyhow::bail!( "Failed to connect to DeepSeek API: {} {}", response.status(), body, ); } }