changeset 444:6002d2c3f9d1

HttpCsrfFilter - Store strategy
author Devel 1
date Fri, 28 Jul 2017 15:54:52 +0200
parents c74451bbdc9c
children cf6662b69c07
files stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilter.java
diffstat 1 files changed, 56 insertions(+), 44 deletions(-) [+]
line wrap: on
line diff
--- a/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilter.java	Fri Jul 28 14:47:33 2017 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilter.java	Fri Jul 28 15:54:52 2017 +0200
@@ -38,27 +38,16 @@
 
     private static final Logger LOGGER = LogManager.getLogger(HttpCsrfFilter.class);
 
-    public static abstract class Extractor {
-
-        public final String type;
+    public interface Extractor {
 
-        protected Extractor(String type) {
-            this.type = type;
-        }
-
-        public String getType() {
-            return type;
-        }
-
-        public abstract ByteString extract(HttpMessage msg);
+        public ByteString extract(HttpMessage msg);
     }
 
-    public static final class CookieExtractor extends Extractor {
+    public static final class CookieExtractor implements Extractor {
 
         private final ByteString cookieName;
 
         public CookieExtractor(CharSequence cookieName) {
-            super("cookie");
             Assert.notNull(cookieName, "cookieName");
             this.cookieName = ByteString.create(cookieName);
         }
@@ -76,15 +65,13 @@
 
             return null;
         }
-
     }
 
-    public static final class HeaderExtractor extends Extractor {
+    public static final class HeaderExtractor implements Extractor {
 
         private final ByteString headerName;
 
         public HeaderExtractor(CharSequence headerName) {
-            super("cookie");
             Assert.notNull(headerName, "headerName");
             this.headerName = ByteString.create(headerName);
         }
@@ -99,28 +86,16 @@
         }
     }
 
-    public static abstract class Injector {
-
-        public final String type;
+    public interface Injector {
 
-        protected Injector(String type) {
-            this.type = type;
-        }
-
-        public String getType() {
-            return type;
-        }
-
-        public abstract void inject(HttpMessage msg, ByteString csrfToken);
-
+        public void inject(HttpMessage msg, ByteString csrfToken);
     }
 
-    public static final class HeaderInjector extends Injector {
+    public static final class HeaderInjector implements Injector {
 
         private final ByteString headerName;
 
         public HeaderInjector(CharSequence headerName) {
-            super("header");
             Assert.notNull(headerName, "headerName");
             this.headerName = ByteString.create(headerName);
         }
@@ -133,7 +108,43 @@
         public void inject(HttpMessage msg, ByteString csrfToken) {
             msg.getHeaders().set(headerName, csrfToken);
         }
+    }
 
+    public static abstract class Store {
+
+        public abstract void save(ByteString token);
+
+        public abstract ByteString load();
+    }
+
+    public static final class QueueStore extends Store {
+
+        private final Queue<ByteString> tokens = new LinkedList<>();
+
+        @Override
+        public void save(ByteString token) {
+            tokens.add(token);
+        }
+
+        @Override
+        public ByteString load() {
+            return tokens.poll();
+        }
+    }
+
+    public static final class SingleTokenStore extends Store {
+
+        private ByteString token;
+
+        @Override
+        public void save(ByteString token) {
+            this.token = token;
+        }
+
+        @Override
+        public ByteString load() {
+            return token;
+        }
     }
 
     private static final HttpMessageHelper HELPER = HttpMessageHelper.NOT_STRICT;
@@ -201,11 +212,12 @@
     @Override
     public int filterOutbound(HttpRequest req, HttpResponse resp, HttpFlowContext context) {
         if (req != null) {
-            ParametersBag params = context.scopes().getSession(req);
-            if (params != null) {
-                Queue<ByteString> tokens = (Queue<ByteString>) params.get(SESSION_KEY);
-                if (tokens != null && !tokens.isEmpty()) {
-                    ByteString token = tokens.poll();
+            ParametersBag session = context.scopes().getSession(req);
+            if (session != null) {
+                Store store = (Store) session.get(SESSION_KEY);
+                if (store != null) {
+                    ByteString token = store.load();
+
                     for (Injector injector : injectors) {
                         injector.inject(req, token);
                     }
@@ -228,16 +240,16 @@
             }
 
             if (token != null) {
-                ParametersBag params = context.scopes().getSession(resp);
-                if (params != null) {
-                    Queue<ByteString> tokens = (Queue<ByteString>) params.get(SESSION_KEY);
+                ParametersBag session = context.scopes().getSession(resp);
+                if (session != null) {
+                    Store store = (Store) session.get(SESSION_KEY);
 
-                    if (tokens == null) {
-                        tokens = new LinkedList<>();
-                        params.set(SESSION_KEY, tokens);
+                    if (store == null) {
+                        store = new QueueStore();
+                        session.set(SESSION_KEY, store);
                     }
 
-                    tokens.add(token);
+                    store.save(token);
                 }
             }
         }