diff --git a/proxy/direct.go b/proxy/direct.go index 26b51c34..3d66bdef 100644 --- a/proxy/direct.go +++ b/proxy/direct.go @@ -11,9 +11,14 @@ import ( type direct struct{} -// Direct is a direct proxy: one that makes network connections directly. +// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext. var Direct = direct{} +var ( + _ Dialer = Direct + _ ContextDialer = Direct +) + // Dial directly invokes net.Dial with the supplied parameters. func (direct) Dial(network, addr string) (net.Conn, error) { return net.Dial(network, addr) diff --git a/proxy/proxy.go b/proxy/proxy.go index 37d3cabd..9ff4b9a7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -26,21 +26,30 @@ type Auth struct { User, Password string } -// FromEnvironment returns the dialer specified by the proxy related variables in -// the environment. +// FromEnvironment returns the dialer specified by the proxy-related +// variables in the environment and makes underlying connections +// directly. func FromEnvironment() Dialer { + return FromEnvironmentUsing(Direct) +} + +// FromEnvironmentUsing returns the dialer specify by the proxy-related +// variables in the environment and makes underlying connections +// using the provided forwarding Dialer (for instance, a *net.Dialer +// with desired configuration). +func FromEnvironmentUsing(forward Dialer) Dialer { allProxy := allProxyEnv.Get() if len(allProxy) == 0 { - return Direct + return forward } proxyURL, err := url.Parse(allProxy) if err != nil { - return Direct + return forward } - proxy, err := FromURL(proxyURL, Direct) + proxy, err := FromURL(proxyURL, forward) if err != nil { - return Direct + return forward } noProxy := noProxyEnv.Get() @@ -48,7 +57,7 @@ func FromEnvironment() Dialer { return proxy } - perHost := NewPerHost(proxy, Direct) + perHost := NewPerHost(proxy, forward) perHost.AddFromString(noProxy) return perHost } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index d260d699..567fc9c3 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -6,7 +6,10 @@ package proxy import ( "bytes" + "context" + "errors" "fmt" + "net" "net/url" "os" "strings" @@ -110,6 +113,37 @@ func TestSOCKS5(t *testing.T) { c.Close() } +type funcFailDialer func(context.Context) error + +func (f funcFailDialer) Dial(net, addr string) (net.Conn, error) { + panic("shouldn't see a call to Dial") +} + +func (f funcFailDialer) DialContext(ctx context.Context, net, addr string) (net.Conn, error) { + return nil, f(ctx) +} + +// Check that FromEnvironmentUsing uses our dialer. +func TestFromEnvironmentUsing(t *testing.T) { + ResetProxyEnv() + errFoo := errors.New("some error to check our dialer was used)") + type key string + ctx := context.WithValue(context.Background(), key("foo"), "bar") + dialer := FromEnvironmentUsing(funcFailDialer(func(ctx context.Context) error { + if got := ctx.Value(key("foo")); got != "bar" { + t.Errorf("Resolver context = %T %v, want %q", got, got, "bar") + } + return errFoo + })) + _, err := dialer.(ContextDialer).DialContext(ctx, "tcp", "foo.tld:123") + if err == nil { + t.Fatalf("unexpected success") + } + if !strings.Contains(err.Error(), errFoo.Error()) { + t.Errorf("got unexpected error %q; want substr %q", err, errFoo) + } +} + func ResetProxyEnv() { for _, env := range []*envOnce{allProxyEnv, noProxyEnv} { for _, v := range env.names {