diff --git a/azure/azure.go b/azure/azure.go index 62d0f4d..533fa19 100644 --- a/azure/azure.go +++ b/azure/azure.go @@ -26,6 +26,7 @@ import ( "mime/multipart" "net/http" "net/url" + "path" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -38,7 +39,7 @@ import ( // WithEndpoint configures this client to connect to an Azure OpenAI endpoint. // // - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://.openai.azure.com -// - apiVersion - the Azure OpenAI API version to target (ex: 2024-10-21). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty. +// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty. // // This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this: // @@ -70,7 +71,7 @@ func WithEndpoint(endpoint string, apiVersion string) option.RequestOption { return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error { if apiVersion == "" { - return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details") + return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.") } if err := withQueryAdd.Apply(rc); err != nil { @@ -183,8 +184,8 @@ func getReplacementPathWithDeployment(req *http.Request) (string, error) { return getMultipartRoute(req) } - // No need to relocate the path. We've already tacked on /openai when we setup the endpoint. - return req.URL.Path, nil + // If route doesn't require deployment ID substitution, just return path with prefix. + return path.Join("/openai/", req.URL.Path), nil } func getJSONRoute(req *http.Request) (string, error) { diff --git a/azure/azure_test.go b/azure/azure_test.go index aa163f5..c11028c 100644 --- a/azure/azure_test.go +++ b/azure/azure_test.go @@ -4,10 +4,12 @@ import ( "bytes" "mime/multipart" "net/http" + "net/url" "testing" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/internal/apijson" + "github.com/openai/openai-go/v3/internal/requestconfig" ) func TestJSONRoute(t *testing.T) { @@ -84,48 +86,121 @@ func TestGetAudioMultipartRoute(t *testing.T) { } } -func TestNoRouteChangeNeeded(t *testing.T) { - chatCompletionParams := openai.ChatCompletionNewParams{ - Model: openai.ChatModel("arbitraryDeployment"), - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.AssistantMessage("You are a helpful assistant"), - openai.UserMessage("Can you tell me another word for the universe?"), +func TestAPIKeyAuthentication(t *testing.T) { + rc := &requestconfig.RequestConfig{ + Request: &http.Request{ + Header: make(http.Header), + URL: &url.URL{}, }, } - serializedBytes, err := apijson.MarshalRoot(chatCompletionParams) + WithAPIKey("my-api-key").Apply(rc) - if err != nil { - t.Fatal(err) - } - - req, err := http.NewRequest("POST", "/openai/does/not/need/a/deployment", bytes.NewReader(serializedBytes)) - - if err != nil { - t.Fatal(err) - } - - replacementPath, err := getReplacementPathWithDeployment(req) - - if err != nil { - t.Fatal(err) - } - - if replacementPath != "/openai/does/not/need/a/deployment" { - t.Fatalf("replacementpath didn't match: %s", replacementPath) + if got := rc.Request.Header.Get("Api-Key"); got != "my-api-key" { + t.Errorf("Api-Key header: got %q, expected %q", got, "my-api-key") } } -func TestAPIKeyAuthentication(t *testing.T) { - // Test that the API key option is created successfully - apiKeyOption := WithAPIKey("test-api-key") +func TestJSONRoutePathConstruction(t *testing.T) { + cases := []struct { + path string + expected string + }{ + {"/chat/completions", "/openai/deployments/gpt-4/chat/completions"}, + {"/completions", "/openai/deployments/gpt-4/completions"}, + {"/embeddings", "/openai/deployments/gpt-4/embeddings"}, + {"/audio/speech", "/openai/deployments/gpt-4/audio/speech"}, + {"/images/generations", "/openai/deployments/gpt-4/images/generations"}, + {"/models", "/openai/models"}, // endpoint without a deployment + {"/files", "/openai/files"}, // endpoint without a deployment + } + for _, tc := range cases { + req, _ := http.NewRequest("POST", tc.path, bytes.NewReader([]byte(`{"model":"gpt-4"}`))) + got, _ := getReplacementPathWithDeployment(req) + if got != tc.expected { + t.Errorf("%s: got %q, expected %q", tc.path, got, tc.expected) + } + } +} - // Verify the option is not nil - if apiKeyOption == nil { - t.Fatal("Expected API key option to be created") +func TestModelWithSpecialCharsIsEscaped(t *testing.T) { + req, _ := http.NewRequest("POST", "/chat/completions", bytes.NewReader([]byte(`{"model":"my-model/v1"}`))) + got, _ := getReplacementPathWithDeployment(req) + + expected := "/openai/deployments/my-model%2Fv1/chat/completions" + if got != expected { + t.Errorf("got %q, expected %q", got, expected) + } +} + +func TestWithEndpointBaseURL(t *testing.T) { + tests := map[string]struct { + endpoint string + apiVersion string + expectedBaseURL string + expectedQuery string + shouldFail bool + }{ + "Azure endpoint": { + endpoint: "https://my-resource.openai.azure.com", + apiVersion: "2024-10-21", + expectedBaseURL: "https://my-resource.openai.azure.com/", + expectedQuery: "api-version=2024-10-21", + }, + "Azure endpoint with trailing slash": { + endpoint: "https://my-resource.openai.azure.com/", + apiVersion: "2024-10-21", + expectedBaseURL: "https://my-resource.openai.azure.com/", + expectedQuery: "api-version=2024-10-21", + }, + "Azure endpoint with path": { + endpoint: "https://my-resource.openai.azure.com/custom/path", + apiVersion: "2023-05-15", + expectedBaseURL: "https://my-resource.openai.azure.com/custom/path/", + expectedQuery: "api-version=2023-05-15", + }, + "empty apiVersion": { + endpoint: "https://my-resource.openai.azure.com", + apiVersion: "", + shouldFail: true, + }, } - // This test verifies the option is created correctly. - // The actual header setting happens in the middleware chain. - t.Log("API key option created successfully") + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + opt := WithEndpoint(tc.endpoint, tc.apiVersion) + + rc := &requestconfig.RequestConfig{ + Request: &http.Request{ + Header: make(http.Header), + URL: &url.URL{}, + }, + } + + err := opt.Apply(rc) + + if tc.shouldFail { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("WithEndpoint returned error: %v", err) + } + + if rc.BaseURL == nil { + t.Fatal("BaseURL was not set") + } + if rc.BaseURL.String() != tc.expectedBaseURL { + t.Errorf("BaseURL: got %q, expected %q", rc.BaseURL.String(), tc.expectedBaseURL) + } + + query := rc.Request.URL.RawQuery + if query != tc.expectedQuery { + t.Errorf("Query: got %q, expected %q", query, tc.expectedQuery) + } + }) + } } diff --git a/examples/azure/main.go b/examples/azure/main.go new file mode 100644 index 0000000..c747694 --- /dev/null +++ b/examples/azure/main.go @@ -0,0 +1,37 @@ +package main + +import ( + "context" + "os" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/azure" + "github.com/openai/openai-go/responses" +) + +func main() { + apiKey := os.Getenv("AZURE_OPENAI_API_KEY") + apiVersion := "2025-03-01-preview" + endpoint := "https://example-endpoint.openai.azure.com" + deploymentName := "model-name" // e.g. "gpt-4o" + + client := openai.NewClient( + azure.WithEndpoint(endpoint, apiVersion), + azure.WithAPIKey(apiKey), + ) + + ctx := context.Background() + + question := "Write me a haiku about computers" + + resp, err := client.Responses.New(ctx, responses.ResponseNewParams{ + Input: responses.ResponseNewParamsInputUnion{OfString: openai.String(question)}, + Model: openai.ChatModel(deploymentName), + }) + + if err != nil { + panic(err) + } + + println(resp.OutputText()) +}