Allow OpenAI API URL to be configured via assistant.openai_api_url
(#7552)
Partially fixes #4321, since Azure OpenAI API can be converted to OpenAI API. Release Notes: - Added `assistant.openai_api_url` setting to allow OpenAI API URL to be configured. --------- Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
This commit is contained in:
parent
d959719f3e
commit
9e17018416
7 changed files with 60 additions and 12 deletions
|
@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent {
|
|||
}
|
||||
|
||||
pub async fn stream_completion(
|
||||
api_url: String,
|
||||
credential: ProviderCredential,
|
||||
executor: BackgroundExecutor,
|
||||
request: Box<dyn CompletionRequest>,
|
||||
|
@ -117,7 +118,7 @@ pub async fn stream_completion(
|
|||
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
|
||||
|
||||
let json_data = request.data()?;
|
||||
let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
|
||||
let mut response = Request::post(format!("{api_url}/chat/completions"))
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.body(json_data)?
|
||||
|
@ -195,18 +196,20 @@ pub async fn stream_completion(
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiCompletionProvider {
|
||||
api_url: String,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
impl OpenAiCompletionProvider {
|
||||
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
|
||||
pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
|
||||
let model = executor
|
||||
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
|
||||
.await;
|
||||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
Self {
|
||||
api_url,
|
||||
model,
|
||||
credential,
|
||||
executor,
|
||||
|
@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
|
|||
// which is currently model based, due to the language model.
|
||||
// At some point in the future we should rectify this.
|
||||
let credential = self.credential.read().clone();
|
||||
let request = stream_completion(credential, self.executor.clone(), prompt);
|
||||
let api_url = self.api_url.clone();
|
||||
let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
|
||||
async move {
|
||||
let response = request.await?;
|
||||
let stream = response
|
||||
|
|
|
@ -35,6 +35,7 @@ lazy_static! {
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct OpenAiEmbeddingProvider {
|
||||
api_url: String,
|
||||
model: OpenAiLanguageModel,
|
||||
credential: Arc<RwLock<ProviderCredential>>,
|
||||
pub client: Arc<dyn HttpClient>,
|
||||
|
@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage {
|
|||
}
|
||||
|
||||
impl OpenAiEmbeddingProvider {
|
||||
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
|
||||
pub async fn new(
|
||||
api_url: String,
|
||||
client: Arc<dyn HttpClient>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Self {
|
||||
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
|
||||
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
|
||||
|
||||
|
@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider {
|
|||
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
|
||||
|
||||
OpenAiEmbeddingProvider {
|
||||
api_url,
|
||||
model,
|
||||
credential,
|
||||
client,
|
||||
|
@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider {
|
|||
}
|
||||
async fn send_request(
|
||||
&self,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
spans: Vec<&str>,
|
||||
request_timeout: u64,
|
||||
) -> Result<Response<AsyncBody>> {
|
||||
let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings"))
|
||||
let request = Request::post(format!("{api_url}/embeddings"))
|
||||
.redirect_policy(isahc::config::RedirectPolicy::Follow)
|
||||
.timeout(Duration::from_secs(request_timeout))
|
||||
.header("Content-Type", "application/json")
|
||||
|
@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
|||
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
|
||||
const MAX_RETRIES: usize = 4;
|
||||
|
||||
let api_url = self.api_url.as_str();
|
||||
let api_key = self.get_api_key()?;
|
||||
|
||||
let mut request_number = 0;
|
||||
|
@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
|
|||
while request_number < MAX_RETRIES {
|
||||
response = self
|
||||
.send_request(
|
||||
&api_url,
|
||||
&api_key,
|
||||
spans.iter().map(|x| &**x).collect(),
|
||||
request_timeout,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue