changeset 432:0db3ae52a2f3

HttpCsrfFilter
author Devel 2
date Thu, 27 Jul 2017 13:36:55 +0200
parents bbc6b52ab089
children f2d245d74663
files stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilter.java stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilterExtractorTransformer.java stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilterInjectorTransformer.java stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFilterTest.java
diffstat 4 files changed, 534 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilter.java	Thu Jul 27 13:36:55 2017 +0200
@@ -0,0 +1,268 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.commons.Assert;
+import com.passus.commons.annotations.Plugin;
+import com.passus.config.Configuration;
+import com.passus.config.annotations.NodeDefinitionCreate;
+import static com.passus.config.schema.ConfigurationSchemaBuilder.mapDef;
+import static com.passus.config.schema.ConfigurationSchemaBuilder.tupleDef;
+import static com.passus.config.schema.ConfigurationSchemaBuilder.valueDef;
+import com.passus.config.schema.KeyNameVaryListNodeDefinition;
+import com.passus.config.schema.NodeDefinition;
+import com.passus.config.schema.NodeDefinitionCreator;
+import com.passus.data.ByteString;
+import com.passus.net.http.HttpCookie;
+import com.passus.net.http.HttpMessage;
+import com.passus.net.http.HttpMessageHelper;
+import com.passus.net.http.HttpRequest;
+import com.passus.net.http.HttpResponse;
+import com.passus.st.ParametersBag;
+import com.passus.st.client.http.HttpFlowContext;
+import com.passus.st.plugin.PluginConstants;
+import com.passus.st.validation.HeaderNameValidator;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+
+/**
+ *
+ * @author Mirosław Hawrot
+ */
+@NodeDefinitionCreate(HttpCsrfFilter.HttpCsrfFilterNodeDefCreator.class)
+@Plugin(name = HttpCsrfFilter.TYPE, category = PluginConstants.CATEGORY_HTTP_FILTER)
+public final class HttpCsrfFilter extends HttpFilter {
+
+    public static abstract class Extractor {
+
+        public final String type;
+
+        protected Extractor(String type) {
+            this.type = type;
+        }
+
+        public String getType() {
+            return type;
+        }
+
+        public abstract ByteString extract(HttpMessage msg);
+    }
+
+    public static final class CookieExtractor extends Extractor {
+
+        private final ByteString cookieName;
+
+        public CookieExtractor(CharSequence cookieName) {
+            super("cookie");
+            Assert.notNull(cookieName, "cookieName");
+            this.cookieName = ByteString.create(cookieName);
+        }
+
+        public ByteString getCookieName() {
+            return cookieName;
+        }
+
+        @Override
+        public ByteString extract(HttpMessage msg) {
+            HttpCookie cookie = HELPER.getCookie(msg, cookieName);
+            if (cookie != null) {
+                return cookie.getValue();
+            }
+
+            return null;
+        }
+
+    }
+
+    public static final class HeaderExtractor extends Extractor {
+
+        private final ByteString headerName;
+
+        public HeaderExtractor(CharSequence headerName) {
+            super("cookie");
+            Assert.notNull(headerName, "headerName");
+            this.headerName = ByteString.create(headerName);
+        }
+
+        public ByteString getHeaderName() {
+            return headerName;
+        }
+
+        @Override
+        public ByteString extract(HttpMessage msg) {
+            return msg.getHeaders().get(headerName);
+        }
+    }
+
+    public static abstract class Injector {
+
+        public final String type;
+
+        protected Injector(String type) {
+            this.type = type;
+        }
+
+        public String getType() {
+            return type;
+        }
+
+        public abstract void inject(HttpMessage msg, ByteString csrfToken);
+
+    }
+
+    public static final class HeaderInjector extends Injector {
+
+        private final ByteString headerName;
+
+        public HeaderInjector(CharSequence headerName) {
+            super("header");
+            Assert.notNull(headerName, "headerName");
+            this.headerName = ByteString.create(headerName);
+        }
+
+        public ByteString getHeaderName() {
+            return headerName;
+        }
+
+        @Override
+        public void inject(HttpMessage msg, ByteString csrfToken) {
+            msg.getHeaders().set(headerName, csrfToken);
+        }
+
+    }
+
+    private static final HttpMessageHelper HELPER = HttpMessageHelper.NOT_STRICT;
+
+    public static final String TYPE = "csrf";
+
+    public static final String SESSION_KEY = "_csrfTokens";
+
+    private final List<Extractor> extractors = new ArrayList<>();
+
+    private final List<Injector> injectors = new ArrayList<>();
+
+    public HttpCsrfFilter() {
+    }
+
+    public HttpCsrfFilter(List<Extractor> extractors, List<Injector> injectors) {
+        setExtractors(extractors);
+        setInjectors(injectors);
+    }
+
+    public List<Extractor> getExtractors() {
+        return Collections.unmodifiableList(extractors);
+    }
+
+    public void setExtractors(List<Extractor> extractors) {
+        Assert.notContainsNull(extractors, "extractors");
+        this.extractors.clear();
+        this.extractors.addAll(extractors);
+    }
+
+    public void addExtractor(Extractor extractor) {
+        Assert.notNull(extractor, "extractor");
+        this.extractors.add(extractor);
+    }
+
+    public void removeExtractor(Extractor extractor) {
+        this.extractors.remove(extractor);
+    }
+
+    public List<Injector> getInjectors() {
+        return Collections.unmodifiableList(injectors);
+    }
+
+    public void setInjectors(List<Injector> injectors) {
+        Assert.notContainsNull(injectors, "injectors");
+        this.injectors.clear();
+        this.injectors.addAll(injectors);
+    }
+
+    public void addInjector(Injector injector) {
+        Assert.notNull(injector, "injector");
+        this.injectors.add(injector);
+    }
+
+    public void removeInjector(Injector injector) {
+        this.injectors.remove(injector);
+    }
+
+    @Override
+    public void configure(Configuration cfg) {
+        setExtractors((List<Extractor>) cfg.get("extract", Collections.EMPTY_LIST));
+        setInjectors((List<Injector>) cfg.get("inject", Collections.EMPTY_LIST));
+    }
+
+    @Override
+    public int filterOutbound(HttpRequest req, HttpResponse resp, HttpFlowContext context) {
+        if (req != null) {
+            ParametersBag params = context.scopes().getSession(req);
+            if (params != null) {
+                Queue<ByteString> tokens = (Queue<ByteString>) params.get(SESSION_KEY);
+                if (tokens != null && !tokens.isEmpty()) {
+                    ByteString token = tokens.poll();
+                    for (Injector injector : injectors) {
+                        injector.inject(req, token);
+                    }
+                }
+            }
+        }
+
+        return DUNNO;
+    }
+
+    @Override
+    public int filterInbound(HttpRequest request, HttpResponse resp, HttpFlowContext context) {
+        if (resp != null) {
+            ByteString token = null;
+            for (Extractor extractor : extractors) {
+                token = extractor.extract(resp);
+                if (token != null) {
+                    break;
+                }
+            }
+
+            if (token != null) {
+                ParametersBag params = context.scopes().getSession(resp);
+                if (params != null) {
+                    Queue<ByteString> tokens = (Queue<ByteString>) params.get(SESSION_KEY);
+                    if (tokens == null) {
+                        tokens = new LinkedList<>();
+                        params.set(SESSION_KEY, tokens);
+                    }
+
+                    tokens.add(token);
+                }
+            }
+        }
+
+        return DUNNO;
+    }
+
+    @Override
+    public HttpCsrfFilter instanceForWorker(int index) {
+        return new HttpCsrfFilter(extractors, injectors);
+    }
+
+    public static class HttpCsrfFilterNodeDefCreator implements NodeDefinitionCreator {
+
+        @Override
+        public NodeDefinition create() {
+            KeyNameVaryListNodeDefinition extractorsDef = new KeyNameVaryListNodeDefinition()
+                    .setNodeTransformer(new HttpCsrfFilterExtractorTransformer())
+                    .add("header", valueDef().addValidator(HeaderNameValidator.INSTANCE))
+                    .add("cookie", valueDef());
+
+            KeyNameVaryListNodeDefinition injectorsDef = new KeyNameVaryListNodeDefinition()
+                    .setNodeTransformer(new HttpCsrfFilterInjectorTransformer())
+                    .add("header", valueDef().addValidator(HeaderNameValidator.INSTANCE));
+
+            return mapDef(
+                    tupleDef("extract", extractorsDef),
+                    tupleDef("inject", injectorsDef)
+            );
+        }
+
+    }
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilterExtractorTransformer.java	Thu Jul 27 13:36:55 2017 +0200
@@ -0,0 +1,83 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.config.CMapNode;
+import com.passus.config.CNode;
+import com.passus.config.CTupleNode;
+import com.passus.config.CValueNode;
+import com.passus.config.NodeType;
+import com.passus.config.schema.NodeTransformer;
+import com.passus.config.validation.Errors;
+import com.passus.st.client.http.filter.HttpCsrfFilter.CookieExtractor;
+import com.passus.st.client.http.filter.HttpCsrfFilter.Extractor;
+import com.passus.st.client.http.filter.HttpCsrfFilter.HeaderExtractor;
+import static com.passus.st.validation.NodeValidationUtils.validateType;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ *
+ * @author Mirosław Hawrot
+ */
+public class HttpCsrfFilterExtractorTransformer implements NodeTransformer {
+
+    private Extractor createNameExtractor(CTupleNode tuple, Errors errors, Class<? extends Extractor> clazz) {
+        if (validateType(tuple.getNode(), NodeType.VALUE, errors)) {
+            CValueNode valNode = (CValueNode) tuple.getNode();
+            try {
+                return clazz
+                        .getConstructor(CharSequence.class)
+                        .newInstance(valNode.getValue().toString());
+            } catch (Exception e) {
+                throw new RuntimeException(e.getMessage(), e);
+            }
+        }
+
+        return null;
+    }
+
+    @Override
+    public Object transform(CNode node, Errors errors) {
+        CMapNode mapNode = (CMapNode) node;
+
+        List<CTupleNode> tuples = mapNode.getChildren();
+        List<Extractor> extractors;
+        if (tuples.isEmpty()) {
+            extractors = Collections.EMPTY_LIST;
+        } else {
+            extractors = new ArrayList<>();
+        }
+
+        for (CTupleNode tuple : tuples) {
+            String opName = tuple.getName();
+            try {
+                errors.pushNestedPath(opName);
+                Extractor extractor = null;
+                switch (opName.toLowerCase()) {
+                    case "cookie":
+                        extractor = createNameExtractor(tuple, errors, CookieExtractor.class);
+                        break;
+                    case "header":
+                        extractor = createNameExtractor(tuple, errors, HeaderExtractor.class);
+                        break;
+                    default:
+                        throw new IllegalStateException("Not supported extractor '" + opName + "'.");
+                }
+
+                if (extractor != null) {
+                    extractors.add(extractor);
+                }
+            } finally {
+                errors.popNestedPath();
+            }
+        }
+
+        return new CValueNode(extractors);
+    }
+
+    @Override
+    public Object reverseTransform(CNode node, Errors errors) {
+        throw new UnsupportedOperationException("Not supported yet.");
+    }
+
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFilterInjectorTransformer.java	Thu Jul 27 13:36:55 2017 +0200
@@ -0,0 +1,79 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.config.CMapNode;
+import com.passus.config.CNode;
+import com.passus.config.CTupleNode;
+import com.passus.config.CValueNode;
+import com.passus.config.NodeType;
+import com.passus.config.schema.NodeTransformer;
+import com.passus.config.validation.Errors;
+import com.passus.st.client.http.filter.HttpCsrfFilter.HeaderInjector;
+import com.passus.st.client.http.filter.HttpCsrfFilter.Injector;
+import static com.passus.st.validation.NodeValidationUtils.validateType;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ *
+ * @author Mirosław Hawrot
+ */
+public class HttpCsrfFilterInjectorTransformer implements NodeTransformer {
+
+    private Injector createNameExtractor(CTupleNode tuple, Errors errors, Class<? extends Injector> clazz) {
+        if (validateType(tuple.getNode(), NodeType.VALUE, errors)) {
+            CValueNode valNode = (CValueNode) tuple.getNode();
+            try {
+                return clazz
+                        .getConstructor(CharSequence.class)
+                        .newInstance(valNode.getValue().toString());
+            } catch (Exception e) {
+                throw new RuntimeException(e.getMessage(), e);
+            }
+        }
+
+        return null;
+    }
+
+    @Override
+    public Object transform(CNode node, Errors errors) {
+        CMapNode mapNode = (CMapNode) node;
+
+        List<CTupleNode> tuples = mapNode.getChildren();
+        List<Injector> injectors;
+        if (tuples.isEmpty()) {
+            injectors = Collections.EMPTY_LIST;
+        } else {
+            injectors = new ArrayList<>();
+        }
+
+        for (CTupleNode tuple : tuples) {
+            String opName = tuple.getName();
+            try {
+                errors.pushNestedPath(opName);
+                Injector injector = null;
+                switch (opName.toLowerCase()) {
+                    case "header":
+                        injector = createNameExtractor(tuple, errors, HeaderInjector.class);
+                        break;
+                    default:
+                        throw new IllegalStateException("Not supported extractor '" + opName + "'.");
+                }
+
+                if (injector != null) {
+                    injectors.add(injector);
+                }
+            } finally {
+                errors.popNestedPath();
+            }
+        }
+
+        return new CValueNode(injectors);
+    }
+
+    @Override
+    public Object reverseTransform(CNode node, Errors errors) {
+        throw new UnsupportedOperationException("Not supported yet.");
+    }
+
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFilterTest.java	Thu Jul 27 13:36:55 2017 +0200
@@ -0,0 +1,104 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.config.validation.Errors;
+import com.passus.net.http.HttpMessage;
+import com.passus.net.http.HttpRequest;
+import com.passus.net.http.HttpRequestBuilder;
+import com.passus.net.http.HttpResponse;
+import com.passus.net.http.HttpResponseBuilder;
+import static com.passus.st.client.http.HttpConsts.TAG_SESSION_ID;
+import com.passus.st.client.http.HttpFlowContext;
+import com.passus.st.client.http.HttpScopes;
+import com.passus.st.client.http.filter.HttpCsrfFilter.CookieExtractor;
+import com.passus.st.client.http.filter.HttpCsrfFilter.Extractor;
+import com.passus.st.client.http.filter.HttpCsrfFilter.HeaderExtractor;
+import com.passus.st.client.http.filter.HttpCsrfFilter.HeaderInjector;
+import com.passus.st.client.http.filter.HttpCsrfFilter.Injector;
+import com.passus.st.emitter.SessionInfo;
+import java.util.List;
+import static org.testng.AssertJUnit.assertEquals;
+import static org.testng.AssertJUnit.assertTrue;
+import org.testng.annotations.Test;
+
+/**
+ *
+ * @author Mirosław Hawrot
+ */
+public class HttpCsrfFilterTest {
+
+    private void tagSessionId(HttpMessage msg) {
+        msg.setTag(TAG_SESSION_ID, "1");
+    }
+
+    @Test
+    public void testFilter_CookieExtractorHeaderInjector() throws Exception {
+        HttpRequest req1 = HttpRequestBuilder.get("http://test/test1")
+                .header("x-csrf-token", "token")
+                .build();
+
+        HttpResponse resp1 = HttpResponseBuilder.ok()
+                .cookie("x_csrf_token", "newToken")
+                .build();
+
+        HttpRequest req2 = HttpRequestBuilder.get("http://test/test2")
+                .header("x-csrf-token", "token1")
+                .build();
+
+        tagSessionId(req1);
+        tagSessionId(resp1);
+        tagSessionId(req2);
+
+        SessionInfo session = new SessionInfo("1.1.1.1:5000", "2.2.2.2:80");
+        HttpScopes scopes = new HttpScopes();
+
+        CookieExtractor cookieExtractor = new CookieExtractor("x_csrf_token");
+        HeaderInjector headerInjector = new HeaderInjector("x-csrf-token");
+
+        HttpFlowContext context = new HttpFlowContext(session, scopes);
+        HttpCsrfFilter filter = new HttpCsrfFilter();
+        filter.addExtractor(cookieExtractor);
+        filter.addInjector(headerInjector);
+
+        filter.filterOutbound(req1, null, context);
+        filter.filterInbound(null, resp1, context);
+        filter.filterOutbound(req2, null, context);
+
+        assertEquals("token", req1.getHeaders().get("x-csrf-token").toString());
+        assertEquals("newToken", req2.getHeaders().get("x-csrf-token").toString());
+    }
+
+    @Test
+    public void testConfigure() throws Exception {
+        String filterConfig = "filters:\n"
+                + "    - type: csrf\n"
+                + "      extract:\n"
+                + "        header: \"csrf-header\"\n"
+                + "        cookie: \"csrf-cookie\"\n"
+                + "      inject:\n"
+                + "        header: \"csrf-header-inject\"\n";
+
+        Errors errors = new Errors();
+        List<HttpFilter> filters = HttpFiltersConfigurator.getFilters(filterConfig, errors);
+
+        assertEquals(0, errors.getErrorCount());
+        assertEquals(1, filters.size());
+        assertTrue(filters.get(0) instanceof HttpCsrfFilter);
+
+        HttpCsrfFilter filter = (HttpCsrfFilter) filters.get(0);
+        List<Extractor> extractors = filter.getExtractors();
+        List<Injector> injectors = filter.getInjectors();
+
+        assertEquals(2, extractors.size());
+        assertEquals(1, injectors.size());
+
+        assertTrue(extractors.get(0) instanceof HeaderExtractor);
+        assertEquals("csrf-header", ((HeaderExtractor) extractors.get(0)).getHeaderName().toString());
+
+        assertTrue(extractors.get(1) instanceof CookieExtractor);
+        assertEquals("csrf-cookie", ((CookieExtractor) extractors.get(1)).getCookieName().toString());
+
+        assertTrue(injectors.get(0) instanceof HeaderInjector);
+        assertEquals("csrf-header-inject", ((HeaderInjector) injectors.get(0)).getHeaderName().toString());
+
+    }
+}