changeset 479:717b8d9db5b6

ST-60 in progress
author Devel 1
date Tue, 08 Aug 2017 12:39:51 +0200
parents ce8d1a1cda94
children 17ebea60b229
files stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormFilter.java stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormFilterTest.java
diffstat 4 files changed, 250 insertions(+), 15 deletions(-) [+]
line wrap: on
line diff
--- a/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java	Tue Aug 08 09:35:22 2017 +0200
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java	Tue Aug 08 12:39:51 2017 +0200
@@ -49,7 +49,22 @@
         return result;
     }
 
-    static TokenEntry extract(String formString, String inputId, String inputName) {
+    static String resolveAction(String formAction, String baseUri) {
+        if (formAction.startsWith("/")) {
+            return formAction;
+        }
+        if (formAction.isEmpty()) {
+            return baseUri;
+        }
+
+        int idx = baseUri.lastIndexOf('/');
+        if (idx < 0) {
+            throw new IllegalArgumentException("Invalid base URI: " + baseUri);
+        }
+        return baseUri.substring(0, idx) + '/' + formAction;
+    }
+
+    static TokenEntry extract(String formString, String inputId, String inputName, String uri) {
         Document root = Jsoup.parse(formString);
         Elements forms = root.getElementsByTag(TAG_FORM);
         if (forms.size() != 1) {
@@ -82,17 +97,18 @@
         }
 
         String action = form.attr(ATTR_ACTION);
+        action = resolveAction(action, uri);
         String method = form.attr(ATTR_METHOD);
         String name = form.attr(ATTR_NAME);
         String value = tokenInput.attr(ATTR_VALUE);
         return new TokenEntry(action, method, name, value);
     }
 
-    public static List<TokenEntry> extractAll(String document, String inputId, String inputName) {
+    public static List<TokenEntry> extractAll(String document, String inputId, String inputName, String uri) {
         List<FormBoundary> forms = scan(document);
         List<TokenEntry> result = new ArrayList<>(forms.size());
         for (FormBoundary form : forms) {
-            TokenEntry entry = extract(document.substring(form.start, form.end), inputId, inputName);
+            TokenEntry entry = extract(document.substring(form.start, form.end), inputId, inputName, uri);
             if (entry != null) {
                 result.add(entry);
             }
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormFilter.java	Tue Aug 08 12:39:51 2017 +0200
@@ -0,0 +1,128 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.data.ByteString;
+import com.passus.data.HeapByteBuff;
+import com.passus.net.http.HttpContentType;
+import com.passus.net.http.HttpMessageHelper;
+import com.passus.net.http.HttpMethod;
+import com.passus.net.http.HttpParameters;
+import com.passus.net.http.HttpRequest;
+import com.passus.net.http.HttpResponse;
+import com.passus.st.ParametersBag;
+import com.passus.st.client.http.HttpFlowContext;
+import com.passus.st.client.http.filter.HttpCsrfFormExtractor.TokenEntry;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+
+/**
+ *
+ * @author mikolaj.podbielski
+ */
+public class HttpCsrfFormFilter extends HttpFilter {
+    // TODO: form name
+
+    private static final Set<String> CONTENT_TYPES_TO_SCAN = new HashSet<>(Arrays.asList(
+            "text/html", "application/xhtml+xml"
+    ));
+
+    public static final String SESSION_KEY = "_csrfFormTokens";
+
+    private static final Logger LOGGER = LogManager.getLogger(HttpCsrfFormFilter.class);
+
+    private final HttpMessageHelper helper = new HttpMessageHelper(); // needs instance, because header decoders are not thread safe
+
+    private String inputName;
+    private String inputId;
+    private Set<ByteString> contentTypesToScan = CONTENT_TYPES_TO_SCAN.stream().map(ByteString::create).collect(Collectors.toSet());
+
+    public void setInputName(String inputName) {
+        this.inputName = inputName;
+    }
+
+    public void setInputId(String inputId) {
+        this.inputId = inputId;
+    }
+
+    @Override
+    public int filterOutbound(HttpRequest request, HttpResponse resp, HttpFlowContext context) {
+        ParametersBag session = context.scopes().getSession(request, false);
+        if (session != null) {
+            Map<String, TokenEntry> tokens = (Map) session.get(SESSION_KEY);
+            if (tokens != null) {
+                String key = request.getMethod().getName().toString() + ':' + request.getUri().toString();
+                TokenEntry entry = tokens.get(key);
+                // action method fname tvalue
+                if (entry != null) {
+                    LOGGER.debug("Found token matching current rerquest.");
+                    try {
+                        HttpParameters parameters = helper.decodeFormUrlencoded(request);
+                        if (parameters != null && parameters.contains(inputName)) {
+                            parameters.set(inputName, entry.value);
+                            helper.setFormUrlencoded(request, parameters);
+                            LOGGER.debug("Token inserted.");
+                        }
+                    } catch (IOException ex) {
+                        LOGGER.debug("Could not decode request.");
+                    }
+                }
+            }
+        }
+
+        return DUNNO;
+    }
+
+    @Override
+    public int filterInbound(HttpRequest request, HttpResponse resp, HttpFlowContext context) {
+        HttpContentType contentType = helper.getContentType(resp);
+        if (contentType != null && contentTypesToScan.contains(contentType.getMimeType())) {
+            try {
+                HeapByteBuff contentBuff = new HeapByteBuff();
+                helper.readContent(resp, contentBuff, true);
+                String content = contentBuff.toString();
+
+                ParametersBag session = context.scopes().getSession(resp);
+                Map<String, TokenEntry> tokens = (Map) session.get(SESSION_KEY);
+                if (tokens == null) {
+                    tokens = new HashMap<>();
+                    session.set(SESSION_KEY, tokens);
+                }
+
+                List<TokenEntry> entries = HttpCsrfFormExtractor.extractAll(content, inputId, inputName, request.getUri().toString());
+                for (TokenEntry entry : entries) {
+                    String key = resolveMethod(entry.method) + ':' + entry.action;
+                    tokens.put(key, entry);
+                    LOGGER.debug("Token saved {}", entry);
+                }
+            } catch (IOException ex) {
+                LOGGER.debug("Could not read response body");
+            }
+        }
+        return DUNNO;
+    }
+
+
+    private static String resolveMethod(String mtd) {
+        String method = mtd.toUpperCase();
+        if (method.isEmpty() || method.equals("GET")) {
+            return "GET";
+        } else if (method.equals("POST")) {
+            return "POST";
+        } else {
+            throw new IllegalArgumentException("Invalid form method: " + mtd);
+        }
+    }
+
+    @Override
+    public HttpCsrfFormFilter instanceForWorker(int index) {
+        return new HttpCsrfFormFilter();
+    }
+
+}
--- a/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java	Tue Aug 08 09:35:22 2017 +0200
+++ b/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java	Tue Aug 08 12:39:51 2017 +0200
@@ -3,6 +3,7 @@
 import com.passus.commons.utils.ResourceUtils;
 import com.passus.st.client.http.filter.HttpCsrfFormExtractor.TokenEntry;
 import com.passus.st.client.http.filter.HttpCsrfFormExtractor.FormBoundary;
+import static com.passus.st.client.http.filter.HttpCsrfFormExtractor.resolveAction;
 import java.io.File;
 import java.io.IOException;
 import java.nio.file.Files;
@@ -36,13 +37,30 @@
     }
 
     @Test
+    public void testResolveAction() {
+        assertEquals("/form/action", resolveAction("/form/action", ""));
+        assertEquals("/form/action", resolveAction("/form/action", "/some"));
+        assertEquals("/form/action", resolveAction("/form/action", "/some/path"));
+
+        try {
+            assertEquals("", resolveAction("action", ""));
+            fail("Should throw IllArgEx");
+        } catch (IllegalArgumentException ignore) {
+        }
+
+        assertEquals("/some/action", resolveAction("action", "/some/path"));
+
+        assertEquals("/some/path", resolveAction("", "/some/path"));
+    }
+
+    @Test
     public void testExtractFromComplexHtml() {
         FormBoundary fb = HttpCsrfFormExtractor.scan(document).get(0);
         String form = document.substring(fb.start, fb.end);
 
         // no action
-        TokenEntry entry = HttpCsrfFormExtractor.extract(form, "form__token", "form[_token]");
-        assertEntryEquals(entry, "", "post", "form1", "token-1qwerty");
+        TokenEntry entry = HttpCsrfFormExtractor.extract(form, "form__token", "form[_token]", "/save");
+        assertEntryEquals(entry, "/save", "post", "form1", "token-1qwerty");
     }
 
     @Test
@@ -52,18 +70,18 @@
 
         // absolute action, no method
         tag = "<form name=\"f1\" action=\"/svc/save\"><input type=\"hidden\" id=\"t_id\" name=\"t_n\" value=\"t4val\" /></form>";
-        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n");
+        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", "");
         assertEntryEquals(entry, "/svc/save", "", "f1", "t4val");
 
         // relative action, no form name
         tag = "<form action=\"save\" method=\"post\"><input type=\"hidden\" id=\"t_id\" name=\"t_n\" value=\"t4val\" /></form>";
-        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n");
-        assertEntryEquals(entry, "save", "post", "", "t4val");
+        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", "/path/abc");
+        assertEntryEquals(entry, "/path/save", "post", "", "t4val");
     }
 
     @Test
     public void testExtractByIdName() {
-        String tag = "<form>"
+        String tag = "<form action=\"/save\">"
                 + "<input type=\"hidden\"                          value=\"t4val1\" />"
                 + "<input type=\"hidden\"             name=\"t_n\" value=\"t4val2\" />"
                 + "<input type=\"hidden\" id=\"t_id\"              value=\"t4val3\" />"
@@ -72,19 +90,19 @@
         TokenEntry entry;
 
 // token by name and id
-        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n");
+        entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", "");
         assertEquals("t4val4", entry.value);
 
 // token by id
-        entry = HttpCsrfFormExtractor.extract(tag, "t_id", null);
+        entry = HttpCsrfFormExtractor.extract(tag, "t_id", null, "");
         assertEquals("t4val3", entry.value);
-        
+
 // token by name
-        entry = HttpCsrfFormExtractor.extract(tag, null, "t_n");
+        entry = HttpCsrfFormExtractor.extract(tag, null, "t_n", "");
         assertEquals("t4val2", entry.value);
-        
+
 // token is first hidden input
-        entry = HttpCsrfFormExtractor.extract(tag, null, null);
+        entry = HttpCsrfFormExtractor.extract(tag, null, null, "");
         assertEquals("t4val1", entry.value);
     }
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormFilterTest.java	Tue Aug 08 12:39:51 2017 +0200
@@ -0,0 +1,73 @@
+package com.passus.st.client.http.filter;
+
+import com.passus.data.ByteBuffDataSource;
+import com.passus.data.DataSource;
+import com.passus.net.http.HttpMessage;
+import com.passus.net.http.HttpMessageHelper;
+import com.passus.net.http.HttpParameters;
+import com.passus.net.http.HttpRequest;
+import com.passus.net.http.HttpRequestBuilder;
+import com.passus.net.http.HttpResponse;
+import com.passus.net.http.HttpResponseBuilder;
+import static com.passus.st.client.http.HttpConsts.TAG_SESSION_ID;
+import com.passus.st.client.http.HttpFlowContext;
+import com.passus.st.client.http.HttpScopes;
+import java.io.IOException;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.AssertJUnit.*;
+import org.testng.annotations.Test;
+
+/**
+ *
+ * @author mikolaj.podbielski
+ */
+public class HttpCsrfFormFilterTest {
+
+    final HttpFlowContext mockContext = mock(HttpFlowContext.class);
+
+    @Test
+    public void testFilter() throws IOException {
+        HttpRequest req1 = HttpRequestBuilder.get("example.com/edit").build();
+        HttpResponse resp1Orig = HttpResponseBuilder.ok().content(form("oldToken"))
+                .header("Content-Type", "text/html").build();
+        HttpResponse resp1Live = HttpResponseBuilder.ok().content(form("newToken123"))
+                .header("Content-Type", "text/html").build();
+        HttpRequest req2 = HttpRequestBuilder.post("example.com/save").content(post("oldToken"))
+                .header("Content-Type", "application/x-www-form-urlencoded").build();
+        tagMessages(req1, resp1Orig, resp1Live, req2);
+        // sfejkować kontekst i sesję
+
+        when(mockContext.scopes()).thenReturn(new HttpScopes());
+        HttpCsrfFormFilter filter = new HttpCsrfFormFilter();
+        filter.setInputName("_token");
+
+        when(mockContext.origReponse()).thenReturn(resp1Orig);
+        filter.filterInbound(req1, resp1Live, mockContext);
+
+        filter.filterOutbound(req2, null, mockContext);
+        
+        ByteBuffDataSource content = (ByteBuffDataSource) req2.getContent();
+        HttpParameters form = HttpMessageHelper.STRICT.decodeFormUrlencoded(req2);
+        assertEquals("newToken123", form.get("_token").toString());
+        assertEquals(34, req2.getContent().available());
+    }
+
+    private static String form(String value) {
+        return "<form action=\"save\" method=\"post\"><input type=\"hidden\" name=\"_token\" value=\"" + value + "\" /></form>";
+    }
+
+    private static HttpParameters post(CharSequence value) {
+        HttpParameters form = new HttpParameters();
+        form.add("abc", "def");
+        form.add("_token", value);
+        form.add("ghi", "jkl");
+        return form;
+    }
+
+    private static void tagMessages(HttpMessage... messages) {
+        for (HttpMessage message : messages) {
+            message.setTag(TAG_SESSION_ID, "sid1");
+        }
+    }
+}