diff --git a/client_test.go b/client_test.go index 27d2262..eb8adac 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net/http" + "reflect" "testing" "time" @@ -49,12 +50,12 @@ func TestUserAgentHeader(t *testing.T) { } func TestRetryAfter(t *testing.T) { - attempts := 0 + retryCountHeaders := make([]string, 0) client := openai.NewClient( option.WithHTTPClient(&http.Client{ Transport: &closureTransport{ fn: func(req *http.Request) (*http.Response, error) { - attempts++ + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) return &http.Response{ StatusCode: http.StatusTooManyRequests, Header: http.Header{ @@ -75,8 +76,85 @@ func TestRetryAfter(t *testing.T) { if err == nil || res != nil { t.Error("Expected there to be a cancel error and for the response to be nil") } - if want := 3; attempts != want { - t.Errorf("Expected %d attempts, got %d", want, attempts) + + attempts := len(retryCountHeaders) + if attempts != 3 { + t.Errorf("Expected %d attempts, got %d", 3, attempts) + } + + expectedRetryCountHeaders := []string{"0", "1", "2"} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) + } +} + +func TestDeleteRetryCountHeader(t *testing.T) { + retryCountHeaders := make([]string, 0) + client := openai.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + option.WithHeaderDel("X-Stainless-Retry-Count"), + ) + res, err := client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F([]openai.ChatCompletionContentPartUnionParam{openai.ChatCompletionContentPartTextParam{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}), + }}), + Model: openai.F(openai.ChatModelO1Preview), + }) + if err == nil || res != nil { + t.Error("Expected there to be a cancel error and for the response to be nil") + } + + expectedRetryCountHeaders := []string{"", "", ""} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) + } +} + +func TestOverwriteRetryCountHeader(t *testing.T) { + retryCountHeaders := make([]string, 0) + client := openai.NewClient( + option.WithHTTPClient(&http.Client{ + Transport: &closureTransport{ + fn: func(req *http.Request) (*http.Response, error) { + retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count")) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + http.CanonicalHeaderKey("Retry-After"): []string{"0.1"}, + }, + }, nil + }, + }, + }), + option.WithHeader("X-Stainless-Retry-Count", "42"), + ) + res, err := client.Chat.Completions.New(context.Background(), openai.ChatCompletionNewParams{ + Messages: openai.F([]openai.ChatCompletionMessageParamUnion{openai.ChatCompletionUserMessageParam{ + Role: openai.F(openai.ChatCompletionUserMessageParamRoleUser), + Content: openai.F([]openai.ChatCompletionContentPartUnionParam{openai.ChatCompletionContentPartTextParam{Text: openai.F("text"), Type: openai.F(openai.ChatCompletionContentPartTextTypeText)}}), + }}), + Model: openai.F(openai.ChatModelO1Preview), + }) + if err == nil || res != nil { + t.Error("Expected there to be a cancel error and for the response to be nil") + } + + expectedRetryCountHeaders := []string{"42", "42", "42"} + if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) { + t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders) } } diff --git a/internal/requestconfig/requestconfig.go b/internal/requestconfig/requestconfig.go index ae8e1e9..62e7956 100644 --- a/internal/requestconfig/requestconfig.go +++ b/internal/requestconfig/requestconfig.go @@ -137,6 +137,7 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa } req.Header.Set("Accept", "application/json") + req.Header.Set("X-Stainless-Retry-Count", "0") for k, v := range getDefaultHeaders() { req.Header.Add(k, v) } @@ -333,6 +334,9 @@ func (cfg *RequestConfig) Execute() (err error) { handler = applyMiddleware(cfg.Middlewares[i], handler) } + // Don't send the current retry count in the headers if the caller modified the header defaults. + shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0" + var res *http.Response for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 { ctx := cfg.Request.Context() @@ -342,7 +346,12 @@ func (cfg *RequestConfig) Execute() (err error) { defer cancel() } - res, err = handler(cfg.Request.Clone(ctx)) + req := cfg.Request.Clone(ctx) + if shouldSendRetryCount { + req.Header.Set("X-Stainless-Retry-Count", strconv.Itoa(retryCount)) + } + + res, err = handler(req) if ctx != nil && ctx.Err() != nil { return ctx.Err() }