-
Notifications
You must be signed in to change notification settings - Fork 0
/
redis
150 lines (134 loc) · 5.56 KB
/
redis
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import redis.clients.jedis.Jedis;
import java.util.UUID;
public class GlobalSemaphore {
private final Jedis jedis;
private final String lockKey = "semaphore";
private final String ownerKey = "semaphore_owner";
private final int maxConcurrent;
private final int timeout; // 超时设置
public GlobalSemaphore(int maxConcurrent, int timeout) {
this.jedis = new Jedis("localhost");
this.maxConcurrent = maxConcurrent;
this.timeout = timeout; // 秒
}
public String acquire() throws InterruptedException {
String threadId = UUID.randomUUID().toString(); // 生成唯一标识符
String luaScript = "local current = tonumber(redis.call('GET', KEYS[1])) " +
"if current and current < tonumber(ARGV[1]) then " +
"redis.call('INCR', KEYS[1]) " +
"redis.call('SET', KEYS[2], ARGV[2], 'EX', ARGV[3]) " +
"return ARGV[2] " +
"else return nil end";
long startTime = System.currentTimeMillis();
while (true) {
Object result = jedis.eval(luaScript, 2, lockKey, ownerKey, String.valueOf(maxConcurrent), threadId, String.valueOf(timeout));
if (result != null) {
return threadId; // 成功获取信号量
}
// 检查超时
if (System.currentTimeMillis() - startTime > timeout * 1000) {
throw new RuntimeException("Failed to acquire semaphore within timeout");
}
Thread.sleep(100); // 等待重试
}
}
public void release(String threadId) {
String luaScript = "if redis.call('GET', KEYS[1]) then " +
"if redis.call('GET', KEYS[2]) == ARGV[1] then " +
"redis.call('DECR', KEYS[1]) " +
"redis.call('DEL', KEYS[2]) " +
"return 1 end end " +
"return 0";
jedis.eval(luaScript, 2, lockKey, ownerKey, threadId);
}
public static void main(String[] args) {
GlobalSemaphore semaphore = new GlobalSemaphore(5, 5); // 设置最大并发数和超时时间
// 创建多个线程模拟并发请求
for (int i = 0; i < 10; i++) {
new Thread(() -> {
try {
String threadId = semaphore.acquire();
System.out.println(Thread.currentThread().getName() + " processing request...");
Thread.sleep(2000); // 模拟处理时间
// 故意抛出异常以测试
if (Math.random() > 0.5) {
throw new RuntimeException("Simulated error");
}
} catch (Exception e) {
System.out.println(Thread.currentThread().getName() + " error: " + e.getMessage());
} finally {
semaphore.release(threadId);
}
}).start();
}
}
}
import redis
import time
import uuid
from threading import Thread, current_thread
class GlobalSemaphore:
def __init__(self, max_concurrent, timeout):
self.redis = redis.Redis(host='localhost', port=6379, db=0)
self.max_concurrent = max_concurrent
self.timeout = timeout # 超时设置
self.lock_key = 'semaphore'
self.owner_key = 'semaphore_owner'
def acquire(self):
thread_id = str(uuid.uuid4()) # 生成唯一标识符
script = """
local current = redis.call('GET', KEYS[1])
if current and tonumber(current) < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('SET', KEYS[2], ARGV[2], 'EX', ARGV[3])
return ARGV[2]
else
return nil
end
"""
start_time = time.time()
while True:
owner = self.redis.eval(script, 2, self.lock_key, self.owner_key, 1, thread_id, self.timeout)
if owner:
return thread_id # 成功获取信号量
# 检查超时
if time.time() - start_time > self.timeout:
raise TimeoutError("Failed to acquire semaphore within timeout")
time.sleep(0.1) # 等待重试
def release(self, thread_id):
# Lua 脚本释放信号量,仅当当前线程是持有者时
script = """
if redis.call('GET', KEYS[1]) then
if redis.call('GET', KEYS[2]) == ARGV[1] then
redis.call('DECR', KEYS[1])
redis.call('DEL', KEYS[2])
return 1
end
end
return 0
"""
self.redis.eval(script, 2, self.lock_key, self.owner_key, thread_id)
# 示例请求处理函数
def handle_request(semaphore):
try:
thread_id = semaphore.acquire()
print(f"Thread {current_thread().name} processing request...")
time.sleep(2) # 模拟处理时间
# 故意抛出异常以测试
if time.time() % 2 > 1:
raise Exception("Simulated error")
except Exception as e:
print(f"Thread {current_thread().name} error: {e}")
finally:
semaphore.release(thread_id)
# 主逻辑
if __name__ == "__main__":
semaphore = GlobalSemaphore(max_concurrent=5, timeout=5) # 设置超时为5秒
# 创建多个线程模拟并发请求
threads = []
for _ in range(10):
thread = Thread(target=handle_request, args=(semaphore,))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()