SpringBoot实现防XSS攻击 超详细教程

寿鸣
2023-12-01

前言:

最近在熟悉公司的项目框架,看到这一块内容时深入了解了一下,在这里总结一下。
XSS攻击通常指的是通过利用网页开发时留下的漏洞,通过巧妙的方法注入恶意指令代码到网页,使用户加载并执行攻击者恶意制造的网页程序。
跨站脚本攻击(XSS),是最普遍的Web应用安全漏洞。这类漏洞能够使得攻击者嵌入恶意脚本代码到正常用户会访问到的页面中,当正常用户访问该页面时,则可导致嵌入的恶意脚本代码的执行,从而达到恶意攻击用户的目的。

实现教程:

  1. 教程中会涉及到几个工具类的依赖,所以需要提前引入一下,在pom文件中新加
<!--io常用工具类 -->
<dependency>
   <groupId>commons-io</groupId>
   <artifactId>commons-io</artifactId>
   <version>2.5</version>
</dependency>
<!--commons-lang3 工具类 -->
<dependency>
   <groupId>org.apache.commons</groupId>
   <artifactId>commons-lang3</artifactId>
   <version>3.5</version>
</dependency>
  1. 然后我们将两个所需工具类引入到项目中,直接ctrl+c、ctrl+v粘到项目中即可
  • 这个工具类用来转义和反转义、清除所有HTML的标签。
package com.blog.utils.html;

import org.springframework.util.StringUtils;

/**
 * 转义和反转义工具类
 * @author LiWT
 * @Date 2021/6/23
 */
public class EscapeUtil {
    public static final String RE_HTML_MARK = "(<[^<]*?>)|(<[\\s]*?/[^<]*?>)|(<[^<]*?/[\\s]*?>)";

    private static final char[][] TEXT = new char[64][];

    static {
        for (int i = 0; i < 64; i++) {
            TEXT[i] = new char[]{(char) i};
        }

        // special HTML characters
        TEXT['\''] = "&#039;".toCharArray(); // 单引号
        TEXT['"'] = "&#34;".toCharArray(); // 单引号
        TEXT['&'] = "&#38;".toCharArray(); // &符
        TEXT['<'] = "&#60;".toCharArray(); // 小于号
        TEXT['>'] = "&#62;".toCharArray(); // 大于号
    }

    /**
     * 转义文本中的HTML字符为安全的字符
     *
     * @param text 被转义的文本
     * @return 转义后的文本
     */
    public static String escape(String text) {
        return encode(text);
    }

    /**
     * 还原被转义的HTML特殊字符
     *
     * @param content 包含转义符的HTML内容
     * @return 转换后的字符串
     */
    public static String unescape(String content) {
        return decode(content);
    }

    /**
     * 清除所有HTML标签,但是不删除标签内的内容
     *
     * @param content 文本
     * @return 清除标签后的文本
     */
    public static String clean(String content) {
        return new HTMLFilter().filter(content);
    }

    /**
     * Escape编码
     *
     * @param text 被编码的文本
     * @return 编码后的字符
     */
    private static String encode(String text) {
        int len;
        if ((text == null) || ((len = text.length()) == 0)) {
            return "";
        }
        StringBuilder buffer = new StringBuilder(len + (len >> 2));
        char c;
        for (int i = 0; i < len; i++) {
            c = text.charAt(i);
            if (c < 64) {
                buffer.append(TEXT[c]);
            } else {
                buffer.append(c);
            }
        }
        return buffer.toString();
    }

    /**
     * Escape解码
     *
     * @param content 被转义的内容
     * @return 解码后的字符串
     */
    public static String decode(String content) {
        if (StringUtils.isEmpty(content)) {
            return content;
        }

        StringBuilder tmp = new StringBuilder(content.length());
        int lastPos = 0, pos = 0;
        char ch;
        while (lastPos < content.length()) {
            pos = content.indexOf("%", lastPos);
            if (pos == lastPos) {
                if (content.charAt(pos + 1) == 'u') {
                    ch = (char) Integer.parseInt(content.substring(pos + 2, pos + 6), 16);
                    tmp.append(ch);
                    lastPos = pos + 6;
                } else {
                    ch = (char) Integer.parseInt(content.substring(pos + 1, pos + 3), 16);
                    tmp.append(ch);
                    lastPos = pos + 3;
                }
            } else {
                if (pos == -1) {
                    tmp.append(content.substring(lastPos));
                    lastPos = content.length();
                } else {
                    tmp.append(content.substring(lastPos, pos));
                    lastPos = pos;
                }
            }
        }
        return tmp.toString();
    }
}

  • 这个工具类主要用来匹配HTML中标签及字符
package com.blog.utils.html;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * HTML过滤器,用于去除XSS漏洞隐患。
 * @author LiWT
 * @Date 2021/6/23
 */
public final class HTMLFilter {
   /**
    * regex flag union representing /si modifiers in php
    **/
   private static final int REGEX_FLAGS_SI = Pattern.CASE_INSENSITIVE | Pattern.DOTALL;
   private static final Pattern P_COMMENTS = Pattern.compile("<!--(.*?)-->", Pattern.DOTALL);
   private static final Pattern P_COMMENT = Pattern.compile("^!--(.*)--$", REGEX_FLAGS_SI);
   private static final Pattern P_TAGS = Pattern.compile("<(.*?)>", Pattern.DOTALL);
   private static final Pattern P_END_TAG = Pattern.compile("^/([a-z0-9]+)", REGEX_FLAGS_SI);
   private static final Pattern P_START_TAG = Pattern.compile("^([a-z0-9]+)(.*?)(/?)$", REGEX_FLAGS_SI);
   private static final Pattern P_QUOTED_ATTRIBUTES = Pattern.compile("([a-z0-9]+)=([\"'])(.*?)\\2", REGEX_FLAGS_SI);
   private static final Pattern P_UNQUOTED_ATTRIBUTES = Pattern.compile("([a-z0-9]+)(=)([^\"\\s']+)", REGEX_FLAGS_SI);
   private static final Pattern P_PROTOCOL = Pattern.compile("^([^:]+):", REGEX_FLAGS_SI);
   private static final Pattern P_ENTITY = Pattern.compile("&#(\\d+);?");
   private static final Pattern P_ENTITY_UNICODE = Pattern.compile("&#x([0-9a-f]+);?");
   private static final Pattern P_ENCODE = Pattern.compile("%([0-9a-f]{2});?");
   private static final Pattern P_VALID_ENTITIES = Pattern.compile("&([^&;]*)(?=(;|&|$))");
   private static final Pattern P_VALID_QUOTES = Pattern.compile("(>|^)([^<]+?)(<|$)", Pattern.DOTALL);
   private static final Pattern P_END_ARROW = Pattern.compile("^>");
   private static final Pattern P_BODY_TO_END = Pattern.compile("<([^>]*?)(?=<|$)");
   private static final Pattern P_XML_CONTENT = Pattern.compile("(^|>)([^<]*?)(?=>)");
   private static final Pattern P_STRAY_LEFT_ARROW = Pattern.compile("<([^>]*?)(?=<|$)");
   private static final Pattern P_STRAY_RIGHT_ARROW = Pattern.compile("(^|>)([^<]*?)(?=>)");
   private static final Pattern P_AMP = Pattern.compile("&");
   private static final Pattern P_QUOTE = Pattern.compile("\"");
   private static final Pattern P_LEFT_ARROW = Pattern.compile("<");
   private static final Pattern P_RIGHT_ARROW = Pattern.compile(">");
   private static final Pattern P_BOTH_ARROWS = Pattern.compile("<>");

   // @xxx could grow large... maybe use sesat's ReferenceMap
   private static final ConcurrentMap<String, Pattern> P_REMOVE_PAIR_BLANKS = new ConcurrentHashMap<>();
   private static final ConcurrentMap<String, Pattern> P_REMOVE_SELF_BLANKS = new ConcurrentHashMap<>();

   /**
    * set of allowed html elements, along with allowed attributes for each element
    **/
   private final Map<String, List<String>> vAllowed;
   /**
    * counts of open tags for each (allowable) html element
    **/
   private final Map<String, Integer> vTagCounts = new HashMap<>();

   /**
    * html elements which must always be self-closing (e.g. "<img />")
    **/
   private final String[] vSelfClosingTags;
   /**
    * html elements which must always have separate opening and closing tags (e.g. "<b></b>")
    **/
   private final String[] vNeedClosingTags;
   /**
    * set of disallowed html elements
    **/
   private final String[] vDisallowed;
   /**
    * attributes which should be checked for valid protocols
    **/
   private final String[] vProtocolAtts;
   /**
    * allowed protocols
    **/
   private final String[] vAllowedProtocols;
   /**
    * tags which should be removed if they contain no content (e.g. "<b></b>" or "<b />")
    **/
   private final String[] vRemoveBlanks;
   /**
    * entities allowed within html markup
    **/
   private final String[] vAllowedEntities;
   /**
    * flag determining whether comments are allowed in input String.
    */
   private final boolean stripComment;
   private final boolean encodeQuotes;
   /**
    * flag determining whether to try to make tags when presented with "unbalanced" angle brackets (e.g. "<b text </b>"
    * becomes "<b> text </b>"). If set to false, unbalanced angle brackets will be html escaped.
    */
   private final boolean alwaysMakeTags;

   /**
    * Default constructor.
    */
   public HTMLFilter() {
      vAllowed = new HashMap<>();

      final ArrayList<String> a_atts = new ArrayList<>();
      a_atts.add("href");
      a_atts.add("target");
      vAllowed.put("a", a_atts);

      final ArrayList<String> img_atts = new ArrayList<>();
      img_atts.add("src");
      img_atts.add("width");
      img_atts.add("height");
      img_atts.add("alt");
      vAllowed.put("img", img_atts);

      final ArrayList<String> no_atts = new ArrayList<>();
      vAllowed.put("b", no_atts);
      vAllowed.put("strong", no_atts);
      vAllowed.put("i", no_atts);
      vAllowed.put("em", no_atts);

      vSelfClosingTags = new String[]{"img"};
      vNeedClosingTags = new String[]{"a", "b", "strong", "i", "em"};
      vDisallowed = new String[]{};
      vAllowedProtocols = new String[]{"http", "mailto", "https"}; // no ftp.
      vProtocolAtts = new String[]{"src", "href"};
      vRemoveBlanks = new String[]{"a", "b", "strong", "i", "em"};
      vAllowedEntities = new String[]{"amp", "gt", "lt", "quot"};
      stripComment = true;
      encodeQuotes = true;
      alwaysMakeTags = true;
   }

   /**
    * Map-parameter configurable constructor.
    *
    * @param conf map containing configuration. keys match field names.
    */
   @SuppressWarnings("unchecked")
   public HTMLFilter(final Map<String, Object> conf) {

      assert conf.containsKey("vAllowed") : "configuration requires vAllowed";
      assert conf.containsKey("vSelfClosingTags") : "configuration requires vSelfClosingTags";
      assert conf.containsKey("vNeedClosingTags") : "configuration requires vNeedClosingTags";
      assert conf.containsKey("vDisallowed") : "configuration requires vDisallowed";
      assert conf.containsKey("vAllowedProtocols") : "configuration requires vAllowedProtocols";
      assert conf.containsKey("vProtocolAtts") : "configuration requires vProtocolAtts";
      assert conf.containsKey("vRemoveBlanks") : "configuration requires vRemoveBlanks";
      assert conf.containsKey("vAllowedEntities") : "configuration requires vAllowedEntities";

      vAllowed = Collections.unmodifiableMap((HashMap<String, List<String>>) conf.get("vAllowed"));
      vSelfClosingTags = (String[]) conf.get("vSelfClosingTags");
      vNeedClosingTags = (String[]) conf.get("vNeedClosingTags");
      vDisallowed = (String[]) conf.get("vDisallowed");
      vAllowedProtocols = (String[]) conf.get("vAllowedProtocols");
      vProtocolAtts = (String[]) conf.get("vProtocolAtts");
      vRemoveBlanks = (String[]) conf.get("vRemoveBlanks");
      vAllowedEntities = (String[]) conf.get("vAllowedEntities");
      stripComment = conf.containsKey("stripComment") ? (Boolean) conf.get("stripComment") : true;
      encodeQuotes = conf.containsKey("encodeQuotes") ? (Boolean) conf.get("encodeQuotes") : true;
      alwaysMakeTags = conf.containsKey("alwaysMakeTags") ? (Boolean) conf.get("alwaysMakeTags") : true;
   }

   private void reset() {
      vTagCounts.clear();
   }

   // ---------------------------------------------------------------
   // my versions of some PHP library functions
   public static String chr(final int decimal) {
      return String.valueOf((char) decimal);
   }

   public static String htmlSpecialChars(final String s) {
      String result = s;
      result = regexReplace(P_AMP, "&amp;", result);
      result = regexReplace(P_QUOTE, "&quot;", result);
      result = regexReplace(P_LEFT_ARROW, "&lt;", result);
      result = regexReplace(P_RIGHT_ARROW, "&gt;", result);
      return result;
   }

   // ---------------------------------------------------------------

   /**
    * given a user submitted input String, filter out any invalid or restricted html.
    *
    * @param input text (i.e. submitted by a user) than may contain html
    * @return "clean" version of input, with only valid, whitelisted html elements allowed
    */
   public String filter(final String input) {
      reset();
      String s = input;

      s = escapeComments(s);

      s = balanceHTML(s);

      s = checkTags(s);

      s = processRemoveBlanks(s);

      s = validateEntities(s);

      return s;
   }

   public boolean isAlwaysMakeTags() {
      return alwaysMakeTags;
   }

   public boolean isStripComments() {
      return stripComment;
   }

   private String escapeComments(final String s) {
      final Matcher m = P_COMMENTS.matcher(s);
      final StringBuffer buf = new StringBuffer();
      if (m.find()) {
         final String match = m.group(1); // (.*?)
         m.appendReplacement(buf, Matcher.quoteReplacement("<!--" + htmlSpecialChars(match) + "-->"));
      }
      m.appendTail(buf);

      return buf.toString();
   }

   private String balanceHTML(String s) {
      if (alwaysMakeTags) {
         //
         // try and form html
         //
         s = regexReplace(P_END_ARROW, "", s);
         s = regexReplace(P_BODY_TO_END, "<$1>", s);
         s = regexReplace(P_XML_CONTENT, "$1<$2", s);

      } else {
         //
         // escape stray brackets
         //
         s = regexReplace(P_STRAY_LEFT_ARROW, "&lt;$1", s);
         s = regexReplace(P_STRAY_RIGHT_ARROW, "$1$2&gt;<", s);

         //
         // the last regexp causes '<>' entities to appear
         // (we need to do a lookahead assertion so that the last bracket can
         // be used in the next pass of the regexp)
         //
         s = regexReplace(P_BOTH_ARROWS, "", s);
      }

      return s;
   }

   private String checkTags(String s) {
      Matcher m = P_TAGS.matcher(s);

      final StringBuffer buf = new StringBuffer();
      while (m.find()) {
         String replaceStr = m.group(1);
         replaceStr = processTag(replaceStr);
         m.appendReplacement(buf, Matcher.quoteReplacement(replaceStr));
      }
      m.appendTail(buf);

      // these get tallied in processTag
      // (remember to reset before subsequent calls to filter method)
      final StringBuilder sBuilder = new StringBuilder(buf.toString());
      for (String key : vTagCounts.keySet()) {
         for (int ii = 0; ii < vTagCounts.get(key); ii++) {
            sBuilder.append("</").append(key).append(">");
         }
      }
      s = sBuilder.toString();

      return s;
   }

   private String processRemoveBlanks(final String s) {
      String result = s;
      for (String tag : vRemoveBlanks) {
         if (!P_REMOVE_PAIR_BLANKS.containsKey(tag)) {
            P_REMOVE_PAIR_BLANKS.putIfAbsent(tag, Pattern.compile("<" + tag + "(\\s[^>]*)?></" + tag + ">"));
         }
         result = regexReplace(P_REMOVE_PAIR_BLANKS.get(tag), "", result);
         if (!P_REMOVE_SELF_BLANKS.containsKey(tag)) {
            P_REMOVE_SELF_BLANKS.putIfAbsent(tag, Pattern.compile("<" + tag + "(\\s[^>]*)?/>"));
         }
         result = regexReplace(P_REMOVE_SELF_BLANKS.get(tag), "", result);
      }

      return result;
   }

   private static String regexReplace(final Pattern regex_pattern, final String replacement, final String s) {
      Matcher m = regex_pattern.matcher(s);
      return m.replaceAll(replacement);
   }

   private String processTag(final String s) {
      // ending tags
      Matcher m = P_END_TAG.matcher(s);
      if (m.find()) {
         final String name = m.group(1).toLowerCase();
         if (allowed(name)) {
            if (false == inArray(name, vSelfClosingTags)) {
               if (vTagCounts.containsKey(name)) {
                  vTagCounts.put(name, vTagCounts.get(name) - 1);
                  return "</" + name + ">";
               }
            }
         }
      }

      // starting tags
      m = P_START_TAG.matcher(s);
      if (m.find()) {
         final String name = m.group(1).toLowerCase();
         final String body = m.group(2);
         String ending = m.group(3);

         // debug( "in a starting tag, name='" + name + "'; body='" + body + "'; ending='" + ending + "'" );
         if (allowed(name)) {
            final StringBuilder params = new StringBuilder();

            final Matcher m2 = P_QUOTED_ATTRIBUTES.matcher(body);
            final Matcher m3 = P_UNQUOTED_ATTRIBUTES.matcher(body);
            final List<String> paramNames = new ArrayList<>();
            final List<String> paramValues = new ArrayList<>();
            while (m2.find()) {
               paramNames.add(m2.group(1)); // ([a-z0-9]+)
               paramValues.add(m2.group(3)); // (.*?)
            }
            while (m3.find()) {
               paramNames.add(m3.group(1)); // ([a-z0-9]+)
               paramValues.add(m3.group(3)); // ([^\"\\s']+)
            }

            String paramName, paramValue;
            for (int ii = 0; ii < paramNames.size(); ii++) {
               paramName = paramNames.get(ii).toLowerCase();
               paramValue = paramValues.get(ii);

               // debug( "paramName='" + paramName + "'" );
               // debug( "paramValue='" + paramValue + "'" );
               // debug( "allowed? " + vAllowed.get( name ).contains( paramName ) );

               if (allowedAttribute(name, paramName)) {
                  if (inArray(paramName, vProtocolAtts)) {
                     paramValue = processParamProtocol(paramValue);
                  }
                  params.append(' ').append(paramName).append("=\"").append(paramValue).append("\"");
               }
            }

            if (inArray(name, vSelfClosingTags)) {
               ending = " /";
            }

            if (inArray(name, vNeedClosingTags)) {
               ending = "";
            }

            if (ending == null || ending.length() < 1) {
               if (vTagCounts.containsKey(name)) {
                  vTagCounts.put(name, vTagCounts.get(name) + 1);
               } else {
                  vTagCounts.put(name, 1);
               }
            } else {
               ending = " /";
            }
            return "<" + name + params + ending + ">";
         } else {
            return "";
         }
      }

      // comments
      m = P_COMMENT.matcher(s);
      if (!stripComment && m.find()) {
         return "<" + m.group() + ">";
      }

      return "";
   }

   private String processParamProtocol(String s) {
      s = decodeEntities(s);
      final Matcher m = P_PROTOCOL.matcher(s);
      if (m.find()) {
         final String protocol = m.group(1);
         if (!inArray(protocol, vAllowedProtocols)) {
            // bad protocol, turn into local anchor link instead
            s = "#" + s.substring(protocol.length() + 1);
            if (s.startsWith("#//")) {
               s = "#" + s.substring(3);
            }
         }
      }

      return s;
   }

   private String decodeEntities(String s) {
      StringBuffer buf = new StringBuffer();

      Matcher m = P_ENTITY.matcher(s);
      while (m.find()) {
         final String match = m.group(1);
         final int decimal = Integer.decode(match).intValue();
         m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
      }
      m.appendTail(buf);
      s = buf.toString();

      buf = new StringBuffer();
      m = P_ENTITY_UNICODE.matcher(s);
      while (m.find()) {
         final String match = m.group(1);
         final int decimal = Integer.valueOf(match, 16).intValue();
         m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
      }
      m.appendTail(buf);
      s = buf.toString();

      buf = new StringBuffer();
      m = P_ENCODE.matcher(s);
      while (m.find()) {
         final String match = m.group(1);
         final int decimal = Integer.valueOf(match, 16).intValue();
         m.appendReplacement(buf, Matcher.quoteReplacement(chr(decimal)));
      }
      m.appendTail(buf);
      s = buf.toString();

      s = validateEntities(s);
      return s;
   }

   private String validateEntities(final String s) {
      StringBuffer buf = new StringBuffer();

      // validate entities throughout the string
      Matcher m = P_VALID_ENTITIES.matcher(s);
      while (m.find()) {
         final String one = m.group(1); // ([^&;]*)
         final String two = m.group(2); // (?=(;|&|$))
         m.appendReplacement(buf, Matcher.quoteReplacement(checkEntity(one, two)));
      }
      m.appendTail(buf);

      return encodeQuotes(buf.toString());
   }

   private String encodeQuotes(final String s) {
      if (encodeQuotes) {
         StringBuffer buf = new StringBuffer();
         Matcher m = P_VALID_QUOTES.matcher(s);
         while (m.find()) {
            final String one = m.group(1); // (>|^)
            final String two = m.group(2); // ([^<]+?)
            final String three = m.group(3); // (<|$)
            // 不替换双引号为&quot;,防止json格式无效 regexReplace(P_QUOTE, "&quot;", two)
            m.appendReplacement(buf, Matcher.quoteReplacement(one + two + three));
         }
         m.appendTail(buf);
         return buf.toString();
      } else {
         return s;
      }
   }

   private String checkEntity(final String preamble, final String term) {

      return ";".equals(term) && isValidEntity(preamble) ? '&' + preamble : "&amp;" + preamble;
   }

   private boolean isValidEntity(final String entity) {
      return inArray(entity, vAllowedEntities);
   }

   private static boolean inArray(final String s, final String[] array) {
      for (String item : array) {
         if (item != null && item.equals(s)) {
            return true;
         }
      }
      return false;
   }

   private boolean allowed(final String name) {
      return (vAllowed.isEmpty() || vAllowed.containsKey(name)) && !inArray(name, vDisallowed);
   }

   private boolean allowedAttribute(final String name, final String paramName) {
      return allowed(name) && (vAllowed.isEmpty() || vAllowed.get(name).contains(paramName));
   }
}
  1. 准备工作就结束了,我们现在来开始实现防XSS攻击的过滤
    首先需要在application.yml配置文件中新加关于XSS的一些配置
    • enabled 是否开启防XSS攻击 true:开启,false:关闭
    • excludes 不需要防XSS攻击的路径(比如我们的接口路径是:/notice/list、/notice/add…,就可以写成/notice/*)
    • urlPatterns 需要防XSS攻击的路径
# 防止XSS攻击
xss:
  # 过滤开关
  enabled: true
  # 排除链接(多个用逗号分隔)
  excludes:
  # 匹配链接(多个用逗号分隔)
  urlPatterns: /test/*,/system/*
  1. 实现防止XSS攻击的过滤器
package com.blog.config.filter;

import org.springframework.util.StringUtils;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;


/**
 * 防止XSS攻击的过滤器
 * @author LiWT
 * @Date 2021/6/23
 */
public class XssFilter implements Filter {
    /**
     * 排除链接
     */
    public List<String> excludes = new ArrayList<>();

    /**
     * xss过滤开关
     */
    public boolean enabled = false;

    @Override
    public void init(FilterConfig filterConfig) {
        String tempExcludes = filterConfig.getInitParameter("excludes");
        String tempEnabled = filterConfig.getInitParameter("enabled");
        if (!StringUtils.isEmpty(tempExcludes)) {
            String[] url = tempExcludes.split(",");
            excludes.addAll(Arrays.asList(url));
        }
        if (!StringUtils.isEmpty(tempEnabled)) {
            enabled = Boolean.parseBoolean(tempEnabled);
        }
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse resp = (HttpServletResponse) response;
        if (handleExcludeURL(req, resp)) {
            chain.doFilter(request, response);
            return;
        }
        XssHttpServletRequestWrapper xssRequest = new XssHttpServletRequestWrapper((HttpServletRequest) request);
        chain.doFilter(xssRequest, response);
    }

    private boolean handleExcludeURL(HttpServletRequest request, HttpServletResponse response) {
        if (!enabled) {
            return true;
        }
        if (excludes == null || excludes.isEmpty()) {
            return false;
        }
        String url = request.getServletPath();
        for (String pattern : excludes) {
            Pattern p = Pattern.compile("^" + pattern);
            Matcher m = p.matcher(url);
            if (m.find()) {
                return true;
            }
        }
        return false;
    }

    @Override
    public void destroy() {

    }
}
  1. 对XSS过滤的处理,这里直接调用了EscapeUtil.clean()方法,清除所有HTML标签,其实也可以进行转义操作,需要用的时候再进行反转义,可以根据业务场景调整。
package com.blog.config.filter;

import com.blog.utils.html.EscapeUtil;
import org.apache.commons.io.IOUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.util.StringUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;


/**
 * XSS过滤处理
 * @author LiWT
 * @Date 2021/6/23
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
    /**
     * @param request
     */
    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if (values != null) {
            int length = values.length;
            String[] escapseValues = new String[length];
            for (int i = 0; i < length; i++) {
                // 防xss攻击和过滤前后空格
                escapseValues[i] = EscapeUtil.clean(values[i]).trim();
            }
            return escapseValues;
        }
        return super.getParameterValues(name);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        // 非json类型,直接返回
        if (!isJsonRequest()) {
            return super.getInputStream();
        }

        // 为空,直接返回
        String json = IOUtils.toString(super.getInputStream(), StandardCharsets.UTF_8);
        if (StringUtils.isEmpty(json)) {
            return super.getInputStream();
        }

        // xss过滤
        json = EscapeUtil.clean(json).trim();
        final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8));
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return true;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener readListener) {
            }

            @Override
            public int read() throws IOException {
                return bis.read();
            }
        };
    }

    /**
     * 是否是Json请求
     *
     * @param %request
     */
    public boolean isJsonRequest() {
        String header = super.getHeader(HttpHeaders.CONTENT_TYPE);
        return MediaType.APPLICATION_JSON_VALUE.equalsIgnoreCase(header)
                || MediaType.APPLICATION_JSON_UTF8_VALUE.equalsIgnoreCase(header);
    }
}
  1. 最后,我们将我们的XSS过滤器配置到spring容器中
package com.blog.config;

import com.blog.config.filter.XssFilter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.util.StringUtils;

import javax.servlet.DispatcherType;
import java.util.HashMap;
import java.util.Map;

/**
 * Filter配置
 * @author LiWT
 * @Date 2021/6/23
 */
@Configuration
public class FilterConfig {
    @Value("${xss.enabled}")
    private String enabled;

    @Value("${xss.excludes}")
    private String excludes;

    @Value("${xss.urlPatterns}")
    private String urlPatterns;

    @SuppressWarnings({"rawtypes", "unchecked"})
    @Bean
    public FilterRegistrationBean xssFilterRegistration() {
        FilterRegistrationBean registration = new FilterRegistrationBean();
        registration.setDispatcherTypes(DispatcherType.REQUEST);
        registration.setFilter(new XssFilter());
        registration.addUrlPatterns(StringUtils.split(urlPatterns, ","));
        registration.setName("xssFilter");
        registration.setOrder(FilterRegistrationBean.HIGHEST_PRECEDENCE);
        Map<String, String> initParameters = new HashMap<String, String>();
        initParameters.put("excludes", excludes);
        initParameters.put("enabled", enabled);
        registration.setInitParameters(initParameters);
        return registration;
    }

}
  1. 这样我们的防XSS攻击的过滤就完成了,最后我们来测试一下编写测试接口。
    我们访问的路径为:/test/testXss,所以我们在application.yml配置文件配置 urlPatterns: /test/* 将test下的所有接口开启防XSS攻击
package com.blog.controller;

import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

/**
 * @Author: LiWT
 * @Date: 2021/6/22 
 */
@RestController
@RequestMapping("/test")
public class TestController {

    @GetMapping("/testXss")
    public void testXss(String param) {
        System.out.println("param:" + param);
    }
}
  1. 最后我们来进行测试,依次测试链接:
http://localhost:8081/test/testXss?param=1  
http://localhost:8081/test/testXss?param="参数"   
http://localhost:8081/test/testXss?param=<aside><a href="#" target="_blank">链接</a></aside>   
http://localhost:8081/test/testXss?param=<script>alert("1")</script>  

控制台输出为:

param:1
param:"参数"
param:
param:alert("1")

可以看出我们的防XSS攻击的过滤生效了,将HTML中的标签清除掉了,如果是业务需要的话可以对它进行转义操作,等需要的时候再进行反转义就ok了,这里直接进行了简单粗暴的清除标签。

 类似资料: