0%

ThreadLocal线程本地存储

ThreadLocal线程本地存储

当访问共享数据时,通常需要使用同步来控制并发程序的访问。那么有没有别的方法来解决呢?当然有,那就是让共享数据不共享了,ThreadLocal就实现了该操作。该类提供了线程局部变量,为每一个线程创建一个单独的变量副本,使得每个线程都可以独立的改变自己所拥有的变量副本,而不会影响其他线程所对应的副本,消除了竞争条件。

采用的以空间换时间的做法,在每个Thread里维护一个以开地址法实现的ThreadLocal.ThreadLocalMap,把数据隔离
线程本地存储根除了对变量的共享,每当线程访问threadLocals变量时,访问的都是各自线程自己的threadLocals变量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class ThreadLocalVariableHolder {
private static ThreadLocal<Integer> myThreadLocal = new ThreadLocal<Integer>() {
// 初始值默认为null 设置初始值为0
protected Integer initialValue() {
return 0;
}
};

public static void increment() {
myThreadLocal.set(myThreadLocal.get() + 1);
}

public static int get(){
return myThreadLocal.get();
}

public static void main(String[] args) {
// 线程池
ExecutorService executorService = Executors.newCachedThreadPool();
for(int i=0;i<3;i++){
executorService.execute(new MyThread());
}
executorService.shutdown();
}
}

public class MyThread implements Runnable {

@Override
public void run() {
for (int i = 0; i < 3; i++) {
ThreadLocalVariableHolder.increment();
System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalVariableHolder.get());
Thread.yield();
}
}
}

输出结果

1
2
3
4
5
6
7
8
9
pool-1-thread-1:1
pool-1-thread-2:1
pool-1-thread-3:1
pool-1-thread-1:2
pool-1-thread-2:2
pool-1-thread-3:2
pool-1-thread-1:3
pool-1-thread-2:3
pool-1-thread-3:3

可以看到ThreadLocal把不同线程的数据进行了隔离,互不影响

来看一下ThreadLocal是如何实现的吧

Thread类

1
2
3
4
5
6
// 线程的本地变量存储在线程的threadLocals变量中,并不是存储在ThreadLocal实例中
// 每一个线程都有一个自己的ThreadLocalMap,ThreadLocalMap类似于Map,key为ThreadLocal对象,value为存储的值,所以一个ThreadLocalMap可以存储多个ThreadLocal

ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

ThreadLocal源码逻辑

  • ThreadLocal对象通过Thread.currentThread()获取当前Thread对象
  • 当前Thread获取对象内部持有的ThreadLocalMap容器
  • 从ThreadLocalMap容器中用ThreadLocal对象作为key,操作当前Thread中的变量副本

提供了四个方法:

  • get() 返回此线程局部变量的当前线程副本中的值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程实例的threadLocals
    ThreadLocalMap map = getMap(t);
    // 不为空
    if (map != null) {
    // 根据当前的ThreadLocal对象引用来取值
    ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
    @SuppressWarnings("unchecked")
    T result = (T)e.value;
    return result;
    }
    }
    // 为空则设置key为当前的ThreadLocal对象,value为initialValue设置的初始值
    return setInitialValue();
    }
  • initialValue() 返回此线程局部变量当前线程的初始值,当线程第一次调用get()或set()方法时调用,并且只调用一次

  • remove() 移除此线程局部变量当前线程的值

1
2
3
4
5
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
  • set(T value) 将此线程局部变量的当前线程副本中的值设置为指定值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取该线程实例对象的threadLocals变量 getMap方法 return t.threadLocals;
    ThreadLocalMap map = getMap(t);
    // 不为空,key为当前的ThreadLocal对象引用,value为所存储的值
    if (map != null)
    map.set(this, value);
    else
    // 为空,则为threadLocals实例化对象
    createMap(t, value);
    }

    ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
    }

    void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
  • 还有一个静态内部类 ThreadLocalMap 提供了一种用键值对方式存储每一个线程的变量副本的方法,key为当前的ThreadLocal对象,value为对象线程的变量副本
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
public class MyThread implements Runnable {

@Override
public void run() {
for (int i = 0; i < 3; i++) {
ThreadLocalVariableHolder.increment();
System.out.println(Thread.currentThread().getName() + ":" + ThreadLocalVariableHolder.get());
Thread.yield();
}
}
}

public class ThreadLocalVariableHolder {
private static ThreadLocal<Integer> myThreadLocal = new ThreadLocal<Integer>() {
// 初始值默认为null 设置初始值为0
protected Integer initialValue() {
return 0;
}
};

public static void increment() {
myThreadLocal.set(myThreadLocal.get() + 1);
}

public static int get(){
return myThreadLocal.get();
}

public static void main(String[] args) {
ExecutorService executorService = Executors.newCachedThreadPool();
for(int i=0;i<5;i++){
executorService.execute(new MyThread());
}
executorService.shutdown();
}
}

执行结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
pool-1-thread-1:1
pool-1-thread-2:1
pool-1-thread-3:1
pool-1-thread-3:2
pool-1-thread-2:2
pool-1-thread-1:2
pool-1-thread-2:3
pool-1-thread-3:3
pool-1-thread-4:1
pool-1-thread-1:3
pool-1-thread-4:2
pool-1-thread-4:3
pool-1-thread-5:1
pool-1-thread-5:2
pool-1-thread-5:3

存在内存泄露问题,每次使用完ThreadLocal,都调用它的remove()方法,清除数据

ThreadLocalMap源码逻辑

作为数据真正存储的位置,ThreadLocalMap对于ThreadLocal还是相当重要的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 初始容量,该容量必须是2的次幂
private static final int INITIAL_CAPACITY = 16;
// 存储数据的Entry数组
private Entry[] table;
// entry的数量
private int size = 0;
// 阈值
private int threshold;

// Entry中key为ThreadLocal,value为值
// key是一个弱引用
// 每个线程在往ThreadLocal中放值的时候,都是放入线程本身的ThreadLocalMap中,key是ThreadLocal,从而实现了线程隔离
// 只要发生了垃圾回收,且该对象没有强引用存在的话,弱引用就会被回收
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
构造函数
1
2
3
4
5
6
7
8
9
10
11
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 初始化
table = new Entry[INITIAL_CAPACITY];
// 使用key的哈希值与上15 计算数组下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 初始化该节点
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置阈值
setThreshold(INITIAL_CAPACITY);
}
重要方法
set方法

存储数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
// 计算数组下标
int i = key.threadLocalHashCode & (len-1);
// 如果出现hash冲突,会采用线性探测,如果当前位置有值,则会获取下一个索引来进行存储,该结构是一个环形
// 直到找到空闲位置为止
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) { // 槽位不为空,进行线性探测
ThreadLocal<?> k = e.get();
// key值与当前key值相等,直接进行覆盖
if (k == key) {
e.value = value;
return;
}
// Entry不为null,但是key为null,当前位置的key已经为null(失效了),则进行替换过期数据
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
// 该循环完成,说明既没有找到原本存储key的位置,也没有找到失效的位置
}
// 执行到这里说明查找到了entry为null的位置
tab[i] = new Entry(key, value);
int sz = ++size;
// 没有清除掉失效的槽位,并且当前数量已经达到了阈值,则进行扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

// 获取下一个数组下标位置
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

// 替换掉失效的值 staleSlot为失效的槽位
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
// 向前扫描,直到找到无效的槽位
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
// 找到无效的槽位,slotToExpunge复为当前位置
if (e.get() == null)
slotToExpunge = i;

// Find either the key or trailing null slot of run, whichever
// occurs first
// 向后扫描,找到失效的槽位
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
// 找到key,替换掉无效的槽位中的值
if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;

// Start expunge at preceding stale entry if it exists
// 向前扫描时没有找到无效的槽位,将slotToExpunge设为当前位置
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 从slotToExpunge开始清理槽位
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
// 如果当前的槽位已经无效,并且向前扫描时没有找到无效的槽位,将slotToExpunge设为当前位置
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 没有找到对应的key新增加一个entry
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them
// 扫描过程中存在无效的槽位,进行清理
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

// 清理无效的槽位
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
// key已失效
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// expunge entry at staleSlot
// 去掉对value的引用
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// key为null,去掉对value的引用
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

// 进行扩容操作
private void rehash() {
expungeStaleEntries();

// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}
// 扩容扩大2倍
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}
getEntry方法

获取Entry值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private Entry getEntry(ThreadLocal<?> key) {
// 计算数组下标
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 没有命中
// 在存放数据的时候采用的是开方定址法,所以可能存在当前key的散列值和元素所在索引并不完全对应的情况
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null) // key等于null,说明该key已经失效了,进行回收,可以有效的避免内存泄漏
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
remove方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
// 计算数组下标
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// 找到对应的key
if (e.get() == key) {
// 清除对ThreadLocal的弱引用
e.clear();
// 清理key为null的元素
expungeStaleEntry(i);
return;
}
}
}

InheritableThreadLocal类解决ThreadLocal继承性问题

InheritableThreadLocal是ThreadLocal的子类,该类提供了一个特性,可以让子线程访问在父线程中设置的本地变量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class InheritableThreadLocal<T> extends ThreadLocal<T> {

protected T childValue(T parentValue) {
return parentValue;
}

// 获取ThreadLocalMap时获取的是该线程中的inheritableThreadLocals变量
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}

// 创建ThreadLocalMap时,使用的是inheritableThreadLocals变量而不是threadLocals变量
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}

如果一个线程是从其他某个线程中创建的,这个类将提供继承的值,如果一个线程A在线程局部变量中已有值,那么当线程A创建其他某个线程B时,线程B的线程局部变量将跟线程A是一样的,可以重写childValue()方法,该方法用来初始化子线程在线程局部变量中的值,使用父线程在线程局部变量中的值作为传入参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Thread实例化时的初始化过程
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
// 省略无关代码

// 当前线程
Thread parent = currentThread();

// inheritThreadLocals为true 且 父线程的inheritableThreadLocals不为null
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
// 则设置子线程的inheritableThreadLocals,值为父线程的inheritableThreadLocals的值复制而来
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

}

javaweb中就是使用这种方式来使得子线程可以获取请求信息的

1
RequestContextHolder.setRequestAttributes(requestAttributes,true);

内存泄漏问题

1
2
3
4
5
6
7
8
9
10
// ThreadLocal中的Entry
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
ThreadLocal内存泄漏

在ThreadLocal.ThreadLocalMap中的key为ThreadLocal对象实例,这个Map中的key是一个弱引用,当把ThreadLocal对象实例置为null后,没有任何强引用指向ThreadLocal对象实例,所以ThreadLocal对象实例会被gc回收,但是Map中的value却不会被回收(此时entry是一个key为null,但是value不为null),因为存在一条从当前thread连接过来的强引用,只有当前thread结束之后,当前thread的强引用才会断开,此时Map中的value才会被gc回收。

但是在使用线程池的时候,由于线程是重复利用的,不会被回收,所以就可能出现内存泄漏

当然JDK的设计中也有考虑到这点,所以在get()、set()、remove()中会扫描key为null的Entry,将对应的value也设置为null,这样value就会被回收

所以当使用完之后,需要调用ThreadLocal对象的remove()方法来删除掉

欢迎关注我的其它发布渠道