use serde::{Deserialize, Serialize}; use std::fmt::{self, Display}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq, Hash)] #[serde(rename_all = "lowercase")] pub enum Role { User, Assistant, System, } impl Role { pub fn from_proto(role: i32) -> Role { match proto::LanguageModelRole::from_i32(role) { Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, None => Role::User, } } pub fn to_proto(&self) -> proto::LanguageModelRole { match self { Role::User => proto::LanguageModelRole::LanguageModelUser, Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, Role::System => proto::LanguageModelRole::LanguageModelSystem, } } pub fn cycle(self) -> Role { match self { Role::User => Role::Assistant, Role::Assistant => Role::System, Role::System => Role::User, } } } impl Display for Role { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { match self { Role::User => write!(f, "user"), Role::Assistant => write!(f, "assistant"), Role::System => write!(f, "system"), } } }