package httpcache import ( "bytes" "errors" "flag" "io" "io/ioutil" "net/http" "net/http/httptest" "os" "strconv" "testing" "time" ) var s struct { server *httptest.Server client http.Client transport *Transport done chan struct{} // Closed to unlock infinite handlers. } type fakeClock struct { elapsed time.Duration } func (c *fakeClock) since(t time.Time) time.Duration { return c.elapsed } func TestMain(m *testing.M) { flag.Parse() setup() code := m.Run() teardown() os.Exit(code) } func setup() { tp := NewMemoryCacheTransport() client := http.Client{Transport: tp} s.transport = tp s.client = client s.done = make(chan struct{}) mux := http.NewServeMux() s.server = httptest.NewServer(mux) mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") })) mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Write([]byte(r.Method)) })) mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lm := "Fri, 14 Dec 2010 01:01:50 GMT" if r.Header.Get("if-modified-since") == lm { w.WriteHeader(http.StatusNotModified) return } w.Header().Set("last-modified", lm) if r.Header.Get("range") == "bytes=4-9" { w.WriteHeader(http.StatusPartialContent) w.Write([]byte(" text ")) return } w.Write([]byte("Some text content")) })) mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") })) mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { etag := "124567" if r.Header.Get("if-none-match") == etag { w.WriteHeader(http.StatusNotModified) return } w.Header().Set("etag", etag) })) mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lm := "Fri, 14 Dec 2010 01:01:50 GMT" if r.Header.Get("if-modified-since") == lm { w.WriteHeader(http.StatusNotModified) return } w.Header().Set("last-modified", lm) })) mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "Accept") w.Write([]byte("Some text content")) })) mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "Accept, Accept-Language") w.Write([]byte("Some text content")) })) mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Add("Vary", "Accept") w.Header().Add("Vary", "Accept-Language") w.Write([]byte("Some text content")) })) mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "max-age=3600") w.Header().Set("Content-Type", "text/plain") w.Header().Set("Vary", "X-Madeup-Header") w.Write([]byte("Some text content")) })) mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { etag := "abc" if r.Header.Get("if-none-match") == etag { w.WriteHeader(http.StatusNotModified) return } w.Header().Set("etag", etag) w.WriteHeader(http.StatusNotFound) w.Write([]byte("Not found")) })) updateFieldsCounter := 0 mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) w.Header().Set("Etag", `"e"`) updateFieldsCounter++ if r.Header.Get("if-none-match") != "" { w.WriteHeader(http.StatusNotModified) return } w.Write([]byte("Some text content")) })) // Take 3 seconds to return 200 OK (for testing client timeouts). mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(3 * time.Second) })) mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for { select { case <-s.done: return default: w.Write([]byte{0}) } } })) } func teardown() { close(s.done) s.server.Close() } func resetTest() { s.transport.Cache = NewMemoryCache() clock = &realClock{} } // TestCacheableMethod ensures that uncacheable method does not get stored // in cache and get incorrectly used for a following cacheable method request. func TestCacheableMethod(t *testing.T) { resetTest() { req, err := http.NewRequest("POST", s.server.URL+"/method", nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), "POST"; got != want { t.Errorf("got %q, want %q", got, want) } if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } } { req, err := http.NewRequest("GET", s.server.URL+"/method", nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), "GET"; got != want { t.Errorf("got wrong body %q, want %q", got, want) } if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } if resp.Header.Get(XFromCache) != "" { t.Errorf("XFromCache header isn't blank") } } } func TestDontServeHeadResponseToGetRequest(t *testing.T) { resetTest() url := s.server.URL + "/" req, err := http.NewRequest(http.MethodHead, url, nil) if err != nil { t.Fatal(err) } _, err = s.client.Do(req) if err != nil { t.Fatal(err) } req, err = http.NewRequest(http.MethodGet, url, nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } if resp.Header.Get(XFromCache) != "" { t.Errorf("Cache should not match") } } func TestDontStorePartialRangeInCache(t *testing.T) { resetTest() { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) if err != nil { t.Fatal(err) } req.Header.Set("range", "bytes=4-9") resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), " text "; got != want { t.Errorf("got %q, want %q", got, want) } if resp.StatusCode != http.StatusPartialContent { t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) } } { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), "Some text content"; got != want { t.Errorf("got %q, want %q", got, want) } if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } if resp.Header.Get(XFromCache) != "" { t.Error("XFromCache header isn't blank") } } { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), "Some text content"; got != want { t.Errorf("got %q, want %q", got, want) } if resp.StatusCode != http.StatusOK { t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) } if resp.Header.Get(XFromCache) != "1" { t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) if err != nil { t.Fatal(err) } req.Header.Set("range", "bytes=4-9") resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } var buf bytes.Buffer _, err = io.Copy(&buf, resp.Body) if err != nil { t.Fatal(err) } err = resp.Body.Close() if err != nil { t.Fatal(err) } if got, want := buf.String(), " text "; got != want { t.Errorf("got %q, want %q", got, want) } if resp.StatusCode != http.StatusPartialContent { t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) } } } func TestCacheOnlyIfBodyRead(t *testing.T) { resetTest() { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } // We do not read the body resp.Body.Close() } { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatalf("XFromCache header isn't blank") } } } func TestOnlyReadBodyOnDemand(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) // This shouldn't hang forever. if err != nil { t.Fatal(err) } buf := make([]byte, 10) // Only partially read the body. _, err = resp.Body.Read(buf) if err != nil { t.Fatal(err) } resp.Body.Close() } func TestGetOnlyIfCachedHit(t *testing.T) { resetTest() { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } if resp.StatusCode != http.StatusOK { t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) } } } func TestGetOnlyIfCachedMiss(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } req.Header.Add("cache-control", "only-if-cached") resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } if resp.StatusCode != http.StatusGatewayTimeout { t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) } } func TestGetNoStoreRequest(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) } req.Header.Add("Cache-Control", "no-store") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } } func TestGetNoStoreResponse(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) if err != nil { t.Fatal(err) } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } } func TestGetWithEtag(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) if err != nil { t.Fatal(err) } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } // additional assertions to verify that 304 response is converted properly if resp.StatusCode != http.StatusOK { t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) } if _, ok := resp.Header["Connection"]; ok { t.Fatalf("Connection header isn't absent") } } } func TestGetWithLastModified(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) if err != nil { t.Fatal(err) } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } } func TestGetWithVary(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) if err != nil { t.Fatal(err) } req.Header.Set("Accept", "text/plain") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get("Vary") != "Accept" { t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } req.Header.Set("Accept", "text/html") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } req.Header.Set("Accept", "") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } } func TestGetWithDoubleVary(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) if err != nil { t.Fatal(err) } req.Header.Set("Accept", "text/plain") req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get("Vary") == "" { t.Fatalf(`Vary header is blank`) } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } req.Header.Set("Accept-Language", "") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } req.Header.Set("Accept-Language", "da") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } } func TestGetWith2VaryHeaders(t *testing.T) { resetTest() // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( accept = "text/plain" acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" ) req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) if err != nil { t.Fatal(err) } req.Header.Set("Accept", accept) req.Header.Set("Accept-Language", acceptLanguage) { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get("Vary") == "" { t.Fatalf(`Vary header is blank`) } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } req.Header.Set("Accept-Language", "") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } req.Header.Set("Accept-Language", "da") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } req.Header.Set("Accept-Language", acceptLanguage) req.Header.Set("Accept", "") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } } req.Header.Set("Accept", "image/png") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "" { t.Fatal("XFromCache header isn't blank") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } } func TestGetVaryUnused(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) if err != nil { t.Fatal(err) } req.Header.Set("Accept", "text/plain") { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get("Vary") == "" { t.Fatalf(`Vary header is blank`) } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } } } func TestUpdateFields(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) if err != nil { t.Fatal(err) } var counter, counter2 string { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() counter = resp.Header.Get("x-counter") _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.Header.Get(XFromCache) != "1" { t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) } counter2 = resp.Header.Get("x-counter") } if counter == counter2 { t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) } } // This tests the fix for https://github.com/gregjones/httpcache/issues/74. // Previously, after validating a cached response, its StatusCode // was incorrectly being replaced. func TestCachedErrorsKeepStatus(t *testing.T) { resetTest() req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) if err != nil { t.Fatal(err) } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() io.Copy(ioutil.Discard, resp.Body) } { resp, err := s.client.Do(req) if err != nil { t.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusNotFound { t.Fatalf("Status code isn't 404: %d", resp.StatusCode) } } } func TestParseCacheControl(t *testing.T) { resetTest() h := http.Header{} for range parseCacheControl(h) { t.Fatal("cacheControl should be empty") } h.Set("cache-control", "no-cache") { cc := parseCacheControl(h) if _, ok := cc["foo"]; ok { t.Error(`Value "foo" shouldn't exist`) } noCache, ok := cc["no-cache"] if !ok { t.Fatalf(`"no-cache" value isn't set`) } if noCache != "" { t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) } } h.Set("cache-control", "no-cache, max-age=3600") { cc := parseCacheControl(h) noCache, ok := cc["no-cache"] if !ok { t.Fatalf(`"no-cache" value isn't set`) } if noCache != "" { t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) } if cc["max-age"] != "3600" { t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) } } } func TestNoCacheRequestExpiration(t *testing.T) { resetTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "max-age=7200") reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "no-cache") if getFreshness(respHeaders, reqHeaders) != transparent { t.Fatal("freshness isn't transparent") } } func TestNoCacheResponseExpiration(t *testing.T) { resetTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "no-cache") respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") reqHeaders := http.Header{} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestReqMustRevalidate(t *testing.T) { resetTest() // not paying attention to request setting max-stale means never returning stale // responses, so always acting as if must-revalidate is set respHeaders := http.Header{} reqHeaders := http.Header{} reqHeaders.Set("Cache-Control", "must-revalidate") if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestRespMustRevalidate(t *testing.T) { resetTest() respHeaders := http.Header{} respHeaders.Set("Cache-Control", "must-revalidate") reqHeaders := http.Header{} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestFreshExpiration(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) reqHeaders := http.Header{} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } clock = &fakeClock{elapsed: 3 * time.Second} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestMaxAge(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=2") reqHeaders := http.Header{} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } clock = &fakeClock{elapsed: 3 * time.Second} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestMaxAgeZero(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=0") reqHeaders := http.Header{} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestBothMaxAge(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=2") reqHeaders := http.Header{} reqHeaders.Set("cache-control", "max-age=0") if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestMinFreshWithExpires(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) reqHeaders := http.Header{} reqHeaders.Set("cache-control", "min-fresh=1") if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } reqHeaders = http.Header{} reqHeaders.Set("cache-control", "min-fresh=2") if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func TestEmptyMaxStale(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=20") reqHeaders := http.Header{} reqHeaders.Set("cache-control", "max-stale") clock = &fakeClock{elapsed: 10 * time.Second} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } clock = &fakeClock{elapsed: 60 * time.Second} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } } func TestMaxStaleValue(t *testing.T) { resetTest() now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) respHeaders.Set("cache-control", "max-age=10") reqHeaders := http.Header{} reqHeaders.Set("cache-control", "max-stale=20") clock = &fakeClock{elapsed: 5 * time.Second} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } clock = &fakeClock{elapsed: 15 * time.Second} if getFreshness(respHeaders, reqHeaders) != fresh { t.Fatal("freshness isn't fresh") } clock = &fakeClock{elapsed: 30 * time.Second} if getFreshness(respHeaders, reqHeaders) != stale { t.Fatal("freshness isn't stale") } } func containsHeader(headers []string, header string) bool { for _, v := range headers { if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { return true } } return false } func TestGetEndToEndHeaders(t *testing.T) { resetTest() var ( headers http.Header end2end []string ) headers = http.Header{} headers.Set("content-type", "text/html") headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) if !containsHeader(end2end, "content-type") { t.Fatal(`doesn't contain "content-type" header`) } if containsHeader(end2end, "te") { t.Fatal(`doesn't contain "te" header`) } headers = http.Header{} headers.Set("connection", "content-type") headers.Set("content-type", "text/csv") headers.Set("te", "deflate") end2end = getEndToEndHeaders(headers) if containsHeader(end2end, "connection") { t.Fatal(`doesn't contain "connection" header`) } if containsHeader(end2end, "content-type") { t.Fatal(`doesn't contain "content-type" header`) } if containsHeader(end2end, "te") { t.Fatal(`doesn't contain "te" header`) } headers = http.Header{} end2end = getEndToEndHeaders(headers) if len(end2end) != 0 { t.Fatal(`non-zero end2end headers`) } headers = http.Header{} headers.Set("connection", "content-type") end2end = getEndToEndHeaders(headers) if len(end2end) != 0 { t.Fatal(`non-zero end2end headers`) } } type transportMock struct { response *http.Response err error } func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { return t.response, t.err } func TestStaleIfErrorRequest(t *testing.T) { resetTest() now := time.Now() tmock := transportMock{ response: &http.Response{ Status: http.StatusText(http.StatusOK), StatusCode: http.StatusOK, Header: http.Header{ "Date": []string{now.Format(time.RFC1123)}, "Cache-Control": []string{"no-cache"}, }, Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), }, err: nil, } tp := NewMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) r.Header.Set("Cache-Control", "stale-if-error") resp, err := tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } // On failure, response is returned from the cache tmock.response = nil tmock.err = errors.New("some error") resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } } func TestStaleIfErrorRequestLifetime(t *testing.T) { resetTest() now := time.Now() tmock := transportMock{ response: &http.Response{ Status: http.StatusText(http.StatusOK), StatusCode: http.StatusOK, Header: http.Header{ "Date": []string{now.Format(time.RFC1123)}, "Cache-Control": []string{"no-cache"}, }, Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), }, err: nil, } tp := NewMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) r.Header.Set("Cache-Control", "stale-if-error=100") resp, err := tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } // On failure, response is returned from the cache tmock.response = nil tmock.err = errors.New("some error") resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } // Same for http errors tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} tmock.err = nil resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } // If failure last more than max stale, error is returned clock = &fakeClock{elapsed: 200 * time.Second} _, err = tp.RoundTrip(r) if err != tmock.err { t.Fatalf("got err %v, want %v", err, tmock.err) } } func TestStaleIfErrorResponse(t *testing.T) { resetTest() now := time.Now() tmock := transportMock{ response: &http.Response{ Status: http.StatusText(http.StatusOK), StatusCode: http.StatusOK, Header: http.Header{ "Date": []string{now.Format(time.RFC1123)}, "Cache-Control": []string{"no-cache, stale-if-error"}, }, Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), }, err: nil, } tp := NewMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) resp, err := tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } // On failure, response is returned from the cache tmock.response = nil tmock.err = errors.New("some error") resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } } func TestStaleIfErrorResponseLifetime(t *testing.T) { resetTest() now := time.Now() tmock := transportMock{ response: &http.Response{ Status: http.StatusText(http.StatusOK), StatusCode: http.StatusOK, Header: http.Header{ "Date": []string{now.Format(time.RFC1123)}, "Cache-Control": []string{"no-cache, stale-if-error=100"}, }, Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), }, err: nil, } tp := NewMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) resp, err := tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } // On failure, response is returned from the cache tmock.response = nil tmock.err = errors.New("some error") resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } // If failure last more than max stale, error is returned clock = &fakeClock{elapsed: 200 * time.Second} _, err = tp.RoundTrip(r) if err != tmock.err { t.Fatalf("got err %v, want %v", err, tmock.err) } } // This tests the fix for https://github.com/gregjones/httpcache/issues/74. // Previously, after a stale response was used after encountering an error, // its StatusCode was being incorrectly replaced. func TestStaleIfErrorKeepsStatus(t *testing.T) { resetTest() now := time.Now() tmock := transportMock{ response: &http.Response{ Status: http.StatusText(http.StatusNotFound), StatusCode: http.StatusNotFound, Header: http.Header{ "Date": []string{now.Format(time.RFC1123)}, "Cache-Control": []string{"no-cache"}, }, Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), }, err: nil, } tp := NewMemoryCacheTransport() tp.Transport = &tmock // First time, response is cached on success r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) r.Header.Set("Cache-Control", "stale-if-error") resp, err := tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } _, err = ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } // On failure, response is returned from the cache tmock.response = nil tmock.err = errors.New("some error") resp, err = tp.RoundTrip(r) if err != nil { t.Fatal(err) } if resp == nil { t.Fatal("resp is nil") } if resp.StatusCode != http.StatusNotFound { t.Fatalf("Status wasn't 404: %d", resp.StatusCode) } } // Test that http.Client.Timeout is respected when cache transport is used. // That is so as long as request cancellation is propagated correctly. // In the past, that required CancelRequest to be implemented correctly, // but modern http.Client uses Request.Cancel (or request context) instead, // so we don't have to do anything. func TestClientTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. } resetTest() client := &http.Client{ Transport: NewMemoryCacheTransport(), Timeout: time.Second, } started := time.Now() resp, err := client.Get(s.server.URL + "/3seconds") taken := time.Since(started) if err == nil { t.Error("got nil error, want timeout error") } if resp != nil { t.Error("got non-nil resp, want nil resp") } if taken >= 2*time.Second { t.Error("client.Do took 2+ seconds, want < 2 seconds") } }