mirror of
https://github.com/openai/openai-go.git
synced 2026-04-01 17:17:14 +09:00
fixes to azure URL construction
This commit is contained in:
committed by
stainless-app[bot]
parent
11b7e1aee6
commit
b0a7c5397b
@@ -26,6 +26,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
@@ -38,7 +39,7 @@ import (
|
||||
// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
|
||||
//
|
||||
// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
|
||||
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-10-21). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
|
||||
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
|
||||
//
|
||||
// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
|
||||
//
|
||||
@@ -70,7 +71,7 @@ func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
|
||||
|
||||
return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
|
||||
if apiVersion == "" {
|
||||
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details")
|
||||
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
|
||||
}
|
||||
|
||||
if err := withQueryAdd.Apply(rc); err != nil {
|
||||
@@ -183,8 +184,8 @@ func getReplacementPathWithDeployment(req *http.Request) (string, error) {
|
||||
return getMultipartRoute(req)
|
||||
}
|
||||
|
||||
// No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
|
||||
return req.URL.Path, nil
|
||||
// If route doesn't require deployment ID substitution, just return path with prefix.
|
||||
return path.Join("/openai/", req.URL.Path), nil
|
||||
}
|
||||
|
||||
func getJSONRoute(req *http.Request) (string, error) {
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/internal/apijson"
|
||||
"github.com/openai/openai-go/v3/internal/requestconfig"
|
||||
)
|
||||
|
||||
func TestJSONRoute(t *testing.T) {
|
||||
@@ -84,48 +86,121 @@ func TestGetAudioMultipartRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoRouteChangeNeeded(t *testing.T) {
|
||||
chatCompletionParams := openai.ChatCompletionNewParams{
|
||||
Model: openai.ChatModel("arbitraryDeployment"),
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.AssistantMessage("You are a helpful assistant"),
|
||||
openai.UserMessage("Can you tell me another word for the universe?"),
|
||||
func TestAPIKeyAuthentication(t *testing.T) {
|
||||
rc := &requestconfig.RequestConfig{
|
||||
Request: &http.Request{
|
||||
Header: make(http.Header),
|
||||
URL: &url.URL{},
|
||||
},
|
||||
}
|
||||
|
||||
serializedBytes, err := apijson.MarshalRoot(chatCompletionParams)
|
||||
WithAPIKey("my-api-key").Apply(rc)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
replacementPath, err := getReplacementPathWithDeployment(req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if replacementPath != "/openai/does/not/need/a/deployment" {
|
||||
t.Fatalf("replacementpath didn't match: %s", replacementPath)
|
||||
if got := rc.Request.Header.Get("Api-Key"); got != "my-api-key" {
|
||||
t.Errorf("Api-Key header: got %q, expected %q", got, "my-api-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthentication(t *testing.T) {
|
||||
// Test that the API key option is created successfully
|
||||
apiKeyOption := WithAPIKey("test-api-key")
|
||||
func TestJSONRoutePathConstruction(t *testing.T) {
|
||||
cases := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"/chat/completions", "/openai/deployments/gpt-4/chat/completions"},
|
||||
{"/completions", "/openai/deployments/gpt-4/completions"},
|
||||
{"/embeddings", "/openai/deployments/gpt-4/embeddings"},
|
||||
{"/audio/speech", "/openai/deployments/gpt-4/audio/speech"},
|
||||
{"/images/generations", "/openai/deployments/gpt-4/images/generations"},
|
||||
{"/models", "/openai/models"}, // endpoint without a deployment
|
||||
{"/files", "/openai/files"}, // endpoint without a deployment
|
||||
}
|
||||
for _, tc := range cases {
|
||||
req, _ := http.NewRequest("POST", tc.path, bytes.NewReader([]byte(`{"model":"gpt-4"}`)))
|
||||
got, _ := getReplacementPathWithDeployment(req)
|
||||
if got != tc.expected {
|
||||
t.Errorf("%s: got %q, expected %q", tc.path, got, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the option is not nil
|
||||
if apiKeyOption == nil {
|
||||
t.Fatal("Expected API key option to be created")
|
||||
func TestModelWithSpecialCharsIsEscaped(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/chat/completions", bytes.NewReader([]byte(`{"model":"my-model/v1"}`)))
|
||||
got, _ := getReplacementPathWithDeployment(req)
|
||||
|
||||
expected := "/openai/deployments/my-model%2Fv1/chat/completions"
|
||||
if got != expected {
|
||||
t.Errorf("got %q, expected %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithEndpointBaseURL(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
endpoint string
|
||||
apiVersion string
|
||||
expectedBaseURL string
|
||||
expectedQuery string
|
||||
shouldFail bool
|
||||
}{
|
||||
"Azure endpoint": {
|
||||
endpoint: "https://my-resource.openai.azure.com",
|
||||
apiVersion: "2024-10-21",
|
||||
expectedBaseURL: "https://my-resource.openai.azure.com/",
|
||||
expectedQuery: "api-version=2024-10-21",
|
||||
},
|
||||
"Azure endpoint with trailing slash": {
|
||||
endpoint: "https://my-resource.openai.azure.com/",
|
||||
apiVersion: "2024-10-21",
|
||||
expectedBaseURL: "https://my-resource.openai.azure.com/",
|
||||
expectedQuery: "api-version=2024-10-21",
|
||||
},
|
||||
"Azure endpoint with path": {
|
||||
endpoint: "https://my-resource.openai.azure.com/custom/path",
|
||||
apiVersion: "2023-05-15",
|
||||
expectedBaseURL: "https://my-resource.openai.azure.com/custom/path/",
|
||||
expectedQuery: "api-version=2023-05-15",
|
||||
},
|
||||
"empty apiVersion": {
|
||||
endpoint: "https://my-resource.openai.azure.com",
|
||||
apiVersion: "",
|
||||
shouldFail: true,
|
||||
},
|
||||
}
|
||||
|
||||
// This test verifies the option is created correctly.
|
||||
// The actual header setting happens in the middleware chain.
|
||||
t.Log("API key option created successfully")
|
||||
for name, tc := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
opt := WithEndpoint(tc.endpoint, tc.apiVersion)
|
||||
|
||||
rc := &requestconfig.RequestConfig{
|
||||
Request: &http.Request{
|
||||
Header: make(http.Header),
|
||||
URL: &url.URL{},
|
||||
},
|
||||
}
|
||||
|
||||
err := opt.Apply(rc)
|
||||
|
||||
if tc.shouldFail {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("WithEndpoint returned error: %v", err)
|
||||
}
|
||||
|
||||
if rc.BaseURL == nil {
|
||||
t.Fatal("BaseURL was not set")
|
||||
}
|
||||
if rc.BaseURL.String() != tc.expectedBaseURL {
|
||||
t.Errorf("BaseURL: got %q, expected %q", rc.BaseURL.String(), tc.expectedBaseURL)
|
||||
}
|
||||
|
||||
query := rc.Request.URL.RawQuery
|
||||
if query != tc.expectedQuery {
|
||||
t.Errorf("Query: got %q, expected %q", query, tc.expectedQuery)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
37
examples/azure/main.go
Normal file
37
examples/azure/main.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/azure"
|
||||
"github.com/openai/openai-go/responses"
|
||||
)
|
||||
|
||||
func main() {
|
||||
apiKey := os.Getenv("AZURE_OPENAI_API_KEY")
|
||||
apiVersion := "2025-03-01-preview"
|
||||
endpoint := "https://example-endpoint.openai.azure.com"
|
||||
deploymentName := "model-name" // e.g. "gpt-4o"
|
||||
|
||||
client := openai.NewClient(
|
||||
azure.WithEndpoint(endpoint, apiVersion),
|
||||
azure.WithAPIKey(apiKey),
|
||||
)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
question := "Write me a haiku about computers"
|
||||
|
||||
resp, err := client.Responses.New(ctx, responses.ResponseNewParams{
|
||||
Input: responses.ResponseNewParamsInputUnion{OfString: openai.String(question)},
|
||||
Model: openai.ChatModel(deploymentName),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
println(resp.OutputText())
|
||||
}
|
||||
Reference in New Issue
Block a user