Mercurial > stress-tester
changeset 479:717b8d9db5b6
ST-60 in progress
author | Devel 1 |
---|---|
date | Tue, 08 Aug 2017 12:39:51 +0200 |
parents | ce8d1a1cda94 |
children | 17ebea60b229 |
files | stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormFilter.java stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormFilterTest.java |
diffstat | 4 files changed, 250 insertions(+), 15 deletions(-) [+] |
line wrap: on
line diff
--- a/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java Tue Aug 08 09:35:22 2017 +0200 +++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormExtractor.java Tue Aug 08 12:39:51 2017 +0200 @@ -49,7 +49,22 @@ return result; } - static TokenEntry extract(String formString, String inputId, String inputName) { + static String resolveAction(String formAction, String baseUri) { + if (formAction.startsWith("/")) { + return formAction; + } + if (formAction.isEmpty()) { + return baseUri; + } + + int idx = baseUri.lastIndexOf('/'); + if (idx < 0) { + throw new IllegalArgumentException("Invalid base URI: " + baseUri); + } + return baseUri.substring(0, idx) + '/' + formAction; + } + + static TokenEntry extract(String formString, String inputId, String inputName, String uri) { Document root = Jsoup.parse(formString); Elements forms = root.getElementsByTag(TAG_FORM); if (forms.size() != 1) { @@ -82,17 +97,18 @@ } String action = form.attr(ATTR_ACTION); + action = resolveAction(action, uri); String method = form.attr(ATTR_METHOD); String name = form.attr(ATTR_NAME); String value = tokenInput.attr(ATTR_VALUE); return new TokenEntry(action, method, name, value); } - public static List<TokenEntry> extractAll(String document, String inputId, String inputName) { + public static List<TokenEntry> extractAll(String document, String inputId, String inputName, String uri) { List<FormBoundary> forms = scan(document); List<TokenEntry> result = new ArrayList<>(forms.size()); for (FormBoundary form : forms) { - TokenEntry entry = extract(document.substring(form.start, form.end), inputId, inputName); + TokenEntry entry = extract(document.substring(form.start, form.end), inputId, inputName, uri); if (entry != null) { result.add(entry); }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/stress-tester/src/main/java/com/passus/st/client/http/filter/HttpCsrfFormFilter.java Tue Aug 08 12:39:51 2017 +0200 @@ -0,0 +1,128 @@ +package com.passus.st.client.http.filter; + +import com.passus.data.ByteString; +import com.passus.data.HeapByteBuff; +import com.passus.net.http.HttpContentType; +import com.passus.net.http.HttpMessageHelper; +import com.passus.net.http.HttpMethod; +import com.passus.net.http.HttpParameters; +import com.passus.net.http.HttpRequest; +import com.passus.net.http.HttpResponse; +import com.passus.st.ParametersBag; +import com.passus.st.client.http.HttpFlowContext; +import com.passus.st.client.http.filter.HttpCsrfFormExtractor.TokenEntry; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * + * @author mikolaj.podbielski + */ +public class HttpCsrfFormFilter extends HttpFilter { + // TODO: form name + + private static final Set<String> CONTENT_TYPES_TO_SCAN = new HashSet<>(Arrays.asList( + "text/html", "application/xhtml+xml" + )); + + public static final String SESSION_KEY = "_csrfFormTokens"; + + private static final Logger LOGGER = LogManager.getLogger(HttpCsrfFormFilter.class); + + private final HttpMessageHelper helper = new HttpMessageHelper(); // needs instance, because header decoders are not thread safe + + private String inputName; + private String inputId; + private Set<ByteString> contentTypesToScan = CONTENT_TYPES_TO_SCAN.stream().map(ByteString::create).collect(Collectors.toSet()); + + public void setInputName(String inputName) { + this.inputName = inputName; + } + + public void setInputId(String inputId) { + this.inputId = inputId; + } + + @Override + public int filterOutbound(HttpRequest request, HttpResponse resp, HttpFlowContext context) { + ParametersBag session = context.scopes().getSession(request, false); + if (session != null) { + Map<String, TokenEntry> tokens = (Map) session.get(SESSION_KEY); + if (tokens != null) { + String key = request.getMethod().getName().toString() + ':' + request.getUri().toString(); + TokenEntry entry = tokens.get(key); + // action method fname tvalue + if (entry != null) { + LOGGER.debug("Found token matching current rerquest."); + try { + HttpParameters parameters = helper.decodeFormUrlencoded(request); + if (parameters != null && parameters.contains(inputName)) { + parameters.set(inputName, entry.value); + helper.setFormUrlencoded(request, parameters); + LOGGER.debug("Token inserted."); + } + } catch (IOException ex) { + LOGGER.debug("Could not decode request."); + } + } + } + } + + return DUNNO; + } + + @Override + public int filterInbound(HttpRequest request, HttpResponse resp, HttpFlowContext context) { + HttpContentType contentType = helper.getContentType(resp); + if (contentType != null && contentTypesToScan.contains(contentType.getMimeType())) { + try { + HeapByteBuff contentBuff = new HeapByteBuff(); + helper.readContent(resp, contentBuff, true); + String content = contentBuff.toString(); + + ParametersBag session = context.scopes().getSession(resp); + Map<String, TokenEntry> tokens = (Map) session.get(SESSION_KEY); + if (tokens == null) { + tokens = new HashMap<>(); + session.set(SESSION_KEY, tokens); + } + + List<TokenEntry> entries = HttpCsrfFormExtractor.extractAll(content, inputId, inputName, request.getUri().toString()); + for (TokenEntry entry : entries) { + String key = resolveMethod(entry.method) + ':' + entry.action; + tokens.put(key, entry); + LOGGER.debug("Token saved {}", entry); + } + } catch (IOException ex) { + LOGGER.debug("Could not read response body"); + } + } + return DUNNO; + } + + + private static String resolveMethod(String mtd) { + String method = mtd.toUpperCase(); + if (method.isEmpty() || method.equals("GET")) { + return "GET"; + } else if (method.equals("POST")) { + return "POST"; + } else { + throw new IllegalArgumentException("Invalid form method: " + mtd); + } + } + + @Override + public HttpCsrfFormFilter instanceForWorker(int index) { + return new HttpCsrfFormFilter(); + } + +}
--- a/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java Tue Aug 08 09:35:22 2017 +0200 +++ b/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormExtractorTest.java Tue Aug 08 12:39:51 2017 +0200 @@ -3,6 +3,7 @@ import com.passus.commons.utils.ResourceUtils; import com.passus.st.client.http.filter.HttpCsrfFormExtractor.TokenEntry; import com.passus.st.client.http.filter.HttpCsrfFormExtractor.FormBoundary; +import static com.passus.st.client.http.filter.HttpCsrfFormExtractor.resolveAction; import java.io.File; import java.io.IOException; import java.nio.file.Files; @@ -36,13 +37,30 @@ } @Test + public void testResolveAction() { + assertEquals("/form/action", resolveAction("/form/action", "")); + assertEquals("/form/action", resolveAction("/form/action", "/some")); + assertEquals("/form/action", resolveAction("/form/action", "/some/path")); + + try { + assertEquals("", resolveAction("action", "")); + fail("Should throw IllArgEx"); + } catch (IllegalArgumentException ignore) { + } + + assertEquals("/some/action", resolveAction("action", "/some/path")); + + assertEquals("/some/path", resolveAction("", "/some/path")); + } + + @Test public void testExtractFromComplexHtml() { FormBoundary fb = HttpCsrfFormExtractor.scan(document).get(0); String form = document.substring(fb.start, fb.end); // no action - TokenEntry entry = HttpCsrfFormExtractor.extract(form, "form__token", "form[_token]"); - assertEntryEquals(entry, "", "post", "form1", "token-1qwerty"); + TokenEntry entry = HttpCsrfFormExtractor.extract(form, "form__token", "form[_token]", "/save"); + assertEntryEquals(entry, "/save", "post", "form1", "token-1qwerty"); } @Test @@ -52,18 +70,18 @@ // absolute action, no method tag = "<form name=\"f1\" action=\"/svc/save\"><input type=\"hidden\" id=\"t_id\" name=\"t_n\" value=\"t4val\" /></form>"; - entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n"); + entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", ""); assertEntryEquals(entry, "/svc/save", "", "f1", "t4val"); // relative action, no form name tag = "<form action=\"save\" method=\"post\"><input type=\"hidden\" id=\"t_id\" name=\"t_n\" value=\"t4val\" /></form>"; - entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n"); - assertEntryEquals(entry, "save", "post", "", "t4val"); + entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", "/path/abc"); + assertEntryEquals(entry, "/path/save", "post", "", "t4val"); } @Test public void testExtractByIdName() { - String tag = "<form>" + String tag = "<form action=\"/save\">" + "<input type=\"hidden\" value=\"t4val1\" />" + "<input type=\"hidden\" name=\"t_n\" value=\"t4val2\" />" + "<input type=\"hidden\" id=\"t_id\" value=\"t4val3\" />" @@ -72,19 +90,19 @@ TokenEntry entry; // token by name and id - entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n"); + entry = HttpCsrfFormExtractor.extract(tag, "t_id", "t_n", ""); assertEquals("t4val4", entry.value); // token by id - entry = HttpCsrfFormExtractor.extract(tag, "t_id", null); + entry = HttpCsrfFormExtractor.extract(tag, "t_id", null, ""); assertEquals("t4val3", entry.value); - + // token by name - entry = HttpCsrfFormExtractor.extract(tag, null, "t_n"); + entry = HttpCsrfFormExtractor.extract(tag, null, "t_n", ""); assertEquals("t4val2", entry.value); - + // token is first hidden input - entry = HttpCsrfFormExtractor.extract(tag, null, null); + entry = HttpCsrfFormExtractor.extract(tag, null, null, ""); assertEquals("t4val1", entry.value); }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/stress-tester/src/test/java/com/passus/st/client/http/filter/HttpCsrfFormFilterTest.java Tue Aug 08 12:39:51 2017 +0200 @@ -0,0 +1,73 @@ +package com.passus.st.client.http.filter; + +import com.passus.data.ByteBuffDataSource; +import com.passus.data.DataSource; +import com.passus.net.http.HttpMessage; +import com.passus.net.http.HttpMessageHelper; +import com.passus.net.http.HttpParameters; +import com.passus.net.http.HttpRequest; +import com.passus.net.http.HttpRequestBuilder; +import com.passus.net.http.HttpResponse; +import com.passus.net.http.HttpResponseBuilder; +import static com.passus.st.client.http.HttpConsts.TAG_SESSION_ID; +import com.passus.st.client.http.HttpFlowContext; +import com.passus.st.client.http.HttpScopes; +import java.io.IOException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.AssertJUnit.*; +import org.testng.annotations.Test; + +/** + * + * @author mikolaj.podbielski + */ +public class HttpCsrfFormFilterTest { + + final HttpFlowContext mockContext = mock(HttpFlowContext.class); + + @Test + public void testFilter() throws IOException { + HttpRequest req1 = HttpRequestBuilder.get("example.com/edit").build(); + HttpResponse resp1Orig = HttpResponseBuilder.ok().content(form("oldToken")) + .header("Content-Type", "text/html").build(); + HttpResponse resp1Live = HttpResponseBuilder.ok().content(form("newToken123")) + .header("Content-Type", "text/html").build(); + HttpRequest req2 = HttpRequestBuilder.post("example.com/save").content(post("oldToken")) + .header("Content-Type", "application/x-www-form-urlencoded").build(); + tagMessages(req1, resp1Orig, resp1Live, req2); + // sfejkować kontekst i sesję + + when(mockContext.scopes()).thenReturn(new HttpScopes()); + HttpCsrfFormFilter filter = new HttpCsrfFormFilter(); + filter.setInputName("_token"); + + when(mockContext.origReponse()).thenReturn(resp1Orig); + filter.filterInbound(req1, resp1Live, mockContext); + + filter.filterOutbound(req2, null, mockContext); + + ByteBuffDataSource content = (ByteBuffDataSource) req2.getContent(); + HttpParameters form = HttpMessageHelper.STRICT.decodeFormUrlencoded(req2); + assertEquals("newToken123", form.get("_token").toString()); + assertEquals(34, req2.getContent().available()); + } + + private static String form(String value) { + return "<form action=\"save\" method=\"post\"><input type=\"hidden\" name=\"_token\" value=\"" + value + "\" /></form>"; + } + + private static HttpParameters post(CharSequence value) { + HttpParameters form = new HttpParameters(); + form.add("abc", "def"); + form.add("_token", value); + form.add("ghi", "jkl"); + return form; + } + + private static void tagMessages(HttpMessage... messages) { + for (HttpMessage message : messages) { + message.setTag(TAG_SESSION_ID, "sid1"); + } + } +}