fixes to azure URL construction

This commit is contained in:
renfred
2025-12-11 00:06:46 -08:00
committed by stainless-app[bot]
parent 11b7e1aee6
commit b0a7c5397b
3 changed files with 151 additions and 38 deletions

View File

@@ -26,6 +26,7 @@ import (
"mime/multipart"
"net/http"
"net/url"
"path"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
@@ -38,7 +39,7 @@ import (
// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
//
// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-10-21). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
//
// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
//
@@ -70,7 +71,7 @@ func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error {
if apiVersion == "" {
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details")
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
}
if err := withQueryAdd.Apply(rc); err != nil {
@@ -183,8 +184,8 @@ func getReplacementPathWithDeployment(req *http.Request) (string, error) {
return getMultipartRoute(req)
}
// No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
return req.URL.Path, nil
// If route doesn't require deployment ID substitution, just return path with prefix.
return path.Join("/openai/", req.URL.Path), nil
}
func getJSONRoute(req *http.Request) (string, error) {

View File

@@ -4,10 +4,12 @@ import (
"bytes"
"mime/multipart"
"net/http"
"net/url"
"testing"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/internal/apijson"
"github.com/openai/openai-go/v3/internal/requestconfig"
)
func TestJSONRoute(t *testing.T) {
@@ -84,48 +86,121 @@ func TestGetAudioMultipartRoute(t *testing.T) {
}
}
func TestNoRouteChangeNeeded(t *testing.T) {
chatCompletionParams := openai.ChatCompletionNewParams{
Model: openai.ChatModel("arbitraryDeployment"),
Messages: []openai.ChatCompletionMessageParamUnion{
openai.AssistantMessage("You are a helpful assistant"),
openai.UserMessage("Can you tell me another word for the universe?"),
func TestAPIKeyAuthentication(t *testing.T) {
rc := &requestconfig.RequestConfig{
Request: &http.Request{
Header: make(http.Header),
URL: &url.URL{},
},
}
serializedBytes, err := apijson.MarshalRoot(chatCompletionParams)
WithAPIKey("my-api-key").Apply(rc)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes))
if err != nil {
t.Fatal(err)
}
replacementPath, err := getReplacementPathWithDeployment(req)
if err != nil {
t.Fatal(err)
}
if replacementPath != "/openai/does/not/need/a/deployment" {
t.Fatalf("replacementpath didn't match: %s", replacementPath)
if got := rc.Request.Header.Get("Api-Key"); got != "my-api-key" {
t.Errorf("Api-Key header: got %q, expected %q", got, "my-api-key")
}
}
func TestAPIKeyAuthentication(t *testing.T) {
// Test that the API key option is created successfully
apiKeyOption := WithAPIKey("test-api-key")
func TestJSONRoutePathConstruction(t *testing.T) {
cases := []struct {
path string
expected string
}{
{"/chat/completions", "/openai/deployments/gpt-4/chat/completions"},
{"/completions", "/openai/deployments/gpt-4/completions"},
{"/embeddings", "/openai/deployments/gpt-4/embeddings"},
{"/audio/speech", "/openai/deployments/gpt-4/audio/speech"},
{"/images/generations", "/openai/deployments/gpt-4/images/generations"},
{"/models", "/openai/models"}, // endpoint without a deployment
{"/files", "/openai/files"}, // endpoint without a deployment
}
for _, tc := range cases {
req, _ := http.NewRequest("POST", tc.path, bytes.NewReader([]byte(`{"model":"gpt-4"}`)))
got, _ := getReplacementPathWithDeployment(req)
if got != tc.expected {
t.Errorf("%s: got %q, expected %q", tc.path, got, tc.expected)
}
}
}
// Verify the option is not nil
if apiKeyOption == nil {
t.Fatal("Expected API key option to be created")
func TestModelWithSpecialCharsIsEscaped(t *testing.T) {
req, _ := http.NewRequest("POST", "/chat/completions", bytes.NewReader([]byte(`{"model":"my-model/v1"}`)))
got, _ := getReplacementPathWithDeployment(req)
expected := "/openai/deployments/my-model%2Fv1/chat/completions"
if got != expected {
t.Errorf("got %q, expected %q", got, expected)
}
}
func TestWithEndpointBaseURL(t *testing.T) {
tests := map[string]struct {
endpoint string
apiVersion string
expectedBaseURL string
expectedQuery string
shouldFail bool
}{
"Azure endpoint": {
endpoint: "https://my-resource.openai.azure.com",
apiVersion: "2024-10-21",
expectedBaseURL: "https://my-resource.openai.azure.com/",
expectedQuery: "api-version=2024-10-21",
},
"Azure endpoint with trailing slash": {
endpoint: "https://my-resource.openai.azure.com/",
apiVersion: "2024-10-21",
expectedBaseURL: "https://my-resource.openai.azure.com/",
expectedQuery: "api-version=2024-10-21",
},
"Azure endpoint with path": {
endpoint: "https://my-resource.openai.azure.com/custom/path",
apiVersion: "2023-05-15",
expectedBaseURL: "https://my-resource.openai.azure.com/custom/path/",
expectedQuery: "api-version=2023-05-15",
},
"empty apiVersion": {
endpoint: "https://my-resource.openai.azure.com",
apiVersion: "",
shouldFail: true,
},
}
// This test verifies the option is created correctly.
// The actual header setting happens in the middleware chain.
t.Log("API key option created successfully")
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
opt := WithEndpoint(tc.endpoint, tc.apiVersion)
rc := &requestconfig.RequestConfig{
Request: &http.Request{
Header: make(http.Header),
URL: &url.URL{},
},
}
err := opt.Apply(rc)
if tc.shouldFail {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("WithEndpoint returned error: %v", err)
}
if rc.BaseURL == nil {
t.Fatal("BaseURL was not set")
}
if rc.BaseURL.String() != tc.expectedBaseURL {
t.Errorf("BaseURL: got %q, expected %q", rc.BaseURL.String(), tc.expectedBaseURL)
}
query := rc.Request.URL.RawQuery
if query != tc.expectedQuery {
t.Errorf("Query: got %q, expected %q", query, tc.expectedQuery)
}
})
}
}

37
examples/azure/main.go Normal file
View File

@@ -0,0 +1,37 @@
package main
import (
"context"
"os"
"github.com/openai/openai-go"
"github.com/openai/openai-go/azure"
"github.com/openai/openai-go/responses"
)
func main() {
apiKey := os.Getenv("AZURE_OPENAI_API_KEY")
apiVersion := "2025-03-01-preview"
endpoint := "https://example-endpoint.openai.azure.com"
deploymentName := "model-name" // e.g. "gpt-4o"
client := openai.NewClient(
azure.WithEndpoint(endpoint, apiVersion),
azure.WithAPIKey(apiKey),
)
ctx := context.Background()
question := "Write me a haiku about computers"
resp, err := client.Responses.New(ctx, responses.ResponseNewParams{
Input: responses.ResponseNewParamsInputUnion{OfString: openai.String(question)},
Model: openai.ChatModel(deploymentName),
})
if err != nil {
panic(err)
}
println(resp.OutputText())
}