diff --git a/azure/azure.go b/azure/azure.go index 354081c..f82e451 100644 --- a/azure/azure.go +++ b/azure/azure.go @@ -91,30 +91,60 @@ func WithEndpoint(endpoint string, apiVersion string) option.RequestOption { }) } +type tokenCredentialConfig struct { + Scopes []string +} + +// TokenCredentialOption is the type for any options that can be used to customize +// [WithTokenCredential], including things like using custom scopes. +type TokenCredentialOption func(*tokenCredentialConfig) error + +// WithTokenCredentialScopes overrides the default scope used when requesting access tokens. +func WithTokenCredentialScopes(scopes []string) func(*tokenCredentialConfig) error { + return func(tc *tokenCredentialConfig) error { + tc.Scopes = scopes + return nil + } +} + // WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential. // This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance. // // [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity -func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption { - bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil) - - // add in a middleware that uses the bearer token generated from the token credential - return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{ - InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc.. - PerRetryPolicies: []policy.Policy{ - bearerTokenPolicy, - policyAdapter(next), - }, - }) - - req2, err := runtime.NewRequestFromRequest(req) - - if err != nil { - return nil, err +func WithTokenCredential(tokenCredential azcore.TokenCredential, options ...TokenCredentialOption) option.RequestOption { + return requestconfig.RequestOptionFunc(func(rc *requestconfig.RequestConfig) error { + tc := &tokenCredentialConfig{ + Scopes: []string{"https://cognitiveservices.azure.com/.default"}, } - return pipeline.Do(req2) + for _, option := range options { + if err := option(tc); err != nil { + return err + } + } + + bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, tc.Scopes, nil) + + // add in a middleware that uses the bearer token generated from the token credential + middlewareOption := option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{ + InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc.. + PerRetryPolicies: []policy.Policy{ + bearerTokenPolicy, + policyAdapter(next), + }, + }) + + req2, err := runtime.NewRequestFromRequest(req) + + if err != nil { + return nil, err + } + + return pipeline.Do(req2) + }) + + return middlewareOption.Apply(rc) }) } diff --git a/azure/example_test.go b/azure/example_test.go index 211d6de..057f94e 100644 --- a/azure/example_test.go +++ b/azure/example_test.go @@ -45,3 +45,29 @@ func Example_authentication() { _ = client } } + +func Example_authentication_custom_scopes() { + // Custom scopes can also be passed, if needed, when using Azure OpenAI endpoints. + const azureOpenAIEndpoint = "https://.openai.azure.com" + const azureOpenAIAPIVersion = "" + + // For a full list of credential types look at the documentation for the Azure Identity + // package: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity + tokenCredential, err := azidentity.NewDefaultAzureCredential(nil) + + if err != nil { + fmt.Printf("Failed to create TokenCredential: %s\n", err) + return + } + + client := openai.NewClient( + azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion), + azure.WithTokenCredential(tokenCredential, + // This is an example of a custom scope. See documentation for your service + // endpoint for the proper value to pass. + azure.WithTokenCredentialScopes([]string{"your-custom-scope"}), + ), + ) + + _ = client +}