package org.traccar.helper; import org.junit.Before; import org.junit.Test; import javax.servlet.*; import javax.servlet.http.*; import java.io.BufferedReader; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.security.Principal; import java.util.*; import static org.junit.Assert.assertEquals; public class ServletHelperTest { private MockHttpServletRequestForRemoteAddr mockHttpServletRequest; @Before public void init() { mockHttpServletRequest = new MockHttpServletRequestForRemoteAddr(); } @Test public void testIpBehindReverseProxy() { mockHttpServletRequest.setRemoteAddr("147.120.1.5"); mockHttpServletRequest.addHeader("X-FORWARDED-FOR", "231.23.45.65, 10.20.10.33, 10.20.20.34"); assertEquals("231.23.45.65", ServletHelper.retrieveRemoteAddress(mockHttpServletRequest)); } @Test public void testNormalIp() { mockHttpServletRequest.setRemoteAddr("231.23.45.65"); assertEquals("231.23.45.65", ServletHelper.retrieveRemoteAddress(mockHttpServletRequest)); } /** * This mock implementation only supports IP address-related operations. */ private final class MockHttpServletRequestForRemoteAddr implements HttpServletRequest { private String remoteAddr; private Map headers = new HashMap<>(); public void setRemoteAddr(String remoteAddr) { this.remoteAddr = remoteAddr; } public void addHeader(String name, String value) { headers.put(name, value); } @Override public String getHeader(String name) { return headers.get(name); } @Override public String getRemoteAddr() { return remoteAddr; } @Override public String getAuthType() { return null; } @Override public Cookie[] getCookies() { return new Cookie[0]; } @Override public long getDateHeader(String name) { return 0; } @Override public Enumeration getHeaders(String name) { return null; } @Override public Enumeration getHeaderNames() { return null; } @Override public int getIntHeader(String name) { return 0; } @Override public String getMethod() { return null; } @Override public String getPathInfo() { return null; } @Override public String getPathTranslated() { return null; } @Override public String getContextPath() { return null; } @Override public String getQueryString() { return null; } @Override public String getRemoteUser() { return null; } @Override public boolean isUserInRole(String role) { return false; } @Override public Principal getUserPrincipal() { return null; } @Override public String getRequestedSessionId() { return null; } @Override public String getRequestURI() { return null; } @Override public StringBuffer getRequestURL() { return null; } @Override public String getServletPath() { return null; } @Override public HttpSession getSession(boolean create) { return null; } @Override public HttpSession getSession() { return null; } @Override public String changeSessionId() { return null; } @Override public boolean isRequestedSessionIdValid() { return false; } @Override public boolean isRequestedSessionIdFromCookie() { return false; } @Override public boolean isRequestedSessionIdFromURL() { return false; } @Override public boolean isRequestedSessionIdFromUrl() { return false; } @Override public boolean authenticate(HttpServletResponse response) throws IOException, ServletException { return false; } @Override public void login(String username, String password) throws ServletException { } @Override public void logout() throws ServletException { } @Override public Collection getParts() throws IOException, ServletException { return null; } @Override public Part getPart(String name) throws IOException, ServletException { return null; } @Override public T upgrade(Class handlerClass) throws IOException, ServletException { return null; } @Override public Object getAttribute(String name) { return null; } @Override public Enumeration getAttributeNames() { return null; } @Override public String getCharacterEncoding() { return null; } @Override public void setCharacterEncoding(String env) throws UnsupportedEncodingException { } @Override public int getContentLength() { return 0; } @Override public long getContentLengthLong() { return 0; } @Override public String getContentType() { return null; } @Override public ServletInputStream getInputStream() throws IOException { return null; } @Override public String getParameter(String name) { return null; } @Override public Enumeration getParameterNames() { return null; } @Override public String[] getParameterValues(String name) { return new String[0]; } @Override public Map getParameterMap() { return null; } @Override public String getProtocol() { return null; } @Override public String getScheme() { return null; } @Override public String getServerName() { return null; } @Override public int getServerPort() { return 0; } @Override public BufferedReader getReader() throws IOException { return null; } @Override public String getRemoteHost() { return null; } @Override public void setAttribute(String name, Object o) { } @Override public void removeAttribute(String name) { } @Override public Locale getLocale() { return null; } @Override public Enumeration getLocales() { return null; } @Override public boolean isSecure() { return false; } @Override public RequestDispatcher getRequestDispatcher(String path) { return null; } @Override public String getRealPath(String path) { return null; } @Override public int getRemotePort() { return 0; } @Override public String getLocalName() { return null; } @Override public String getLocalAddr() { return null; } @Override public int getLocalPort() { return 0; } @Override public ServletContext getServletContext() { return null; } @Override public AsyncContext startAsync() throws IllegalStateException { return null; } @Override public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException { return null; } @Override public boolean isAsyncStarted() { return false; } @Override public boolean isAsyncSupported() { return false; } @Override public AsyncContext getAsyncContext() { return null; } @Override public DispatcherType getDispatcherType() { return null; } } }